Skip to content

Integrate Elastic-Attention (PawQwen3) into FastDeploy#8001

Open
tianzhenxu wants to merge 5 commits into
PaddlePaddle:developfrom
tianzhenxu:elastic_attn_pr
Open

Integrate Elastic-Attention (PawQwen3) into FastDeploy#8001
tianzhenxu wants to merge 5 commits into
PaddlePaddle:developfrom
tianzhenxu:elastic_attn_pr

Conversation

@tianzhenxu

@tianzhenxu tianzhenxu commented Jun 4, 2026

Copy link
Copy Markdown

[Models] Integrate Elastic-Attention (PawQwen3) into FastDeploy

Motivation

Integrate Elastic-Attention (PawQwen3) as a new model architecture in FastDeploy for long-context inference. Reuses FastDeploy's scheduler / KV cache / PagedAttention infrastructure and Qwen3's weight loading, QKV fusion, TP mapping, RoPE and MLP modules; only the attention layer is replaced with the Elastic-Attention (block-sparse + router) path during prefill.

Modifications

  • Added new model package fastdeploy/model_executor/models/qwen3_elastic/:

    • modeling_elastic_qwen3.py: PawQwen3ForCausalLM / Qwen3ElasticAttention / decoder layer / model definition, weight loading and TP mappings.
    • config_elastic.py: bridges ckpt config.json fields (sink_size, local_window_size, toggle_type, retrieval_mode, pooling_mode, xattn_stride/threshold/norm, block_size, etc.) to model_config.
    • utils.py: AttentionRouter (3-layer MLP, argmax over 2-class logits), ctx_q_pool (head-100 + tail-100 mean, aligned with HF eval path), derive_head_mask_type ((retrieval_mode, toggle_type){1, 0, -1} per Q-head), _LinearTransposed (HF→Paddle weight transpose flag).
    • kernels/: Triton estimator + block-sparse attention scheduler (Xattention_prefill_dim4, xattn_estimate, find_blocks_chunked) and a Paddle wrapper around the block_sparse_attn_ops custom op.
  • Added BSA CUDA custom op under custom_ops/gpu_ops/block_sparse_attn/ (independent .so to isolate BSA's bundled CUTLASS 3.3 from FastDeploy's newer CUTLASS).

  • Added attention backend fastdeploy/model_executor/layers/attention/elastic_attn_backend.py (Qwen3ElasticAttentionBackend, inherits FlashAttentionBackend):

    • Prefill leg: slices pre-RoPE K from qkv → gqa_rope_write_cache (RoPE + KV cache write) → GQA expand → per-segment ctx_q_poolAttentionRouterderive_head_mask_typeXattention_prefill_dim4 (Triton estimate + BSA op).
    • Decode leg: reuses append_attention (paged FA, dense), identical to dense Qwen3.
    • merge_prefill_decode_output for chunked prefill / continuous batching.
  • Architecture-aware patch in qwen3_elastic/__init__.py to attention_selecter.get_attention_backend / _get_attn_backend: only routes to Qwen3ElasticAttentionBackend when architectures[0] == "PawQwen3ForCausalLM", leaving other models untouched.

  • RoPE/YaRN patch on rotary_embedding.get_rope_impl: for PawQwen3 + rope_scaling.type == "yarn", constructs GptOssScalingRotaryEmbedding (neox style) with the correct (2, 1, T, 1, head_dim) layout to satisfy gqa_rope_write_cache / append_attention kernel asserts.

  • Weight loading: extends stacked_params_mapping for QKV / MLP fusion and q_norm/k_normqk_norm.{q,k}_norm; router weights auto-transposed via weight_need_transpose=True; mask_allocator.log_temp skipped (training-only); supports tie_word_embeddings.

  • Added smoke / evaluation scripts: run_elastic_qwen3_4b.py, run_dump_router_2wikimqa.py, run_longbench_elastic.py(.sh), run_longbench_qwen3.py(.sh) (dense baseline), and a README.

Usage or Command

python fastdeploy/model_executor/models/qwen3_elastic/run_elastic_qwen3_4b.py \
  --model_path /path/to/full_xattn_64k_qwen3-4b_wfrozen

Unit Tests

