Skip to content

Commit 4cb6ef9

Browse files
committed
sm90 still use row-major scale for permute-fp8
1 parent e967af8 commit 4cb6ef9

3 files changed

Lines changed: 27 additions & 10 deletions

File tree

custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -942,9 +942,14 @@ __global__ void permute_x_fp8_kernel(
942942
}
943943

944944
} else {
945-
for (int s = tid; s < hidden_size_scale; s += blockDim.x) {
946-
permute_scale[s * permute_scale_stride0 + dst_token_idx] =
947-
scale[s * padded_num_rows + s_token_idx];
945+
for (int v_id = tid; v_id < hidden_size_scale_int4;
946+
v_id += blockDim.x) {
947+
*(reinterpret_cast<int4*>(permute_scale +
948+
dst_token_idx * hidden_size_scale) +
949+
v_id) =
950+
*(reinterpret_cast<const int4*>(scale + s_token_idx *
951+
hidden_size_scale) +
952+
v_id);
948953
}
949954
}
950955
}
@@ -1106,7 +1111,6 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
11061111
m_indices};
11071112
} else {
11081113
permute_scale = GetEmptyTensor({token_nums_feed_to_ffn, hidden_size / 128},
1109-
{1, permute_scale_stride0},
11101114
paddle::DataType::FLOAT32,
11111115
place);
11121116
EPMoeDispatchFP8Kernel<float>(input,

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def create_buffer(self):
173173
self.num_nvl_bytes,
174174
self.num_rdma_bytes,
175175
low_latency_mode=True,
176-
num_qps_per_rank=24,
176+
# num_qps_per_rank=24,
177+
num_qps_per_rank=48,
177178
)
178179
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
179180
else:
@@ -186,7 +187,8 @@ def create_buffer(self):
186187
self.num_nvl_bytes,
187188
self.num_rdma_bytes,
188189
low_latency_mode=True,
189-
num_qps_per_rank=24,
190+
# num_qps_per_rank=24,
191+
num_qps_per_rank=48,
190192
)
191193
else:
192194
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
9999
(permute_input.shape[0], layer_added_weight_attrs_0.shape[1]),
100100
dtype=paddle.bfloat16,
101101
)
102+
if disable_ue8m0_cast:
103+
permute_scale = permute_scale.transpose([1, 0]).contiguous()
104+
permute_scale = permute_scale.transpose([1, 0])
102105
# disable_ue8m0_cast is False for SM100
103106
m_grouped_fp8_gemm_nt_contiguous(
104107
(permute_input, permute_scale),
@@ -262,10 +265,14 @@ def apply_ep_prefill(
262265
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
263266
x,
264267
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
265-
output_scale_transpose=True,
268+
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
266269
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
267270
)
268-
x_scale_tensor = x_scale_tensor.T[: x.shape[0]]
271+
x_scale_tensor = (
272+
x_scale_tensor[: x.shape[0]]
273+
if not self.quant_config.deepgemm_scale_ue8m0
274+
else x_scale_tensor.T[: x.shape[0]]
275+
)
269276

270277
event = deep_ep.Buffer.capture()
271278
let_another_thread_run()
@@ -502,10 +509,14 @@ def apply_tp(
502509
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
503510
x,
504511
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
505-
output_scale_transpose=True,
512+
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
506513
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
507514
)
508-
recv_x_scale = recv_x_scale.T[: recv_x.shape[0]]
515+
recv_x_scale = (
516+
recv_x_scale[: recv_x.shape[0]]
517+
if not self.quant_config.deepgemm_scale_ue8m0
518+
else recv_x_scale.T[: recv_x.shape[0]]
519+
)
509520
(
510521
permute_input,
511522
permute_scale,

0 commit comments

Comments
 (0)