Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions fastdeploy/model_executor/layers/attention/append_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def __init__(
self.num_heads: int = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.v_head_dim: int = getattr(fd_config.model_config, "v_head_dim", self.head_dim)
self.external_norm_rope: bool = True if self.v_head_dim != self.head_dim else False
self.num_layers: int = fd_config.model_config.num_hidden_layers

# head wise sliding window attention
Expand Down Expand Up @@ -290,7 +292,9 @@ def get_kv_cache_shape(
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape[-1] = self.head_dim // 2
value_cache_shape = key_cache_shape
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.v_head_dim]

This comment was marked as outdated.

if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
value_cache_shape[-1] = self.v_head_dim // 2
return key_cache_shape, value_cache_shape

def forward_mixed(
Expand All @@ -310,10 +314,23 @@ def forward_mixed(

cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
from fastdeploy.model_executor.ops.triton_ops import (
do_rope,
qk_rmsnorm_fused,
write_cache,
)

from fastdeploy.model_executor.ops.triton_ops import do_rope, write_cache

if getattr(layer, "only_do_attn", False):
if self.external_norm_rope:
qk_rmsnorm_fused(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 external_norm_rope 只和 v_head_dim != head_dim 绑定后,会对所有 qk/v 维度不同的层执行 qk_rmsnorm_fused。但 Attention 默认 use_qk_norm=False,只有开启时才会创建 q_norm_weight / k_norm_weight;现有多处 Attention(...) 构造没有传 use_qk_norm。这些模型一旦配置 v_head_dim != head_dim,这里会把 None 作为 Triton 指针传入,qk_rmsnorm_fused_kernel 随后 tl.load(q_weight_ptr + ...) 会直接失败。建议把“需要外部 rope/write_cache”和“需要 q/k norm”拆开:只有权重存在或 layer.use_qk_norm 为真时才跑 fused norm;没有 q/k norm 的模型仍应只做 RoPE/write_cache。

qkv,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
layer.num_heads * layer.head_dim,
layer.kv_num_heads * layer.head_dim,
cache_k.shape[-1],
cache_v.shape[-1],
)
do_rope(
qkv,
forward_meta.rotary_embs[0],
Expand Down Expand Up @@ -547,7 +564,7 @@ def forward_mixed(
sliding_window,
self.sink_size,
self.head_wise_full_hidden if self.head_wise_swa_ratio > 0 else 0,
getattr(layer, "only_do_attn", False),
self.external_norm_rope, # if True is means only_do_attn
)
return res

Expand Down
95 changes: 52 additions & 43 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,17 +666,18 @@ def __init__(
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
self.v_head_dim = getattr(fd_config.model_config, "v_head_dim", fd_config.model_config.head_dim)
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)
if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
output_size = (self.num_heads + 2 * self.tp_size) * self.head_dim
output_size = (self.num_heads + self.tp_size) * self.head_dim + self.tp_size * self.v_head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
self.num_kv_head_replicas = 1
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
output_size = (self.num_heads + self.kv_num_heads) * self.head_dim + self.kv_num_heads * self.v_head_dim

This comment was marked as outdated.

input_size = self.hidden_size
super().__init__(
fd_config=fd_config,
Expand All @@ -692,15 +693,13 @@ def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
shard_size_mapping = {
"q": self.num_heads_per_rank * head_dim,
"k": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * self.v_head_dim,
}
return shard_size_mapping.get(loaded_shard_id)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
Expand All @@ -711,9 +710,9 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * head_dim),
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
("q", 0, self.num_heads * self.head_dim),
("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim),
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.v_head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
Expand All @@ -728,7 +727,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and not self.fd_config.load_config.is_pre_sharded:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
block_size = self._get_shard_size_mapping(loaded_shard_id, self.head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = block_size
Expand All @@ -738,16 +737,15 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
param.initialize()

if loaded_shard_id == "q":

param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * head_dim
param_shard_size = self.num_heads_per_rank * self.head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads_per_rank * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
param_shard_offset = self.num_heads_per_rank * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.v_head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)

Expand Down Expand Up @@ -783,7 +781,8 @@ def load_weight(self, state_dict: dict):
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor], axis=-1).transpose([1, 0])
weight_tensor = weight_tensor.reshape(
[
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * (self.head_dim),
(self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim
+ self.kv_num_heads_per_rank * self.v_head_dim,

This comment was marked as outdated.

self.hidden_size,
]
)
Expand Down Expand Up @@ -1171,18 +1170,24 @@ def __init__(
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
self.v_head_dim = getattr(fd_config.model_config, "v_head_dim", fd_config.model_config.head_dim)
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)

if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
output_size = (2 * self.num_heads + 2 * self.tp_size) * self.head_dim
output_size = (self.num_heads + self.tp_size) * self.head_dim + (
self.num_heads + self.tp_size
) * self.v_head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
self.num_kv_head_replicas = 1
output_size = (2 * self.num_heads + 2 * self.kv_num_heads) * self.head_dim
# qkvg layout: [q (num_heads*head_dim) | k (kv_heads*head_dim) | v (kv_heads*v_head_dim) | gate (num_heads*v_head_dim)]
output_size = (self.num_heads + self.kv_num_heads) * self.head_dim + (
self.num_heads + self.kv_num_heads
) * self.v_head_dim
input_size = self.hidden_size
super().__init__(
fd_config=fd_config,
Expand All @@ -1198,28 +1203,33 @@ def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
shard_size_mapping = {
"q": self.num_heads_per_rank * head_dim,
"k": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * self.v_head_dim,
}
return shard_size_mapping.get(loaded_shard_id)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in [
"qkv",
"q",
"k",
"v",
"gate",
], f"loaded_shard_id must be one of ['qkv', 'gate'], but got {loaded_shard_id}"

if loaded_shard_id == "qkv":
self.qkv_weight_loader(param, loaded_weight, None)
else:
], f"loaded_shard_id must be one of ['qkv', 'q', 'k', 'v', 'gate'], but got {loaded_shard_id}"
if loaded_shard_id == "gate":
self.gate_weight_loader(param, loaded_weight)
elif loaded_shard_id in ("qkv", "q", "k", "v"):
# qkv: 传入的是 fused q+k+v 一次性拆分
# q/k/v: 单独传入某一头,由 qkv_weight_loader 直接放置到对应 offset
sub_id = None if loaded_shard_id == "qkv" else loaded_shard_id
self.qkv_weight_loader(param, loaded_weight, sub_id)
else:
raise ValueError(
f"loaded_shard_id must be one of ['qkv','q','k','v','gate'], " f"but got {loaded_shard_id}"
)

def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0

# q_head + gate_head + kv_head
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
Expand All @@ -1230,9 +1240,9 @@ def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * head_dim),
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
("q", 0, self.num_heads * self.head_dim),
("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim),
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.v_head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
Expand All @@ -1247,7 +1257,7 @@ def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
block_size = self._get_shard_size_mapping(loaded_shard_id, self.head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = block_size
Expand All @@ -1258,14 +1268,14 @@ def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):

if loaded_shard_id == "q":
param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * head_dim
param_shard_size = self.num_heads_per_rank * self.head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads_per_rank * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
param_shard_offset = self.num_heads_per_rank * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.v_head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)

Expand All @@ -1276,9 +1286,6 @@ def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
def gate_weight_loader(self, param, loaded_weight):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0
# q_head + gate_head + kv_head
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)

if weight_need_transpose:
Expand All @@ -1287,16 +1294,18 @@ def gate_weight_loader(self, param, loaded_weight):

# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
block_size = self.num_heads_per_rank * head_dim
block_size = self.num_heads_per_rank * self.head_dim

This comment was marked as outdated.

shard_offset = self.local_rank * block_size
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)

if not param._is_initialized():
param.initialize()

param_shard_offset = (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.num_heads_per_rank * head_dim
param_shard_offset = (
self.num_heads_per_rank + self.kv_num_heads_per_rank
) * self.head_dim + self.kv_num_heads_per_rank * self.v_head_dim
param_shard_size = self.num_heads_per_rank * self.v_head_dim

if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/ops/triton_ops/do_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def do_rope(

head_dim_k = cache_k.shape[-1]
num_kv_heads = cache_k.shape[1]
head_dim_v = cache_k.shape[-1]
head_dim_v = cache_v.shape[-1]
qkv_size = qkv_out.shape[-1]
num_q_heads = (qkv_size - head_dim_v * num_kv_heads) // head_dim_k - num_kv_heads
num_q_heads = (qkv_size - (head_dim_k + head_dim_v) * num_kv_heads) // head_dim_k

M = qkv_out.shape[0]
grid = (M,)
Expand Down
2 changes: 2 additions & 0 deletions tests/model_executor/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def test_qkv_paths():
num_heads_per_rank=2,
kv_num_heads_per_rank=1,
num_kv_head_replicas=2,
head_dim=2,
v_head_dim=2,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个新增用例仍然设置 v_head_dim=2head_dim=2,因此不会覆盖本 PR 最关键的 qkdim != vdim 分支:V shard size、param offset、shared KV slice 等仍按旧路径通过。建议至少加入 v_head_dim != head_dim(例如 head_dim=2, v_head_dim=3)的 fused/split load 断言,验证 V 段和后续 offset/gate 不重叠。

tp_size=2,
local_rank=0,
fd_config=cfg_tp2,
Expand Down
Loading