pre-commit (black / isort / flake8 / ruff / EOF / trailing whitespace / merge-conflict / private-key / large-files) — all passed:

$ pre-commit run --files \
    tests/layers/test_elastic_attention_backend.py \
    tests/model_executor/test_elastic_qwen3_config.py \
    tests/model_executor/test_elastic_qwen3_patches.py \
    tests/model_executor/test_elastic_qwen3_utils.py \
    tests/operators/attention/test_block_sparse_attn_paddle.py \
    tests/operators/attention/test_xattention_estimate.py

black....................................................................Passed
isort....................................................................Passed
flake8...................................................................Passed
ruff.....................................................................Passed
clang-format.........................................(no files to check)Skipped
PyMarkdown...........................................(no files to check)Skipped
check for merge conflicts................................................Passed
check for broken symlinks............................(no files to check)Skipped
fix end of files.........................................................Passed
trim trailing whitespace.................................................Passed
detect private key.......................................................Passed
check for added large files..............................................Passed

pytest40 / 40 passed:

$ python -m pytest \
    tests/model_executor/test_elastic_qwen3_utils.py \
    tests/model_executor/test_elastic_qwen3_config.py \
    tests/model_executor/test_elastic_qwen3_patches.py \
    tests/operators/attention/test_xattention_estimate.py \
    tests/operators/attention/test_block_sparse_attn_paddle.py \
    tests/layers/test_elastic_attention_backend.py -v

platform linux -- Python 3.10.20, pytest-9.0.3, pluggy-1.6.0
collected 40 items

tests/model_executor/test_elastic_qwen3_utils.py ........ (15) PASSED
tests/model_executor/test_elastic_qwen3_config.py ....... (5)  PASSED
tests/model_executor/test_elastic_qwen3_patches.py ...... (4)  PASSED
tests/operators/attention/test_xattention_estimate.py ... (4)  PASSED
tests/operators/attention/test_block_sparse_attn_paddle.py (7) PASSED
tests/layers/test_elastic_attention_backend.py .......... (5)  PASSED

================== 40 passed ==================

Coverage breakdown:

  • Utils (15): _LinearTransposed, AttentionRouter (argmax / batch / shape-dtype-range), ctx_q_pool (long / short / varlen), derive_head_mask_type (full+full / full+streaming / full+xattn / GQA repeat-interleave / unsupported pair / xattn+streaming / xattn+xattn).
  • Config (5): elastic-field population, ckpt override, defaults, idempotency, fallback when no pretrained_config.
  • Patches (4): attention selector — other models untouched / PawQwen3 routes to elastic; RoPE — predicate match / YaRN extraction.
  • Xattention estimate (4): decode path all-true, sink+diagonal kept, threshold=1 causal, output shape/dtype.
  • BSA Paddle wrapper (7): replace_ones_with_count (all-sparse / mixed / no-sparse), convert_blockmask_row_reverse (all-kept / all-dropped / descending+padding), BlockSparseAttn smoke (full heads match dense).
  • Backend construction (5): block_size from pretrained config, elastic attrs mirrored, mask_allocator is router, router cache shapes, PawQwen3ForCausalLM declares elastic backend.

Signed-off-by: Zhenxu Tian <tianzhenxu@baidu.com>
@CLAassistant

CLAassistant commented Jun 4, 2026

Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ tianzhenxu
❌ Zhenxu Tian


Zhenxu Tian seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot

Copy link
Copy Markdown

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-06-04 21:00:46 UTC+08:00

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

当前无 required 失败任务;但 7 个 Workflow 处于 action_required 等待审批,审批后需继续观察后续 CI。可选任务中 CI_METAX / Trigger Jenkins for PR 失败,仅供参考。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
2(0) 2 1 1 0 0 0

⚠️ 注意:以下 7 个 Workflow 处于 action_required 状态(等待审批后才会执行):CI_XPUCheck PR TemplateCodestyle-CheckCI_HPUApprovalPR Build and TestILUVATAR-CI。这些 Workflow 需人工审批触发。

注意:action_required workflows 不计入上表的任务统计。

2 任务状态汇总

日志列说明:失败任务直接使用 log_links_markdown 字段(已预生成),运行中任务手动拼接 [Job]({html_url})

