feat: intracard cp for sm90#86
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an optimized Hopper (SM90) KDA prefill path featuring fused gate and L2-norm preprocessing, along with intra-card CP (chunk-parallel) scheduling. Key feedback includes optimizing cp_context.py to avoid a synchronous D2H copy by computing sequence mappings on the CPU, adding device validation checks in the C++ API to prevent illegal memory accesses, and using an if/else block in the fused L2-norm Triton kernel to eliminate redundant load instructions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
pre-commit adopt cr suggestions
📌 Description
The serial bottleneck
kda_prefill_hopper (cuLA's SM90 KDA prefill) launches one CTA per (seq, head) and runs a strictly
sequential chunk recurrence inside each sequence: h_t = decay(g_t) · h_{t-1} + k_t^T @ (u_t − w_t·h_{t-1}).
Within one sequence, work cannot parallelize across chunks — only across the (raw_batch × H) grid.
This becomes a bottleneck when both:
H=8 occupies only 8 CTAs on a 132‑SM H100 (~6% occupancy). The per‑SM work is so small that most of the card is idle waiting on 8 serial chains.
dominates wall time while short seqs finish in microseconds and leave SMs idle.
Approach
Mirroring FLA's intra‑card CP design (and the SM100 cuLA path in cula/ops/cp/chunk_delta_h.py),
this PR splits long sequences into CP‑chunks on the same card and produces per‑CP‑chunk initial
states so the main C++ kernel can run all CP‑chunks in parallel.
🔍 Related Issues
Similar to this issue #20 , but for SM90.
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
python -m pytest tests/test_intracard_cp_sm90.py -v⚡ Performance
python benchmarks/bench_intracard_cp_sm90.py