Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions custom_ops/gpu_ops/per_token_quant_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ __host__ __device__ __forceinline__ int align(int x, int y) {

template <typename T, typename ScaleT, bool UseUE8M0>
__global__ void quant_per_token_per_block(
const T *input,
phi::dtype::float8_e4m3fn *quanted_res,
ScaleT *quanted_scale,
const T* input,
phi::dtype::float8_e4m3fn* quanted_res,
ScaleT* quanted_scale,
const int token_num,
const int hidden_size,
const int hidden_size_scale,
Expand All @@ -46,11 +46,11 @@ __global__ void quant_per_token_per_block(
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn *quanted_res_now =
const T* input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn* quanted_res_now =
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
float *quanted_scale_now = reinterpret_cast<float *>(quanted_scale) +
token_idx * hidden_size_scale;
float* quanted_scale_now =
reinterpret_cast<float*>(quanted_scale) + token_idx * hidden_size_scale;
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
const int start_offset = iter * 128;
Expand Down Expand Up @@ -91,22 +91,24 @@ __global__ void quant_per_token_per_block(
max_value_thread *= 7.0f;
}

float scale_to_store = max_value_thread / MAX_VALUE;
float scale_to_store = max_value_thread * __frcp_rn(MAX_VALUE);

// quant
if constexpr (UseUE8M0) {
scale_to_store =
exp2f(ceilf(log2f(fmaxf(scale_to_store, epsilon) + 5e-7f)));
const float rcp_scale = __frcp_rn(scale_to_store);
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
load_vec_float[vid] * rcp_scale);
}
} else {
const float rcp_max = __frcp_rn(max_value_thread);
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] * MAX_VALUE / max_value_thread);
load_vec_float[vid] * MAX_VALUE * rcp_max);
}
}
// store
Expand All @@ -116,13 +118,13 @@ __global__ void quant_per_token_per_block(
quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
if (lane_id == 0) {
if constexpr (UseUE8M0) {
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
int exp = (reinterpret_cast<int&>(scale_to_store) >> 23) & 0xFF;
const int pack_idx = iter >> 2;
const int byte_idx = iter & 3;
const int pack_num = ceil_div(hidden_size_scale, 4);
int32_t *scale_now = quanted_scale;
int32_t* scale_now = quanted_scale;
const int base_idx = token_idx * pack_num + pack_idx;
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
reinterpret_cast<uint8_t*>(&scale_now[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
} else {
quanted_scale_now[iter] = scale_to_store;
Expand All @@ -132,7 +134,7 @@ __global__ void quant_per_token_per_block(
}
}

std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
const int block_size,
const bool use_ue8m0) {
auto input_dim = input.dims();
Expand All @@ -149,7 +151,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int blockx = min(1024, hidden_size / 128 * 32);

bool use_finegrained_range = false;
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
char* env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
Expand Down Expand Up @@ -242,9 +244,9 @@ std::vector<paddle::DataType> PerTokenQuantInferDtype(

template <typename T, typename ScaleT, bool UseUE8M0>
__global__ void quant_per_token_per_block_padding(
const T *input,
phi::dtype::float8_e4m3fn *quanted_res,
ScaleT *quanted_scale,
const T* input,
phi::dtype::float8_e4m3fn* quanted_res,
ScaleT* quanted_scale,
const int token_num,
const int padded_token_num,
const int hidden_size,
Expand All @@ -262,8 +264,8 @@ __global__ void quant_per_token_per_block_padding(
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const T *input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn *quanted_res_now =
const T* input_now = input + static_cast<int64_t>(token_idx) * hidden_size;
phi::dtype::float8_e4m3fn* quanted_res_now =
quanted_res + static_cast<int64_t>(token_idx) * hidden_size;
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
Expand Down Expand Up @@ -320,7 +322,7 @@ __global__ void quant_per_token_per_block_padding(
if (lane_id == 0) {
if constexpr (UseUE8M0) {
// exp
int exp = (reinterpret_cast<int &>(scale_to_store) >> 23) & 0xFF;
int exp = (reinterpret_cast<int&>(scale_to_store) >> 23) & 0xFF;

const int pack_idx = iter >> 2;
const int byte_idx = iter & 3;
Expand All @@ -329,14 +331,14 @@ __global__ void quant_per_token_per_block_padding(
const int pack_num = align(hidden_size_scale, 4) >> 2;

// column-major base index
int32_t *scale_now = quanted_scale;
int32_t* scale_now = quanted_scale;
const int base_idx = token_idx + pack_idx * padded_token_num;

// ---------------- store exp ----------------
reinterpret_cast<uint8_t *>(&scale_now[base_idx])[byte_idx] =
reinterpret_cast<uint8_t*>(&scale_now[base_idx])[byte_idx] =
static_cast<uint8_t>(exp);
} else {
float *scale_now =
float* scale_now =
quanted_scale + iter * padded_token_num + token_idx;
*scale_now = scale_to_store;
}
Expand All @@ -345,7 +347,7 @@ __global__ void quant_per_token_per_block_padding(
}
}

std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
const int block_size,
const bool use_ue8m0) {
using ScaleDtype = float;
Expand All @@ -372,7 +374,7 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
const int blockx = min(1024, hidden_size / 128 * 32);

bool use_finegrained_range = false;
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
char* env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ def _validate_split_kv_size(value: int) -> int:
# When set to a valid JSON dict, metric labels are automatically enabled.
# Example: '{"model_id":"my_model"}' adds model_id label to all metrics.
"FD_DEFAULT_METRIC_LABEL_VALUES": lambda: os.getenv("FD_DEFAULT_METRIC_LABEL_VALUES", "{}"),
# When set to 1, skip certain layers in unit tests to ensure Prefill/Decode output consistency during PD consistency checks.
"FD_SKIP_IN_DETERMINISTIC": lambda: int(os.getenv("FD_SKIP_IN_DETERMINISTIC", "0")),
}


Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
qk_norm_before_rope: bool = False,
rms_norm_eps: float = 1e-6,
with_sinks: bool = False,
skip_attn: bool = False,
) -> None:
"""
Initializes `LMLayer` with the given parameters.
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
self.use_neox_rotary_style: bool = use_neox_rotary_style

self.with_sinks: bool = with_sinks
self.skip_attn: bool = skip_attn

if fd_config.quant_config and hasattr(fd_config.quant_config, "kv_cache_quant_type"):
self.quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
Expand Down Expand Up @@ -272,6 +274,8 @@ def forward(
compressed_kv: optional compressed key-value cache (for MLA)
k_pe: optional key positional encoding (for MLA)
"""
if self.skip_attn:
return qkv[..., : self.head_dim * self.num_heads]
# ============ V1 KVCACHE Manager: Layer-by-layer swap wait ============
# Wait for swap-in of current layer before using cache
if forward_meta.layer_done_counter is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def matmul_kernel_persistent(
num_pid_in_group = GROUP_SIZE_M * num_pid_n

for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
Expand Down Expand Up @@ -117,7 +121,11 @@ def matmul_kernel_persistent(
accumulator = tl.dot(a, b, accumulator)

tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
group_id = tile_id_c // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id_c % group_size_m)
pid_n = (tile_id_c % num_pid_in_group) // group_size_m
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if C_LARGE:
Expand Down Expand Up @@ -531,7 +539,11 @@ def bmm_kernel_persistent(
batch_idx = tile_id // num_tiles_per_batch
tile_in_batch = tile_id % num_tiles_per_batch

pid_m, pid_n = _compute_pid(tile_in_batch, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
group_id = tile_in_batch // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_in_batch % group_size_m)
pid_n = (tile_in_batch % num_pid_in_group) // group_size_m
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
Expand Down Expand Up @@ -662,18 +674,45 @@ def bmm_batch_invariant(x, y):
return bmm_persistent(x, y)


def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False, out=None):
def mm_batch_invariant(a, b, bias=None, transpose_x=False, transpose_y=False, out=None):
if transpose_x:
a = a.T
if transpose_y:
b = b.T
result = matmul_persistent(a, b)
result = matmul_persistent(a, b, bias)
if out is not None:
out.copy_(result, False)
return out
return result


def linear_batch_invariant(x, weight, bias=None, *args, **kwargs):
"""Drop-in replacement for paddle._C_ops.linear.

_C_ops.linear computes: out = x @ weight + bias
Weight shape is [K, N] (no transpose).
"""
return matmul_persistent(x, weight, bias)


def linear_v2_batch_invariant(x, weight, bias=None, weight_transposed=False):
"""Drop-in replacement for paddle._C_ops.linear_v2.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 形状判断 weight.shape[0] == K 用于推断 weight 是否已经是 [K, N],但当矩阵为方阵(K == N)时无法区分两种布局,会静默地走 weight is already [K, N] 分支,可能产生错误的矩阵乘法。

建议直接使用 weight_transposed 参数进行判断(该参数已由调用方传入),或与 Paddle flag FLAGS_use_accuracy_compatible_kernel 对齐:

def linear_v2_batch_invariant(x, weight, bias=None, weight_transposed=False):
    # 依据 weight_transposed 标志而非形状猜测
    if weight_transposed:
        # weight is [K, N] (已由调用方转置)
        return matmul_persistent(x, weight, bias)
    else:
        # weight is [N, K],需要转置
        return matmul_persistent(x, weight.T, bias)

weight_transposed 语义在不同 Paddle 版本下不一致,至少应在方阵场景下加断言或警告。

linear_v2 computes: out = x @ weight + bias, where weight should be [K, N].
The weight_transposed flag semantics varies depending on FLAGS_use_accuracy_compatible_kernel:
- True mode: weight_transposed=True → weight is [K, N] (transposed from [N, K] by caller)
- False mode: weight_transposed=False → weight is [K, N] (caller didn't transpose)
In both cases the weight passed in is [K, N], so we determine by actual dimensions.
"""
K = x.shape[1]
if weight.shape[0] == K:
# weight is already [K, N], use directly
return matmul_persistent(x, weight, bias)
else:
# weight is [N, K], need transpose to [K, N]
return matmul_persistent(x, weight.T, bias)


def addmm_batch_invariant(
input: paddle.Tensor, x: paddle.Tensor, y: paddle.Tensor, beta: float = 1.0, alpha: float = 1.0
) -> paddle.Tensor:
Expand Down Expand Up @@ -791,7 +830,15 @@ def rms_norm_batch_invariant(x: paddle.Tensor, weight: paddle.Tensor, eps: float
return out.reshape(orig_shape)


_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None, "bmm": None}
_original_ops = {
"mm": None,
"addmm": None,
"_log_softmax": None,
"mean_dim": None,
"bmm": None,
"linear": None,
"linear_v2": None,
}

_batch_invariant_MODE = False

Expand Down Expand Up @@ -822,7 +869,10 @@ def enable_batch_invariant_mode():
_original_ops["log_softmax"] = paddle._C_ops.log_softmax
_original_ops["mean"] = paddle._C_ops.mean
_original_ops["bmm"] = paddle._C_ops.bmm

_original_ops["linear"] = paddle._C_ops.linear
_original_ops["linear_v2"] = paddle._C_ops.linear_v2
paddle._C_ops.linear = linear_batch_invariant
paddle._C_ops.linear_v2 = linear_v2_batch_invariant
paddle._C_ops.matmul = mm_batch_invariant
paddle._C_ops.addmm = addmm_batch_invariant
paddle._C_ops.log_softmax = _log_softmax_batch_invariant
Expand Down Expand Up @@ -856,6 +906,10 @@ def disable_batch_invariant_mode():
paddle._C_ops.mean = _original_ops["mean"]
if _original_ops["bmm"]:
paddle._C_ops.bmm = _original_ops["bmm"]
if _original_ops["linear"]:
paddle._C_ops.linear = _original_ops["linear"]
if _original_ops["linear_v2"]:
paddle._C_ops.linear_v2 = _original_ops["linear_v2"]

_batch_invariant_MODE = False

Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import os
import re
from functools import partial
from typing import Dict
Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
prefix=prefix,
use_neox_rotary_style=True,
rms_norm_eps=fd_config.model_config.rms_norm_eps,
skip_attn=bool(int(os.getenv("FD_SKIP_IN_DETERMINISTIC", "0"))),
)
if self.use_qk_norm:
self.qk_norm = QKRMSNorm(
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ def _process_batch_output_use_zmq(self, receive_datas):
finished=False,
metrics=metrics,
ic_req_data=task.ic_req_data,
prompt_token_ids_len=task.prompt_token_ids_len,
prompt_token_ids=task.prompt_token_ids,
)
if self.use_logprobs:
if getattr(stream_data, "logprobs", None) is not None:
Expand Down Expand Up @@ -975,6 +977,7 @@ def _process_batch_output(self):
metrics=metrics,
ic_req_data=task.ic_req_data,
prompt_token_ids_len=task.prompt_token_ids_len,
prompt_token_ids=task.prompt_token_ids,
trace_carrier=trace_carrier,
)
if self.tokens_counter[task_id] == 0:
Expand Down
18 changes: 17 additions & 1 deletion fastdeploy/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,23 @@ def graph_optimize_and_warm_up_model(self) -> None:
# than subsequent requests, causing occasional first-run divergence.
if envs.FD_DETERMINISTIC_MODE:
set_random_seed(self.fd_config.model_config.seed)
self.model_runner.share_inputs.reset_share_inputs()
# self.model_runner.share_inputs.reset_share_inputs()
if hasattr(self.model_runner.share_inputs, "reset_share_inputs"):
self.model_runner.share_inputs.reset_share_inputs()
elif isinstance(self.model_runner.share_inputs, dict):
# 创建一个临时的 InputBatch 来借用它的逻辑
from fastdeploy.worker.input_batch import InputBatch

temp_batch = InputBatch.__new__(InputBatch)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 gpu_model_runner.py:226self.share_inputs = InputBatch(self.fd_config) 确认 share_inputs 始终是 InputBatch 实例,因此 elif isinstance(self.model_runner.share_inputs, dict) 分支为死代码,在生产中永远不会执行。

另外,被注释掉的原始调用 # self.model_runner.share_inputs.reset_share_inputs() 应一并删除,避免维护歧义。

建议简化为:

if envs.FD_DETERMINISTIC_MODE:
    set_random_seed(self.fd_config.model_config.seed)
    if hasattr(self.model_runner.share_inputs, 'reset_share_inputs'):
        self.model_runner.share_inputs.reset_share_inputs()

temp_batch.__dict__.update(self.model_runner.share_inputs) # 把字典内容注入
temp_batch.model_config = self.fd_config.model_config
temp_batch.scheduler_config = self.fd_config.scheduler_config
temp_batch.cache_config = self.fd_config.cache_config
temp_batch.reset_share_inputs()
# 把结果写回字典
self.model_runner.share_inputs.update(
{k: v for k, v in temp_batch.__dict__.items() if k in self.model_runner.share_inputs}
)

def check_health(self) -> bool:
""" """
Expand Down
Loading
Loading