2.1 Required任务 : 0/0 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
无已触发的必选任务 - - - - -

2.2 可选任务 — 1/2 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
CI_METAX / Trigger Jenkins for PR 8m2s Job -
其余 1 个可选任务通过 - - -

3 失败详情(仅 required)

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Signed-off-by: Zhenxu Tian <tianzhenxu@baidu.com>
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-06-05 12:04:02

📋 Review 摘要

PR 概述:集成 Elastic-Attention (PawQwen3) 为 FastDeploy 新增模型架构,包含 BSA CUDA 自定义算子、弹性注意力后端及 Qwen3 权重加载适配。
变更范围custom_ops/gpu_ops/block_sparse_attn/fastdeploy/model_executor/models/qwen3_elastic/fastdeploy/model_executor/layers/attention/
影响面 Tag[Models] [OP]

⚠️ 本 PR 变更量较大(62 文件),建议拆分以降低审查难度和合入风险。

建议拆分方案

  • PR 1: [BSA CUDA 自定义算子] — custom_ops/gpu_ops/block_sparse_attn/ 下所有文件
  • PR 2: [弹性注意力模型 + 后端] — fastdeploy/model_executor/models/qwen3_elastic/fastdeploy/model_executor/layers/attention/elastic_attn_backend.py
  • PR 3: [测试] — tests/ 下新增测试文件

问题

级别 文件 概述
🔴 Bug custom_ops/gpu_ops/block_sparse_attn/block_sparse_attn_fwd.cu:243 dummy_rng_state 是 CPU 栈内存,被作为 GPU 设备指针传给 CUDA kernel 写入(历史 F1 未修复)
🟡 建议 custom_ops/gpu_ops/block_sparse_attn/block_sparse_attn_fwd.cu:181 shape 维度用 int (32-bit) 承接 int64_t,存在溢出风险(历史 F2 未修复)
🟡 建议 fastdeploy/model_executor/models/qwen3_elastic/__init__.py:57 sys._getframe 爬调用栈查找 fd_config,调用深度假设脆弱,热路径性能开销不可忽视
🟡 建议 fastdeploy/model_executor/models/qwen3_elastic/__init__.py:161 _patched_get_rope_impl 临时改写 model_config.architectures[0] 非线程安全,TP warm-up 场景可能数据竞争
🟡 建议 大 PR 拆分建议(62 文件)

历史 Findings 修复情况

Finding 问题 状态
F1 dummy_rng_state CPU 内存传给 CUDA kernel ⚠️ 仍存在
F2 int32 承接 int64_t shape,溢出风险 ⚠️ 仍存在
F3 src 为指向私有绝对路径的符号链接 ✅ 已修复(src/ 目录已实体化,文件通过 diff 直接加入)

📝 PR 规范检查

PR 标题缺少官方 Tag 前缀,根据 diff 内容(新增模型架构 + 自定义算子)建议如下:

标题建议(可直接复制):

  • [Models] Integrate Elastic-Attention (PawQwen3) into FastDeploy
PR 描述建议(点击展开,可直接复制)
## Motivation

Integrate Elastic-Attention (PawQwen3) as a new model architecture in FastDeploy for long-context inference. Reuses FastDeploy's scheduler / KV cache / PagedAttention infrastructure and Qwen3's weight loading, QKV fusion, TP mapping, RoPE and MLP modules; only the attention layer is replaced with the Elastic-Attention (block-sparse + router) path during prefill.

## Modifications

- Added new model package `fastdeploy/model_executor/models/qwen3_elastic/`
- Added BSA CUDA custom op under `custom_ops/gpu_ops/block_sparse_attn/`
- Added `Qwen3ElasticAttentionBackend` in `elastic_attn_backend.py`
- Added architecture-aware attention backend routing patch in `qwen3_elastic/__init__.py`
- Added RoPE/YaRN patch in `rotary_embedding.get_rope_impl` for PawQwen3

## Usage or Command

python fastdeploy/model_executor/models/qwen3_elastic/run_elastic_qwen3_4b.py \
  --model_path /path/to/full_xattn_64k_qwen3-4b_wfrozen

