fix bug#8043
Conversation
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-06-13 21:15:43
📋 Review 摘要
PR 概述:调整 GPU ModelRunner 初始化 attention backend 时的共享 buffer 分配参数,以覆盖分层 KV heads 场景。
变更范围:fastdeploy/worker/gpu_model_runner.py 的 attention backend buffer 初始化。
影响面 Tag:[Engine] [OP]
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🔴 Bug | fastdeploy/worker/gpu_model_runner.py:1753 |
decode-unified 的 decode_block_indices 容量不只随 group_size 增长,直接使用最小 KV heads 可能低估容量 |
📝 PR 规范检查
当前标题缺少官方 Tag,PR 描述各章节仍为空模板;可直接替换为下面内容。
标题建议(可直接复制):
[BugFix] Fix attention buffer sizing for per-layer KV heads
PR 描述建议(点击展开,可直接复制)
## Motivation
修复分层 `num_key_value_heads_list` 场景下 attention backend 共享 buffer 预分配容量不足的问题。
## Modifications
- `fastdeploy/worker/gpu_model_runner.py`:初始化 attention backend buffer 时,使用 per-layer KV heads 中的最小值计算最大 `group_size`,避免按较大的 KV head 数低估 `decoder_*` / `encoder_*` tile buffer。
## Usage or Command
N/A
## Accuracy Tests
N/A(未提供精度数据;本次 diff 仅调整 attention backend buffer 分配参数。)
## Checklist
- [x] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [x] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.总体评价
本次改动修复了普通 attention tile buffer 按最大 group_size 分配的问题,但同一个参数也被 decode-unified 的 decode_block_indices 容量复用,仍需要拆开计算容量上界后再合入。PR 标题和描述也需要按仓库模板补齐。
| num_heads=num_heads, | ||
| kv_num_heads=max(kv_num_heads_per_layer), | ||
| # This requires the largest possible group size, corresponding to the smallest kv-num-heads. | ||
| kv_num_heads=min(kv_num_heads_per_layer), |
There was a problem hiding this comment.
🔴 Bug min(kv_num_heads_per_layer) 只覆盖了按 group_size 递增的 buffer,但 decode-unified 的 decode_block_indices 容量还会乘以 kv_num_heads,这里会在部分分层 KV-head 配置下低估容量。
allocate_decode_unified_related_buffer() 和 allocate_launch_related_buffer() 中的容量是 max_batch_size * kv_num_heads * max_num_chunk * q_tile_num,而 config_for_attention() 实际按 q_tile_num * kv_chunk_num * kv_num_heads 写入。比如 num_heads=64、decoder_step_token_num=1、per-layer KV heads 为 [8, 1] 时,按 min=1 只预留 1 * ceil(64 / 1 / 16) = 4 个 q/kv-head tile;layer 0 若使用 8 个 KV heads,会需要 8 * ceil(64 / 8 / 16) = 8 个,存在越界写或错误配置风险。
建议修复方式:
保留最小 KV heads 用于 decoder_batch_ids/encoder_batch_ids 等只按 group_size 放大的 buffer,但对 decode_block_indices 单独按所有 layer 的最坏值分配,例如计算 max(k * ceil(decoder_step_token_num * (num_heads // k) / 16) for k in kv_num_heads_per_layer) 后作为容量上界;或者扩展 allocation API,分别传入最大 group_size 和 decode_block_indices 容量。
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #8043 +/- ##
==========================================
Coverage ? 67.55%
==========================================
Files ? 475
Lines ? 66657
Branches ? 10283
==========================================
Hits ? 45029
Misses ? 18753
Partials ? 2875
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.