Skip to content

feat: intracard cp for sm90#86

Open
Hyaloid wants to merge 1 commit into
inclusionAI:mainfrom
Hyaloid:intcd-cp
Open

feat: intracard cp for sm90#86
Hyaloid wants to merge 1 commit into
inclusionAI:mainfrom
Hyaloid:intcd-cp

Conversation

@Hyaloid
Copy link
Copy Markdown

@Hyaloid Hyaloid commented Jun 4, 2026

📌 Description

The serial bottleneck

kda_prefill_hopper (cuLA's SM90 KDA prefill) launches one CTA per (seq, head) and runs a strictly
sequential chunk recurrence inside each sequence: h_t = decay(g_t) · h_{t-1} + k_t^T @ (u_t − w_t·h_{t-1}).
Within one sequence, work cannot parallelize across chunks — only across the (raw_batch × H) grid.

This becomes a bottleneck when both:

  1. raw_batch × H is small — the baseline grid under‑utilizes the SMs. A single long sequence at
    H=8 occupies only 8 CTAs on a 132‑SM H100 (~6% occupancy). The per‑SM work is so small that most of the card is idle waiting on 8 serial chains.
  2. The shape has a long‑tail sequence (e.g. 128K+1K packed) — the long seq's serial recurrence
    dominates wall time while short seqs finish in microseconds and leave SMs idle.

Approach

Mirroring FLA's intra‑card CP design (and the SM100 cuLA path in cula/ops/cp/chunk_delta_h.py),
this PR splits long sequences into CP‑chunks on the same card and produces per‑CP‑chunk initial
states so the main C++ kernel can run all CP‑chunks in parallel.

🔍 Related Issues

Similar to this issue #20 , but for SM90.

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.
clang-format.............................................................Passed
ruff (legacy alias)......................................................Passed
ruff format..............................................................Passed

🧪 Tests

python -m pytest tests/test_intracard_cp_sm90.py -v

======================================================== test session starts =========================================================
collected 35 items                                                                                                                   

tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens0-4-False] PASSED                                      [  2%]
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens1-4-True] PASSED                                       [  5%]
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens2-4-True] PASSED                                       [  8%]
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens3-8-False] PASSED                                      [ 11%]
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens4-8-True] PASSED                                       [ 14%]
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens5-4-True] PASSED                                       [ 17%]
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens6-4-True] PASSED                                       [ 20%]
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens7-4-False] PASSED                                      [ 22%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens0-4-False] PASSED                             [ 25%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens1-4-True] PASSED                              [ 28%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens2-4-True] PASSED                              [ 31%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens3-8-False] PASSED                             [ 34%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens4-8-True] PASSED                              [ 37%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens5-4-True] PASSED                              [ 40%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens6-4-True] PASSED                              [ 42%]
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens7-4-False] PASSED                             [ 45%]
tests/test_intracard_cp_sm90.py::test_cp_off_matches_basic_baseline PASSED                                                     [ 48%]
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens0-4] PASSED                                                            [ 51%]
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens1-4] PASSED                                                            [ 54%]
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens2-8] PASSED                                                            [ 57%]
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens3-4] PASSED                                                            [ 60%]
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens4-4] PASSED                                                            [ 62%]
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens5-4] PASSED                                                            [ 65%]
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens0-4-False] PASSED                                         [ 68%]
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens1-4-True] PASSED                                          [ 71%]
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens2-4-True] PASSED                                          [ 74%]
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens3-4-False] PASSED                                         [ 77%]
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens4-4-True] PASSED                                          [ 80%]
tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[single-64K-H4-h0] PASSED                                                [ 82%]
tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[multi-64K+4K-H4-h0] PASSED                                              [ 85%]
tests/test_intracard_cp_sm90.py::test_cp_h0_none_equiv_h0_zeros PASSED                                                         [ 88%]
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens0-8] PASSED                                              [ 91%]
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens1-64] PASSED                                             [ 94%]
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens2-8] PASSED                                              [ 97%]
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens3-8] PASSED                                              [100%]

======================================================== 35 passed in 35.22s =========================================================

  • Tests have been added or updated as needed.
  • All tests are passing.

⚡ Performance

python benchmarks/bench_intracard_cp_sm90.py

====================================================================================================
 Intracard CP Benchmark (SM90): CP-on vs CP-off
====================================================================================================


==============================================================================================================
                       BENCHMARK REPORT: Intracard CP (SM90)
                       CP-on (kda_prefill_hopper_auto) vs CP-off (kda_prefill_hopper)
                       D=128  dtype=bf16  safe_gate=True
                       Warmup=25  Iters=100
==============================================================================================================

  [H=4]
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub  │         o max/mean        ht max/mean  │  CP_off(ms)   CP_on(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  T=4K                        4096     Y    16  │  2.4e-04/5.1e-07  1.3e-06/1.5e-10  │      0.7637      0.4657     1.64x
  T=8K                        8192     Y    32  │  2.4e-04/2.8e-08  0.0e+00/0.0e+00  │      1.5123      0.4712     3.21x
  T=32K                      32768     Y    32  │  2.4e-04/7.7e-08  0.0e+00/0.0e+00  │      6.0002      1.3043     4.60x
  T=64K                      65536     Y    32  │  3.1e-04/4.1e-07  5.6e-06/5.0e-10  │     11.9816      2.4318     4.93x
  T=128K                    131072     Y    32  │  2.4e-04/7.1e-09  0.0e+00/0.0e+00  │     23.9463      4.7367     5.06x
  8x4K                       32768     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.8402      0.8382     1.00x
  4x8K                       32768     Y    32  │  2.4e-04/6.9e-08  0.0e+00/0.0e+00  │      1.5797      1.4430     1.09x
  2x16K                      32768     Y    32  │  2.4e-04/7.5e-08  0.0e+00/0.0e+00  │      3.0526      1.4272     2.14x
  16K+16K                    32768     Y    32  │  2.4e-04/7.5e-08  0.0e+00/0.0e+00  │      3.0481      1.4233     2.14x
  24K+8K                     32768     Y    32  │  2.4e-04/7.4e-08  0.0e+00/0.0e+00  │      4.5231      1.4288     3.17x
  28K+4K                     32768     Y    32  │  2.4e-04/7.6e-08  0.0e+00/0.0e+00  │      5.2616      1.4207     3.70x
  32K+256+256                33280     Y    34  │  2.4e-04/7.6e-08  0.0e+00/0.0e+00  │      6.0024      1.4963     4.01x
  40K+1K+8K                  50176     Y    25  │  3.7e-04/2.4e-07  0.0e+00/0.0e+00  │      7.5130      2.1328     3.52x
  64K+512+256+128            66432     Y    35  │  3.1e-04/4.1e-07  5.6e-06/1.2e-10  │     11.9893      2.7907     4.30x
  128K+1K                   132096     Y    33  │  2.4e-04/7.1e-09  0.0e+00/0.0e+00  │     23.9516      5.2825     4.53x
  128K+2x1K                 133120     Y    34  │  1.2e-04/4.0e-10  0.0e+00/0.0e+00  │     23.9631      5.3962     4.44x
  128K+5x1K                 136192     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     23.9713     23.9775     1.00x
  128K+10x1K                141312     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     23.9883     23.9917     1.00x
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

  [H=8]
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub  │         o max/mean        ht max/mean  │  CP_off(ms)   CP_on(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  T=4K                        4096     Y    16  │  2.4e-04/6.3e-08  3.7e-09/2.8e-14  │      0.7778      0.4911     1.58x
  T=8K                        8192     Y    16  │  1.2e-04/1.5e-08  0.0e+00/0.0e+00  │      1.5338      0.5403     2.84x
  T=32K                      32768     Y    16  │  4.9e-04/2.2e-07  3.4e-05/3.4e-10  │      6.0858      1.8168     3.35x
  T=64K                      65536     Y    16  │  1.2e-04/4.2e-09  0.0e+00/0.0e+00  │     12.1578      3.5025     3.47x
  T=128K                    131072     Y    16  │  2.4e-04/6.0e-08  1.5e-06/1.3e-11  │     24.3007      6.7844     3.58x
  8x4K                       32768     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.9448      0.9410     1.00x
  4x8K                       32768     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      1.6544      1.6561     1.00x
  2x16K                      32768     Y    16  │  4.9e-04/1.9e-07  3.4e-05/1.7e-10  │      3.1324      1.9478     1.61x
  16K+16K                    32768     Y    16  │  4.9e-04/1.9e-07  3.4e-05/1.7e-10  │      3.1362      1.9394     1.62x
  24K+8K                     32768     Y    16  │  4.9e-04/2.0e-07  3.4e-05/1.7e-10  │      4.6130      1.9580     2.36x
  28K+4K                     32768     Y    16  │  4.9e-04/2.0e-07  3.4e-05/3.8e-10  │      5.3521      1.9473     2.75x
  32K+256+256                33280     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      6.0890      6.0885     1.00x
  40K+1K+8K                  50176     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      7.6568      7.6582     1.00x
  64K+512+256+128            66432     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     12.1620     12.1602     1.00x
  128K+1K                   132096     Y    17  │  2.4e-04/5.2e-09  0.0e+00/0.0e+00  │     24.2970      7.6274     3.19x
  128K+2x1K                 133120     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     24.3063     24.3050     1.00x
  128K+5x1K                 136192     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     24.3240     24.3254     1.00x
  128K+10x1K                141312     N     0  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     24.3527     24.3539     1.00x
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

  CP triggered (25 configs): geo-mean=2.92x  best=5.06x  worst=1.09x
  CP bypassed  (11 configs): mean overhead=1.000x  max=1.001x  (1.00 = no regression)
  Accuracy (CP-on vs CP-off): o  max=4.88e-04 avg=2.00e-04   ht max=3.41e-05 avg=5.12e-06

==============================================================================================================

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized Hopper (SM90) KDA prefill path featuring fused gate and L2-norm preprocessing, along with intra-card CP (chunk-parallel) scheduling. Key feedback includes optimizing cp_context.py to avoid a synchronous D2H copy by computing sequence mappings on the CPU, adding device validation checks in the C++ API to prevent illegal memory accesses, and using an if/else block in the fused L2-norm Triton kernel to eliminate redundant load instructions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread cula/kda/cp_context.py Outdated
Comment thread cula/kda/cp_context.py Outdated
Comment thread cula/kda/cp_context.py Outdated
Comment thread csrc/api/kda_sm90.cu
Comment thread cula/kda/l2norm_qk_fused.py Outdated
pre-commit

adopt cr suggestions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant