@@ -99,6 +99,9 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
9999 (permute_input .shape [0 ], layer_added_weight_attrs_0 .shape [1 ]),
100100 dtype = paddle .bfloat16 ,
101101 )
102+ if disable_ue8m0_cast :
103+ permute_scale = permute_scale .transpose ([1 , 0 ]).contiguous ()
104+ permute_scale = permute_scale .transpose ([1 , 0 ])
102105 # disable_ue8m0_cast is False for SM100
103106 m_grouped_fp8_gemm_nt_contiguous (
104107 (permute_input , permute_scale ),
@@ -262,10 +265,14 @@ def apply_ep_prefill(
262265 x , x_scale_tensor = paddle .incubate .nn .functional .fp8_quant_blockwise (
263266 x ,
264267 using_pow2_scale = self .quant_config .deepgemm_scale_ue8m0 ,
265- output_scale_transpose = True ,
268+ output_scale_transpose = self . quant_config . deepgemm_scale_ue8m0 ,
266269 using_ue8m0_scale = self .quant_config .deepgemm_scale_ue8m0 ,
267270 )
268- x_scale_tensor = x_scale_tensor .T [: x .shape [0 ]]
271+ x_scale_tensor = (
272+ x_scale_tensor [: x .shape [0 ]]
273+ if not self .quant_config .deepgemm_scale_ue8m0
274+ else x_scale_tensor .T [: x .shape [0 ]]
275+ )
269276
270277 event = deep_ep .Buffer .capture ()
271278 let_another_thread_run ()
@@ -502,10 +509,14 @@ def apply_tp(
502509 recv_x , recv_x_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
503510 x ,
504511 using_pow2_scale = self .quant_config .deepgemm_scale_ue8m0 ,
505- output_scale_transpose = True ,
512+ output_scale_transpose = self . quant_config . deepgemm_scale_ue8m0 ,
506513 using_ue8m0_scale = self .quant_config .deepgemm_scale_ue8m0 ,
507514 )
508- recv_x_scale = recv_x_scale .T [: recv_x .shape [0 ]]
515+ recv_x_scale = (
516+ recv_x_scale [: recv_x .shape [0 ]]
517+ if not self .quant_config .deepgemm_scale_ue8m0
518+ else recv_x_scale .T [: recv_x .shape [0 ]]
519+ )
509520 (
510521 permute_input ,
511522 permute_scale ,
0 commit comments