## Accuracy Tests

N/A

## Checklist

- [ ] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

本 PR 整体架构设计合理,复用了 FastDeploy 现有的 KV Cache / TP / RoPE 基础设施,Qwen3 权重加载适配完善,测试覆盖了关键构造合约。历史 Bug F1(CPU 内存被传给 CUDA kernel 写入)仍未修复,属 P0 必须在合入前解决;F3(src 符号链接)已修复,是一个良好的进展。全局 monkey-patch 机制建议后续重构为更稳健的注册式设计。

// Inference: dropout disabled. Provide a dummy non-null rng_state buffer
// because the kernel writes into it unconditionally.
static thread_local uint64_t dummy_rng_state[2] = {0, 0};
params.rng_state = dummy_rng_state;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug(历史 F1 仍存在) dummy_rng_statestatic thread_local CPU 栈内存,但被作为设备指针传给 CUDA kernel 写入。

params.rng_state 在内核中会被无条件写入(philox seed/offset 更新路径),这是主机内存——CUDA kernel 写入主机指针会导致未定义行为(非法内存访问 / 静默数据损坏)。

建议修复:将 rng_state 分配为 GPU 显存,生命周期与本次推理调用绑定:

// 替换 static thread_local 方案:
auto dummy_rng_gpu = paddle::empty({2}, paddle::DataType::UINT64, q.place());
params.rng_state = reinterpret_cast<uint64_t*>(dummy_rng_gpu.data());
// dummy_rng_gpu 离开作用域自动释放(RAII)

const int head_size = static_cast<int>(q_shape[2]);
const int total_k = static_cast<int>(k_shape[0]);
const int num_heads_k = static_cast<int>(k_shape[1]);
const int batch_size = static_cast<int>(cu_seqlens_q.shape()[0]) - 1;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议(历史 F2 仍存在) total_qnum_headshead_sizetotal_knum_heads_kbatch_size 均以 int(32-bit)承接 Paddle shape(int64_t)。

static_cast<int>(q_shape[0])total_q > INT_MAX 时会静默溢出,属 FastDeploy checklist §C 必报项。建议在 static_cast 前加范围断言:

PD_CHECK(q_shape[0] <= static_cast<int64_t>(INT_MAX), "total_q exceeds INT_MAX");
const int total_q = static_cast<int>(q_shape[0]);
// 对 num_heads, head_size, total_k, num_heads_k, batch_size 做同样处理


def _caller_arch():
"""Walk the call stack to find ``self.fd_config.model_config.architectures``."""
frame = _sys._getframe(2) # skip this fn + the patched selector fn

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _caller_arch() 使用 sys._getframe 爬调用栈来获取 fd_config 是脆弱设计。

问题:

  1. _getframe(2) 硬编码调用深度;装饰器、未来封装层或 JIT 优化都可能使偏移失效,导致无声退化为原始 backend。
  2. sys._getframe 是 CPython 私有 API,非标准行为。
  3. 热路径性能开销:每次 attention backend 路由都需遍历整个调用栈。

建议:将架构判断提升到初始化阶段——在 PawQwen3ForCausalLM.__init__ 中直接通过注册接口指定 Qwen3ElasticAttentionBackend,或将 fd_config 显式传入 selector,避免运行时栈爬取。

# ``QwenRotaryEmbedding``.
original = model_config.architectures[0]
try:
model_config.architectures[0] = "Qwen3" + original

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 临时改写 model_config.architectures[0] 并在 finally 中还原不是线程安全的。

在 TP 多线程 warm-up 场景下,若两个线程同时调用 _patched_get_rope_impl,线程 A 将 architectures[0] 修改为 "Qwen3PawQwen3ForCausalLM" 尚未还原时,线程 B 已读取该值,可能触发错误的 RoPE 路径或 YaRN 参数。

建议:不要修改共享的 model_config 对象,改为向 _orig_get_rope_impl 传递一个临时配置副本:

import copy
tmp_cfg = copy.copy(model_config)
tmp_cfg.architectures = ["Qwen3" + model_config.architectures[0]]
return _orig_get_rope_impl(rotary_dim, base, position_ids, tmp_cfg, partial_rotary_factor)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants