Skip to content

support qkdim!=vdim#8023

Open
chang-wenbin wants to merge 8 commits into
PaddlePaddle:developfrom
chang-wenbin:qkdim_vdim
Open

support qkdim!=vdim#8023
chang-wenbin wants to merge 8 commits into
PaddlePaddle:developfrom
chang-wenbin:qkdim_vdim

Conversation

@chang-wenbin

Copy link
Copy Markdown
Collaborator

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@codecov-commenter

codecov-commenter commented Jun 8, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 52.77778% with 17 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@edc885d). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/linear.py 57.69% 10 Missing and 1 partial ⚠️
...l_executor/layers/attention/append_attn_backend.py 50.00% 2 Missing and 2 partials ⚠️
...astdeploy/model_executor/ops/triton_ops/do_rope.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #8023   +/-   ##
==========================================
  Coverage           ?   67.67%           
==========================================
  Files              ?      471           
  Lines              ?    66360           
  Branches           ?    10217           
==========================================
  Hits               ?    44912           
  Misses             ?    18576           
  Partials           ?     2872           
Flag Coverage Δ
GPU 77.73% <52.77%> (?)
XPU 6.99% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot

PaddlePaddle-bot commented Jun 9, 2026

Copy link
Copy Markdown

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-06-11 23:16:39 UTC+08:00

CI报告基于以下代码生成(30分钟更新一次):
PR commit: 94fa1f9 | Merge base: edc885d (branch: develop)


1 Required任务 : 8/10 通过

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
42(0) 42 37 5 0 0 0
任务 错误类型 置信度 日志
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage PR问题 Job
Extracted partial CE model tasks to run in CI. / run_ce_cases 未知 Job

2 失败详情

🔴 Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — PR问题(置信度: 高)

分析器: 通用分析(fallback)
失败用例:

用例 错误摘要
layers/test_attention_layer.py::TestAttentionPerformance::test_append_attn_backend_decode_performance_with_prefill append_attention_gpu 触发 partial_rotary_factor < 1.0 参数约束失败
model_executor/test_paddleformers_base.py::TestPaddleFormersQKVParallelLinearUnit::test_extract_local_shard_with_transpose_and_tp_slice PaddleFormersQKVParallelLinear 缺少 v_head_dim 属性
model_loader/test_offline_model.py::test_offline_model[offline_quant_Qwen3-30B-A3B-FP8.None.triton] worker 进程 300s 超时被终止,疑似受同一批 QK/V 维度变更影响

关键日志:

SystemError: (Fatal) partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false.
AttributeError: 'PaddleFormersQKVParallelLinear' object has no attribute 'v_head_dim'
RuntimeError: Worker process hung and was terminated
  • 根因摘要: qk/v 维度变更未完整同步
    PR 修改了 fastdeploy/model_executor/layers/attention/append_attn_backend.pyfastdeploy/model_executor/layers/linear.pydo_rope.py 以支持 qkdim != vdim,但 attention 路径仍会把 q_norm_weight/k_norm_weight 传给 append attention kernel,触发 partial rotary 场景下的 kernel 参数限制。另一个确定失败来自 PaddleFormersQKVParallelLinear 的测试对象通过 object.__new__ 构造,未设置本 PR 新增的 v_head_dim,而 _get_shard_size_mapping() 已改为读取该属性。

修复建议:

  1. fastdeploy/model_executor/layers/attention/append_attn_backend.py 中,当 external_norm_rope=True 且前面已经执行 qk_rmsnorm_fused/do_rope/write_cache 后,避免继续向 append_attention 传入 q_norm_weight/k_norm_weight,确保 partial rotary 的 kernel 参数满足限制。
  2. tests/model_executor/test_paddleformers_base.py 的手工构造 layer 中补齐 v_head_dim = head_dim,或在相关 shard size 逻辑中使用 getattr(self, "v_head_dim", self.head_dim) 兜底。
  3. 离线模型超时与本 PR 修改的 attention/linear 路径有关联,但还需要补充运行日志确认具体卡点。

关联变更: fastdeploy/model_executor/layers/attention/append_attn_backend.py, fastdeploy/model_executor/layers/linear.py, fastdeploy/model_executor/ops/triton_ops/do_rope.py, tests/model_executor/test_linear.py

🔴 Extracted partial CE model tasks to run in CI. / run_ce_cases — 未知(置信度: 低)

分析省略/待后续深挖。

当前仅掌握该任务以 exit code 123 失败,尚未读取该 job 的详细日志,无法判断是 PR 问题、环境问题还是不稳定问题。

PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-06-11 02:07:41 Asia/Shanghai

📋 Review 摘要

PR 概述:为 attention/QKV 路径补充 qkdim != vdim 支持。
变更范围:AppendAttention backend、QKV/QKVG linear loader、Triton RoPE、线性层单测。
影响面 Tag[OP] [KVCache] [Loader]

问题

级别 文件 概述
🔴 Bug fastdeploy/model_executor/layers/attention/append_attn_backend.py:324 v_head_dim != head_dim 会无条件执行 q/k RMSNorm,未启用 qk norm 的模型会把 None 传进 Triton kernel
🟡 建议 tests/model_executor/test_linear.py:291 新增测试仍使用 v_head_dim == head_dim,没有覆盖本 PR 的核心维度不等分支

📝 PR 规范检查

标题缺少官方 Tag,且 PR 描述的 Motivation / Modifications / Usage or Command / Accuracy Tests 仍未填写。

标题建议(可直接复制):

  • [OP] Support qkdim != vdim in attention and QKV loading
PR 描述建议(点击展开,可直接复制)
## Motivation
Support models whose Q/K head dimension differs from V head dimension (`qkdim != vdim`) in FastDeploy attention and QKV projection paths.

## Modifications
- Add `v_head_dim` to attention backend constructors and use it for value cache shape.
- Pass `model_config.v_head_dim` from GPU model runner to the selected attention backend.
- Update QKV/QKVG parallel linear output sizing and V shard placement to use `v_head_dim`.

## Usage or Command
N/A

## Accuracy Tests
N/A. Current PR does not include accuracy or regression test results.

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [x] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.


if getattr(layer, "only_do_attn", False):
if self.external_norm_rope:
qk_rmsnorm_fused(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 external_norm_rope 只和 v_head_dim != head_dim 绑定后,会对所有 qk/v 维度不同的层执行 qk_rmsnorm_fused。但 Attention 默认 use_qk_norm=False,只有开启时才会创建 q_norm_weight / k_norm_weight;现有多处 Attention(...) 构造没有传 use_qk_norm。这些模型一旦配置 v_head_dim != head_dim,这里会把 None 作为 Triton 指针传入,qk_rmsnorm_fused_kernel 随后 tl.load(q_weight_ptr + ...) 会直接失败。建议把“需要外部 rope/write_cache”和“需要 q/k norm”拆开:只有权重存在或 layer.use_qk_norm 为真时才跑 fused norm;没有 q/k norm 的模型仍应只做 RoPE/write_cache。

kv_num_heads_per_rank=1,
num_kv_head_replicas=2,
head_dim=2,
v_head_dim=2,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个新增用例仍然设置 v_head_dim=2head_dim=2,因此不会覆盖本 PR 最关键的 qkdim != vdim 分支:V shard size、param offset、shared KV slice 等仍按旧路径通过。建议至少加入 v_head_dim != head_dim(例如 head_dim=2, v_head_dim=3)的 fused/split load 断言,验证 V 段和后续 offset/gate 不重叠。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants