-
Notifications
You must be signed in to change notification settings - Fork 752
support qkdim!=vdim #8023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
support qkdim!=vdim #8023
Changes from all commits
44f547c
fc06ae2
19a7044
2958fda
303ad42
29dce63
db7d260
94fa1f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -176,6 +176,8 @@ def __init__( | |
| self.num_heads: int = num_heads | ||
| self.group_size: int = self.num_heads // self.kv_num_heads | ||
| self.head_dim: int = fd_config.model_config.head_dim | ||
| self.v_head_dim: int = getattr(fd_config.model_config, "v_head_dim", self.head_dim) | ||
| self.external_norm_rope: bool = True if self.v_head_dim != self.head_dim else False | ||
| self.num_layers: int = fd_config.model_config.num_hidden_layers | ||
|
|
||
| # head wise sliding window attention | ||
|
|
@@ -290,7 +292,9 @@ def get_kv_cache_shape( | |
| key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] | ||
| if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": | ||
| key_cache_shape[-1] = self.head_dim // 2 | ||
| value_cache_shape = key_cache_shape | ||
| value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.v_head_dim] | ||
| if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": | ||
| value_cache_shape[-1] = self.v_head_dim // 2 | ||
| return key_cache_shape, value_cache_shape | ||
|
|
||
| def forward_mixed( | ||
|
|
@@ -310,10 +314,23 @@ def forward_mixed( | |
|
|
||
| cache_k = forward_meta.caches[2 * layer.layer_id] | ||
| cache_v = forward_meta.caches[2 * layer.layer_id + 1] | ||
| from fastdeploy.model_executor.ops.triton_ops import ( | ||
| do_rope, | ||
| qk_rmsnorm_fused, | ||
| write_cache, | ||
| ) | ||
|
|
||
| from fastdeploy.model_executor.ops.triton_ops import do_rope, write_cache | ||
|
|
||
| if getattr(layer, "only_do_attn", False): | ||
| if self.external_norm_rope: | ||
| qk_rmsnorm_fused( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里把 |
||
| qkv, | ||
| getattr(layer, "q_norm_weight", None), | ||
| getattr(layer, "k_norm_weight", None), | ||
| getattr(layer, "rms_norm_eps", 1e-6), | ||
| layer.num_heads * layer.head_dim, | ||
| layer.kv_num_heads * layer.head_dim, | ||
| cache_k.shape[-1], | ||
| cache_v.shape[-1], | ||
| ) | ||
| do_rope( | ||
| qkv, | ||
| forward_meta.rotary_embs[0], | ||
|
|
@@ -547,7 +564,7 @@ def forward_mixed( | |
| sliding_window, | ||
| self.sink_size, | ||
| self.head_wise_full_hidden if self.head_wise_swa_ratio > 0 else 0, | ||
| getattr(layer, "only_do_attn", False), | ||
| self.external_norm_rope, # if True is means only_do_attn | ||
| ) | ||
| return res | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -287,6 +287,8 @@ def test_qkv_paths(): | |
| num_heads_per_rank=2, | ||
| kv_num_heads_per_rank=1, | ||
| num_kv_head_replicas=2, | ||
| head_dim=2, | ||
| v_head_dim=2, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个新增用例仍然设置 |
||
| tp_size=2, | ||
| local_rank=0, | ||
| fd_config=cfg_tp2, | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.