From fbc6487e1ee1e5060691468c9b625d209c10ebf7 Mon Sep 17 00:00:00 2001 From: Hyaloid Date: Mon, 15 Jun 2026 02:40:58 +0000 Subject: [PATCH] feat: intracard cp for sm90 pre-commit adopt cr suggestions support varlen fuse l2norm+gate cumsum & fix irregular input --- benchmarks/bench_intracard_cp.py | 2 +- benchmarks/bench_intracard_cp_sm90.py | 362 +++++++++++++++ csrc/api/kda_sm90.cu | 49 +- csrc/api/pybind.cu | 23 +- csrc/kda/sm90/collective/mainloop_kda_fwd.hpp | 28 +- csrc/kda/sm90/kda_fwd_sm90.cu | 25 +- csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu | 122 ++--- csrc/kda/sm90/kernel/kernel_kda_fwd.hpp | 5 + csrc/kda/sm90/prefill_kernel.hpp | 5 +- csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh | 8 +- cula/kda/__init__.py | 4 + cula/kda/auto_route.py | 104 +++++ cula/kda/cp_context.py | 426 +++++++++++++++++ cula/kda/cp_h_boundary.py | 190 ++++++++ cula/kda/gate_l2norm_fused.py | 194 ++++++++ cula/kda/hopper_fused_fwd_opt.py | 374 +++++++++++++++ cula/kda/l2norm_qk_fused.py | 119 +++++ cula/kda/wy_intra.py | 354 ++++++++++++++ cula/kda/wy_recompute.py | 137 ++++++ tests/conftest.py | 7 +- tests/test_intracard_cp_sm90.py | 439 ++++++++++++++++++ 21 files changed, 2865 insertions(+), 112 deletions(-) create mode 100644 benchmarks/bench_intracard_cp_sm90.py create mode 100644 cula/kda/auto_route.py create mode 100644 cula/kda/cp_context.py create mode 100644 cula/kda/cp_h_boundary.py create mode 100644 cula/kda/gate_l2norm_fused.py create mode 100644 cula/kda/hopper_fused_fwd_opt.py create mode 100644 cula/kda/l2norm_qk_fused.py create mode 100644 cula/kda/wy_intra.py create mode 100644 cula/kda/wy_recompute.py create mode 100644 tests/test_intracard_cp_sm90.py diff --git a/benchmarks/bench_intracard_cp.py b/benchmarks/bench_intracard_cp.py index f24fb7bf..d1648599 100644 --- a/benchmarks/bench_intracard_cp.py +++ b/benchmarks/bench_intracard_cp.py @@ -59,7 +59,7 @@ # ============================================================ BT, D = 64, 128 H_VALUES = [4, 8] -WARMUP = 10 +WARMUP = 25 N_ITERS = 100 NCU_MODE = False SANITIZER_MODE = False diff --git a/benchmarks/bench_intracard_cp_sm90.py b/benchmarks/bench_intracard_cp_sm90.py new file mode 100644 index 00000000..c9316432 --- /dev/null +++ b/benchmarks/bench_intracard_cp_sm90.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""bench_intracard_cp_sm90.py — CP-on vs CP-off for SM90 KDA prefill. + +Mirrors benchmarks/bench_intracard_cp.py (SM100 version) but for the Hopper +(SM90) path: + + CP_on : cula.kda.kda_prefill_hopper_auto + CP_off : cula.kda.kda_prefill_hopper + +Reports per-config `pred` (would CP fire?) and `n_sub` (CP-chunk count). When +`pred=N` we still measure CP_on to confirm the bypass adds no regression. + +Usage: + python benchmarks/bench_intracard_cp_sm90.py [--ncu] [--sanitizer] +""" + +import argparse +import pathlib +import sys + +import torch + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) + +from benchmarks.utils import ( + SEED, + exclusive_cumsum, + prepare_safe_gate_inputs, + set_seed, + time_cuda_fn, +) +from cula.kda import kda_prefill_hopper, kda_prefill_hopper_auto +from cula.kda.auto_route import _should_use_opt +from cula.kda.cp_context import _calc_cp_seqs +from cula.kda.hopper_fused_fwd_opt import FUSED_GATE_L2NORM_VARLEN_AVG_SEQ, _fused_gate_l2norm_threshold +from cula.utils import get_device_sm_count + +# ============================================================ +# Constants +# ============================================================ +BT, D = 64, 128 +H_VALUES = [4, 8] +WARMUP = 10 +N_ITERS = 10 +NCU_MODE = False +SANITIZER_MODE = False + +# (tag, seq_lens) — varlen configs, run with cu_seqlens=cumsum(seq_lens) +CONFIGS = [ + # small varlen — exercises fused gate+l2norm path (packed_T*H <= 65536) + ("4x256", [256] * 4), + ("8x256", [256] * 8), + ("16x256", [256] * 16), + ("4x1K", [1024] * 4), + ("8x1K", [1024] * 8), + ("4x2K", [2048] * 4), + ("1K+512+256+128", [1024, 512, 256, 128]), + ("2K+1K+512+256", [2048, 1024, 512, 256]), + ("1K+1+63+65+129", [1024, 1, 63, 65, 129]), + # single seq + ("T=4K", [4096]), + ("T=8K", [8192]), + ("T=32K", [32768]), + ("T=64K", [65536]), + ("T=128K", [131072]), + # equal-length batches (~32K total) + ("8x4K", [4096] * 8), + ("4x8K", [8192] * 4), + ("2x16K", [16384] * 2), + # asymmetric multi-seq + ("16K+16K", [16384, 16384]), + ("24K+8K", [24576, 8192]), + ("28K+4K", [28672, 4096]), + ("32K+256+256", [32768, 256, 256]), + ("40K+1K+8K", [40960, 1024, 8192]), + ("64K+512+256+128", [65536, 512, 256, 128]), + ("128K+1K", [131072, 1024]), + ("128K+2x1K", [131072, 1024, 1024]), + ("128K+5x1K", [131072] + [1024] * 5), + ("128K+10x1K", [131072] + [1024] * 10), +] + + +# ============================================================ +# Helpers +# ============================================================ +def _bench_warmup_iters(): + warmup = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP + n_iters = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS + return warmup, n_iters + + +def run_call(q, k, v, g, beta, scale, A_log, dt_bias, cu_seqlens, lower_bound, *, enable_cp, return_state=False): + fn = kda_prefill_hopper_auto if enable_cp else kda_prefill_hopper + out = fn( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + initial_state=None, + output_final_state=return_state, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + safe_gate=True, + lower_bound=lower_bound, + cu_seqlens=cu_seqlens, + ) + return out + + +def accuracy(ref, got): + if ref is None or got is None: + return float("nan"), float("nan") + diff = (ref.float() - got.float()).abs() + return diff.max().item(), diff.mean().item() + + +def predict_cp(seq_lens, H, num_sms, device): + cu = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + raw_batch = len(seq_lens) + packed_seq = sum(seq_lens) + + if raw_batch > 1: + cp_wf = raw_batch * H <= 16 and packed_seq >= 8192 + else: + cp_wf = (H <= 8 and packed_seq >= 4096) or (H <= 16 and packed_seq >= 4096) or (H <= 32 and packed_seq >= 16384) + if not cp_wf: + return False, 0 + + use_cp, cp_cu, *_ = _calc_cp_seqs(cu, BT, H, num_sms, raw_cu_seqlens_cpu=cu.cpu()) + if not use_cp: + return False, 0 + n_sub = int(cp_cu.numel() - 1) + if n_sub == raw_batch: # no-op split + return False, 0 + return True, n_sub + + +def predict_fused_all_pre(q, v, cu_seqlens_for_opt, *, cu_seqlens_is_none, use_gate_in_kernel, use_qk_l2norm_in_kernel): + if not _should_use_opt(q, cu_seqlens_for_opt): + return False + num_qk_heads = q.shape[-2] + num_v_heads = v.shape[-2] + if cu_seqlens_is_none: + avg_seq_ok = True + else: + N = cu_seqlens_for_opt.numel() - 1 + packed_T = q.shape[1] + avg_seq_ok = N <= 1 or packed_T <= N * FUSED_GATE_L2NORM_VARLEN_AVG_SEQ + return ( + use_gate_in_kernel + and use_qk_l2norm_in_kernel + and (q.numel() // q.shape[-1]) <= _fused_gate_l2norm_threshold(cu_seqlens_is_none) + and num_qk_heads == num_v_heads + and avg_seq_ok + ) + + +# ============================================================ +# Benchmark +# ============================================================ +SEP = " " + "─" * 138 +ROW_HEADER = ( + f" {'config':<24s} {'T':>7s} {'pred':>4s} {'sub':>4s} {'fused_pre':>5s}" + f" │ {'o max/mean':>17s} {'ht max/mean':>17s}" + f" │ {'CP_off(ms)':>10s} {'CP_on(ms)':>10s} {'Speedup':>8s}" +) + + +def _format_row(r): + pred_s = "Y" if r["pred"] else "N" + fused_s = "Y" if r["fused_all_pre"] else "N" + return ( + f" {r['tag']:<24s} {r['total_T']:>7d} {pred_s} {r['n_sub']:>4d} {fused_s}" + f" │ {r['o_max']:>7.1e}/{r['o_mean']:>7.1e} {r['ht_max']:>7.1e}/{r['ht_mean']:>7.1e}" + f" │ {r['ms_off']:>10.4f} {r['ms_on']:>10.4f} {r['speedup']:>7.2f}x" + ) + + +def bench_cp(h_values, configs): + print("\n" + "=" * 110) + print(" BENCHMARK REPORT: Intracard CP (SM90)") + print(" CP-on (kda_prefill_hopper_auto) vs CP-off (kda_prefill_hopper)") + print(f" D={D} dtype=bf16 safe_gate=True") + wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP + ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS + mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") + print(f" Warmup={wu} Iters={ni}{mode_tag}") + print("=" * 110) + + device = torch.device("cuda") + num_sms = get_device_sm_count(device) + results = [] + + for H in h_values: + print(f"\n [H={H}]", flush=True) + print(SEP, flush=True) + print(ROW_HEADER, flush=True) + print(SEP, flush=True) + + for tag, seq_lens in configs: + set_seed(SEED) + torch.cuda.empty_cache() + + total_T = sum(seq_lens) + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + inputs = prepare_safe_gate_inputs(1, total_T, H, D, device, cu_seqlens=cu_seqlens, seed=SEED) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] + scale, lower_bound = inputs["scale"], inputs["lower_bound"] + + pred, n_sub = predict_cp(seq_lens, H, num_sms, device) + fused_all_pre = predict_fused_all_pre( + q, + v, + cu_seqlens, + cu_seqlens_is_none=False, + use_gate_in_kernel=True, + use_qk_l2norm_in_kernel=True, + ) + + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) + + try: + o_off, ht_off = run_call(**common, enable_cp=False, return_state=True) + o_on, ht_on = run_call(**common, enable_cp=True, return_state=True) + o_max, o_mean = accuracy(o_off, o_on) + ht_max, ht_mean = accuracy(ht_off, ht_on) + del o_off, ht_off, o_on, ht_on + + ms_off = time_cuda_fn(lambda: run_call(**common, enable_cp=False), *_bench_warmup_iters()) + ms_on = time_cuda_fn(lambda: run_call(**common, enable_cp=True), *_bench_warmup_iters()) + speedup = ms_off / ms_on if ms_on > 0 else float("inf") + except torch.cuda.OutOfMemoryError: + ms_off = ms_on = speedup = float("nan") + o_max = o_mean = ht_max = ht_mean = float("nan") + + row = { + "tag": tag, + "H": H, + "total_T": total_T, + "pred": pred, + "n_sub": n_sub, + "fused_all_pre": fused_all_pre, + "ms_off": ms_off, + "ms_on": ms_on, + "speedup": speedup, + "o_max": o_max, + "o_mean": o_mean, + "ht_max": ht_max, + "ht_mean": ht_mean, + } + results.append(row) + print(_format_row(row), flush=True) + + del q, k, v, g, beta, A_log, dt_bias, inputs + torch.cuda.empty_cache() + + print(SEP, flush=True) + + return results + + +# ============================================================ +# Report (summary only — per-row output is streamed inside bench_cp) +# ============================================================ +def print_report(results, h_values): + sep = "=" * 110 + triggered = [r for r in results if r["pred"]] + bypassed = [r for r in results if not r["pred"]] + + print() + print(sep) + print(" Summary") + print(sep) + + if triggered: + speedups = [r["speedup"] for r in triggered if r["speedup"] == r["speedup"]] # NaN filter + if speedups: + geo = 1.0 + for s in speedups: + geo *= s + geo = geo ** (1 / len(speedups)) + print( + f" CP triggered ({len(triggered)} configs): " + f"geo-mean={geo:.2f}x best={max(speedups):.2f}x worst={min(speedups):.2f}x" + ) + + if bypassed: + ratios = [r["ms_on"] / r["ms_off"] for r in bypassed if r["ms_off"] == r["ms_off"] and r["ms_off"] > 0] + if ratios: + print( + f" CP bypassed ({len(bypassed)} configs): " + f"mean overhead={sum(ratios) / len(ratios):.3f}x max={max(ratios):.3f}x " + f"(1.00 = no regression)" + ) + + o_maxes = [r["o_max"] for r in results if r["o_max"] == r["o_max"]] + ht_maxes = [r["ht_max"] for r in results if r["ht_max"] == r["ht_max"]] + if o_maxes: + print( + f" Accuracy (CP-on vs CP-off): " + f"o max={max(o_maxes):.2e} avg={sum(o_maxes) / len(o_maxes):.2e} " + f"ht max={max(ht_maxes):.2e} avg={sum(ht_maxes) / len(ht_maxes):.2e}" + ) + + print(sep) + + +# ============================================================ +# Main +# ============================================================ +def main(): + parser = argparse.ArgumentParser(description="bench_intracard_cp_sm90: CP-on vs CP-off") + parser.add_argument("--ncu", action="store_true", help="NCU profiling mode: warmup=1, iters=1") + parser.add_argument("--sanitizer", action="store_true", help="Sanitizer mode: warmup=1, iters=1") + args = parser.parse_args() + + global NCU_MODE, SANITIZER_MODE + if args.ncu: + NCU_MODE = True + print("[NCU mode] warmup=1, iters=1") + if args.sanitizer: + SANITIZER_MODE = True + print("[Sanitizer mode] warmup=1, iters=1") + + results = bench_cp(H_VALUES, CONFIGS) + print_report(results, H_VALUES) + return results + + +if __name__ == "__main__": + main() diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 9e016eb1..bd7ba5ee 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -35,7 +35,9 @@ kda_fwd_prefill( torch::Tensor workspace_buffer, float scale, bool output_final_state, - bool safe_gate) { + bool safe_gate, + OptionalTensor cp_seq_map_, + OptionalTensor raw_cu_seqlens_) { // Q, K: [packed_seq, num_qk_heads, D] // V/O/g: [packed_seq, num_v_heads, D] (GVA: num_v_heads is a positive integer multiple of num_qk_heads) auto packed_seq = q.size(0); @@ -44,6 +46,31 @@ kda_fwd_prefill( auto head_size = q.size(2); auto num_seqs = cu_seqlens.size(0) - 1; + // Intra-card CP plumbing. + int32_t const* cp_seq_map_ptr = nullptr; + int32_t const* raw_cu_seqlens_ptr = nullptr; + int32_t raw_num_seqs = static_cast(num_seqs); + if (cp_seq_map_.has_value()) { + TORCH_CHECK(raw_cu_seqlens_.has_value(), "raw_cu_seqlens must be provided alongside cp_seq_map"); + auto const& cp_seq_map = cp_seq_map_.value(); + auto const& raw_cu_seqlens = raw_cu_seqlens_.value(); + TORCH_CHECK(cp_seq_map.device() == q.device(), "cp_seq_map must be on the same device as q"); + TORCH_CHECK(raw_cu_seqlens.device() == q.device(), "raw_cu_seqlens must be on the same device as q"); + TORCH_CHECK(cp_seq_map.dtype() == torch::kInt32, "cp_seq_map must be int32"); + TORCH_CHECK(raw_cu_seqlens.dtype() == torch::kInt32, "raw_cu_seqlens must be int32"); + TORCH_CHECK(cp_seq_map.is_contiguous(), "cp_seq_map must be contiguous"); + TORCH_CHECK(raw_cu_seqlens.is_contiguous(), "raw_cu_seqlens must be contiguous"); + TORCH_CHECK( + cp_seq_map.size(0) == num_seqs, + "cp_seq_map.size(0) must equal cu_seqlens.size(0)-1, got ", + cp_seq_map.size(0), + " vs ", + num_seqs); + cp_seq_map_ptr = cp_seq_map.data_ptr(); + raw_cu_seqlens_ptr = raw_cu_seqlens.data_ptr(); + raw_num_seqs = static_cast(raw_cu_seqlens.size(0) - 1); + } + // GVA contract on the C++ side. Order matters: check positivity *before* the modulo to // avoid % 0 / division-by-zero UB in case the Python layer passed a degenerate shape. TORCH_CHECK(num_qk_heads > 0, "KDA requires num_qk_heads > 0, got ", num_qk_heads); @@ -64,15 +91,15 @@ kda_fwd_prefill( {packed_seq, num_v_heads, head_size}, torch::TensorOptions().dtype(q.dtype()).device(q.device())); - // output_final_state controls the API side effect. If it is false, ignore - // even an explicitly provided output_state_ buffer so the kernel skips the - // final-state store. + // Allocate output state if not provided. In CP mode the state is keyed + // by raw_num_seqs (one slot per original sequence), not the inflated + // CP-chunk count. OptionalTensor output_state = std::nullopt; if (output_final_state) { output_state = output_state_.has_value() ? output_state_.value() : torch::zeros( - {num_seqs, num_v_heads, head_size, head_size}, + {raw_num_seqs, num_v_heads, head_size, head_size}, torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); } @@ -123,7 +150,7 @@ kda_fwd_prefill( auto& input_state = input_state_.value(); TORCH_CHECK(input_state.dtype() == torch::kFloat32, "input_state must be float32"); TORCH_CHECK(input_state.is_contiguous(), "input_state must be contiguous"); - // Defense in depth: also enforce shape on the C++ side (Python layer should already check). + // In CP mode the leading dim is the CP-chunk count (== num_seqs), otherwise it is the raw num_seqs. TORCH_CHECK( input_state.dim() == 4 && input_state.size(0) == num_seqs && input_state.size(1) == num_v_heads && input_state.size(2) == head_size && input_state.size(3) == head_size, @@ -164,7 +191,10 @@ kda_fwd_prefill( static_cast(packed_seq), scale, safe_gate, - static_cast(sm_count)); + static_cast(sm_count), + cp_seq_map_ptr, + raw_cu_seqlens_ptr, + raw_num_seqs); } else { float const* beta_ptr = beta_.has_value() ? beta_.value().data_ptr() : nullptr; kda::sm90::launch_kda_fwd_prefill_kernel( @@ -186,7 +216,10 @@ kda_fwd_prefill( static_cast(packed_seq), scale, safe_gate, - static_cast(sm_count)); + static_cast(sm_count), + cp_seq_map_ptr, + raw_cu_seqlens_ptr, + raw_num_seqs); } return {output, output_state}; diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu index d14a41c5..5a0f6299 100644 --- a/csrc/api/pybind.cu +++ b/csrc/api/pybind.cu @@ -65,7 +65,9 @@ kda_fwd_prefill( torch::Tensor workspace_buffer, float scale, bool output_final_state, - bool safe_gate); + bool safe_gate, + std::optional cp_seq_map_, + std::optional raw_cu_seqlens_); #endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -75,6 +77,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("recompute_w_u_cuda", &ChunkKDAFwdRecompWU); #endif #if defined(CULA_SM90A_ENABLED) - m.def("kda_fwd_prefill", &kda_fwd_prefill); + m.def( + "kda_fwd_prefill", + &kda_fwd_prefill, + pybind11::arg("output_"), + pybind11::arg("output_state_"), + pybind11::arg("q"), + pybind11::arg("k"), + pybind11::arg("v"), + pybind11::arg("input_state_"), + pybind11::arg("alpha_"), + pybind11::arg("beta_"), + pybind11::arg("cu_seqlens"), + pybind11::arg("workspace_buffer"), + pybind11::arg("scale"), + pybind11::arg("output_final_state"), + pybind11::arg("safe_gate"), + pybind11::arg("cp_seq_map_") = std::nullopt, + pybind11::arg("raw_cu_seqlens_") = std::nullopt); #endif } diff --git a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp index 301dbfd4..66563e5a 100644 --- a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp +++ b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp @@ -930,14 +930,26 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { v_head_idx); return; } + // Intra-card CP + int32_t out_seq_idx = seq_idx; + int32_t out_num_seqs = problem_size.num_seqs; + if (problem_size.cp_seq_map != nullptr) { + out_seq_idx = problem_size.cp_seq_map[seq_idx]; + out_num_seqs = problem_size.raw_num_seqs; + int32_t this_end = problem_size.cu_seqlens[seq_idx + 1]; + int32_t raw_end = problem_size.raw_cu_seqlens[out_seq_idx + 1]; + if (this_end != raw_end) { + return; + } + } DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); // GVA: state is stored per V/O head. int num_state_heads = problem_size.num_v_heads; int state_head_idx = work_desc.o_head_idx(); auto gKV = make_tensor( make_gmem_ptr(params.ptr_output_state), - make_layout(make_shape(Int{}, Int{}, num_state_heads, problem_size.num_seqs)))( - _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous + make_layout(make_shape(Int{}, Int{}, num_state_heads, out_num_seqs)))( + _, _, state_head_idx, out_seq_idx); // (KDim, VDim), K-contiguous auto tiled_copy_kv = make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); @@ -1371,10 +1383,18 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { if constexpr (!kInitStateFromInput) { clear(tKVrKV); - compute_loop_body(0, /*is_first_block_=*/cute::true_type{}, /*is_final_block_=*/cute::false_type{}); + if (num_blocks == 1) { + compute_loop_body(0, /*is_first_block_=*/cute::true_type{}, /*is_final_block_=*/cute::true_type{}); + } else { + compute_loop_body(0, /*is_first_block_=*/cute::true_type{}, /*is_final_block_=*/cute::false_type{}); + } } else { kv_load(tKVrKV); // GMEM -> Register, only once at the beginning - compute_loop_body(0, /*is_first_block_=*/cute::false_type{}, /*is_final_block_=*/cute::false_type{}); + if (num_blocks == 1) { + compute_loop_body(0, /*is_first_block_=*/cute::false_type{}, /*is_final_block_=*/cute::true_type{}); + } else { + compute_loop_body(0, /*is_first_block_=*/cute::false_type{}, /*is_final_block_=*/cute::false_type{}); + } } CUTE_NO_UNROLL for (int blk = 1; blk < num_blocks - 1; ++blk) { diff --git a/csrc/kda/sm90/kda_fwd_sm90.cu b/csrc/kda/sm90/kda_fwd_sm90.cu index d668db9b..ed855db6 100644 --- a/csrc/kda/sm90/kda_fwd_sm90.cu +++ b/csrc/kda/sm90/kda_fwd_sm90.cu @@ -53,7 +53,10 @@ launch_kda_fwd_prefill_kernel_gbai( int32_t head_size, int64_t total_seqlen, float scale, - int32_t sm_count); + int32_t sm_count, + int32_t const* cp_seq_map, + int32_t const* raw_cu_seqlens, + int32_t raw_num_seqs); template < typename ArchTag, // TODO: hide this @@ -81,7 +84,10 @@ launch_kda_fwd_prefill_kernel( int64_t total_seqlen, float scale, bool safe_gate, - int32_t sm_count = 0) { + int32_t sm_count, + int32_t const* cp_seq_map, + int32_t const* raw_cu_seqlens, + int32_t raw_num_seqs) { bool needs_beta = beta != nullptr; bool needs_alpha = alpha != nullptr; bool init_state = input_state != nullptr; @@ -105,7 +111,10 @@ launch_kda_fwd_prefill_kernel( head_size, \ total_seqlen, \ scale, \ - sm_count); + sm_count, \ + cp_seq_map, \ + raw_cu_seqlens, \ + raw_num_seqs); if (init_state) { if (needs_beta && needs_alpha && safe_gate) { LAUNCH(true, true, true, true); @@ -146,7 +155,10 @@ launch_kda_fwd_prefill_kernel( int64_t total_seqlen, float scale, bool safe_gate, - int32_t sm_count); + int32_t sm_count, + int32_t const* cp_seq_map, + int32_t const* raw_cu_seqlens, + int32_t raw_num_seqs); // TBeta=bf16 template void @@ -169,6 +181,9 @@ launch_kda_fwd_prefill_kernel( int64_t total_seqlen, float scale, bool safe_gate, - int32_t sm_count); + int32_t sm_count, + int32_t const* cp_seq_map, + int32_t const* raw_cu_seqlens, + int32_t raw_num_seqs); } // namespace kda::sm90 diff --git a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu index 309cefa0..0da2986e 100644 --- a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu +++ b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu @@ -23,92 +23,44 @@ namespace kda::sm90 { using namespace cute; using bf16 = cute::bfloat16_t; -// SafeGate=true, InitState=false -template void -launch_kda_fwd_prefill_kernel_gbai( - cudaStream_t, - bf16*, - float*, - bf16 const*, - bf16 const*, - bf16 const*, - float const*, - float const*, - float const*, - int32_t const*, - uint8_t*, - int32_t, - int32_t, - int32_t, - int32_t, - int64_t, - float, - int32_t); +#define INSTANTIATE_GBAI(NeedsBeta, NeedsAlpha, InitState, SafeGate, TBeta) \ + template void launch_kda_fwd_prefill_kernel_gbai< \ + NeedsBeta, \ + NeedsAlpha, \ + InitState, \ + SafeGate, \ + cutlass::arch::Sm90, \ + bf16, \ + bf16, \ + float, \ + TBeta>( \ + cudaStream_t, \ + bf16*, \ + float*, \ + bf16 const*, \ + bf16 const*, \ + bf16 const*, \ + float const*, \ + float const*, \ + TBeta const*, \ + int32_t const*, \ + uint8_t*, \ + int32_t, \ + int32_t, \ + int32_t, \ + int32_t, \ + int64_t, \ + float, \ + int32_t, \ + int32_t const*, \ + int32_t const*, \ + int32_t) -// SafeGate=true, InitState=true -template void -launch_kda_fwd_prefill_kernel_gbai( - cudaStream_t, - bf16*, - float*, - bf16 const*, - bf16 const*, - bf16 const*, - float const*, - float const*, - float const*, - int32_t const*, - uint8_t*, - int32_t, - int32_t, - int32_t, - int32_t, - int64_t, - float, - int32_t); +INSTANTIATE_GBAI(true, true, false, true, float); +INSTANTIATE_GBAI(true, true, true, true, float); +INSTANTIATE_GBAI(true, true, false, true, bf16); +INSTANTIATE_GBAI(true, true, true, true, bf16); -// SafeGate=true, InitState=false, BetaBF16 -template void -launch_kda_fwd_prefill_kernel_gbai( - cudaStream_t, - bf16*, - float*, - bf16 const*, - bf16 const*, - bf16 const*, - float const*, - float const*, - bf16 const*, - int32_t const*, - uint8_t*, - int32_t, - int32_t, - int32_t, - int32_t, - int64_t, - float, - int32_t); - -// SafeGate=true, InitState=true, BetaBF16 -template void -launch_kda_fwd_prefill_kernel_gbai( - cudaStream_t, - bf16*, - float*, - bf16 const*, - bf16 const*, - bf16 const*, - float const*, - float const*, - bf16 const*, - int32_t const*, - uint8_t*, - int32_t, - int32_t, - int32_t, - int32_t, - int64_t, - float, - int32_t); +#undef INSTANTIATE_GBAI } // namespace kda::sm90 diff --git a/csrc/kda/sm90/kernel/kernel_kda_fwd.hpp b/csrc/kda/sm90/kernel/kernel_kda_fwd.hpp index 4f8ba027..ac597f7b 100644 --- a/csrc/kda/sm90/kernel/kernel_kda_fwd.hpp +++ b/csrc/kda/sm90/kernel/kernel_kda_fwd.hpp @@ -140,6 +140,11 @@ struct FlatKernelTmaWarpSpecializedKdaFwd { int32_t num_qk_heads; int32_t num_v_heads; int32_t head_size; // d + + // For intra-card CP + int32_t const* cp_seq_map = nullptr; + int32_t const* raw_cu_seqlens = nullptr; + int32_t raw_num_seqs = 0; }; using ProblemShape = VarlenProblemShape; diff --git a/csrc/kda/sm90/prefill_kernel.hpp b/csrc/kda/sm90/prefill_kernel.hpp index d56fafae..6e54c3a3 100644 --- a/csrc/kda/sm90/prefill_kernel.hpp +++ b/csrc/kda/sm90/prefill_kernel.hpp @@ -46,6 +46,9 @@ launch_kda_fwd_prefill_kernel( int64_t total_seqlen, float scale, bool safe_gate, - int32_t sm_count = 0); + int32_t sm_count = 0, + int32_t const* cp_seq_map = nullptr, + int32_t const* raw_cu_seqlens = nullptr, + int32_t raw_num_seqs = 0); } // namespace kda::sm90 diff --git a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh index 72f13a6f..c53f2ae3 100644 --- a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh +++ b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh @@ -58,7 +58,10 @@ launch_kda_fwd_prefill_kernel_gbai( int32_t head_size, int64_t total_seqlen, float scale, - int32_t sm_count) { + int32_t sm_count, + int32_t const* cp_seq_map = nullptr, + int32_t const* raw_cu_seqlens = nullptr, + int32_t raw_num_seqs = 0) { #if defined(CULA_SM90A_ENABLED) constexpr bool HopperSupported = true; #else @@ -123,6 +126,9 @@ launch_kda_fwd_prefill_kernel_gbai( .num_qk_heads = num_qk_heads, .num_v_heads = num_v_heads, .head_size = head_size, + .cp_seq_map = cp_seq_map, + .raw_cu_seqlens = raw_cu_seqlens, + .raw_num_seqs = raw_num_seqs, }, .mainloop = { diff --git a/cula/kda/__init__.py b/cula/kda/__init__.py index ee1a2bb9..0090530d 100644 --- a/cula/kda/__init__.py +++ b/cula/kda/__init__.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from cula.kda.auto_route import cula_kda_prefill_auto as kda_prefill_hopper_auto from cula.kda.blackwell_fused_fwd import flash_kda_prefill as kda_prefill_blackwell from cula.kda.chunk import chunk_kda from cula.kda.hopper_fused_fwd import cula_kda_prefill as kda_prefill_hopper +from cula.kda.hopper_fused_fwd_opt import cula_kda_prefill_opt as kda_prefill_hopper_opt from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode __all__ = [ @@ -23,4 +25,6 @@ "kda_decode", "fused_sigmoid_gating_delta_rule_update", "kda_prefill_hopper", + "kda_prefill_hopper_opt", + "kda_prefill_hopper_auto", ] diff --git a/cula/kda/auto_route.py b/cula/kda/auto_route.py new file mode 100644 index 00000000..a7dea339 --- /dev/null +++ b/cula/kda/auto_route.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import torch + +from cula.kda.hopper_fused_fwd import cula_kda_prefill as _basic +from cula.kda.hopper_fused_fwd_opt import FUSED_GATE_L2NORM_TH_VARLEN +from cula.kda.hopper_fused_fwd_opt import cula_kda_prefill_opt as _opt + + +def _should_use_opt(q: torch.Tensor, cu_seqlens: torch.Tensor | None) -> bool: + """Pick opt vs basic based on H100 measurements.""" + B = q.shape[0] + T = q.shape[1] + H = q.shape[2] + + if cu_seqlens is not None: + N = cu_seqlens.numel() - 1 + if N > 1: + packed_T = q.shape[1] + if packed_T * H <= FUSED_GATE_L2NORM_TH_VARLEN: + return True + return N * H <= 16 and T >= 8192 + # N == 1 falls through to the single-sequence logic below. + + # Fused gate+l2norm reliably wins at very small T*H even with B>1. + if T * H <= 6000: + return True + + if B == 1: + # T=1024 H=8/16 gets a small win from fused l2norm_qk (T*H<10000). + if H <= 16 and T <= 1024: + return True + if H <= 8: + return T >= 4096 # CP kicks in + elif H <= 16: + return T >= 4096 + elif H <= 32: + return T >= 16384 + else: # H >= 64 + return False # base ties or wins + + if B == 2: + if H == 8: + return T >= 4096 + return False # B=2 H>=16 mostly ties + + # B >= 4 : B*H >= 32 already saturates a sizable fraction of SMs, CP + # buys little; basic and opt tie. Default to basic (cheaper wrapper). + return False + + +def cula_kda_prefill_auto( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = True, + use_qk_l2norm_in_kernel: bool = True, + use_gate_in_kernel: bool = True, + safe_gate: bool = True, + lower_bound: float | None = -5.0, + cu_seqlens: torch.IntTensor | None = None, + chunk_indices: torch.IntTensor | None = None, + **kwargs, +): + if _should_use_opt(q, cu_seqlens): + return _opt( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + auto_cp=True, + **kwargs, + ) + return _basic( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + **kwargs, + ) diff --git a/cula/kda/cp_context.py b/cula/kda/cp_context.py new file mode 100644 index 00000000..6c22ac52 --- /dev/null +++ b/cula/kda/cp_context.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import functools +import math +import weakref + +import torch +from fla.utils import tensor_cache + +from cula.utils import get_device_sm_count + + +@tensor_cache +def _create_cu_seqlens(batch_size: int, num_tokens: int, device_idx: int, dtype: torch.dtype) -> torch.Tensor: + return torch.arange(batch_size + 1, dtype=dtype, device=f"cuda:{device_idx}") * num_tokens + + +@functools.lru_cache(maxsize=32) +def _create_full_cu_seqlens_2(T: int, device_idx: int, dtype: torch.dtype) -> torch.Tensor: + return torch.tensor([0, T], dtype=dtype, device=f"cuda:{device_idx}") + + +_CP_SEQS_CACHE: dict = {} +_SLOT_MAP_CACHE: dict = {} + + +def _get_slot_map(cp_cu_seqlens: torch.Tensor, T: int, chunk_size: int) -> torch.Tensor: + key = (id(cp_cu_seqlens), T, chunk_size) + hit = _SLOT_MAP_CACHE.get(key) + if hit is not None: + weak_ref, cached = hit + if weak_ref() is cp_cu_seqlens: + return cached + # id was reused after the original tensor was collected — recompute. + num_chunks = (T + chunk_size - 1) // chunk_size + cp_starts = (cp_cu_seqlens[:-1] // chunk_size).to(torch.int32) + slot_map = torch.full((num_chunks,), -1, dtype=torch.int32, device=cp_cu_seqlens.device) + slots = torch.arange(cp_starts.numel(), dtype=torch.int32, device=cp_cu_seqlens.device) + slot_map[cp_starts.long()] = slots + _SLOT_MAP_CACHE[key] = (weakref.ref(cp_cu_seqlens), slot_map) + return slot_map + + +def _calc_cp_seqs_cached(raw_cu_seqlens, chunk_size, num_v_heads, sm_count, raw_cu_seqlens_cpu=None): + key = (id(raw_cu_seqlens), chunk_size, num_v_heads, sm_count) + hit = _CP_SEQS_CACHE.get(key) + if hit is not None: + weak_ref, cached = hit + if weak_ref() is raw_cu_seqlens: + return cached + + val = _calc_cp_seqs(raw_cu_seqlens, chunk_size, num_v_heads, sm_count, raw_cu_seqlens_cpu=raw_cu_seqlens_cpu) + _CP_SEQS_CACHE[key] = (weakref.ref(raw_cu_seqlens), val) + return val + + +def _calc_cp_seqs( + raw_cu_seqlens: torch.Tensor, + chunk_size: int, + num_v_heads: int, + sm_count: int, + raw_cu_seqlens_cpu: torch.Tensor | None = None, +) -> tuple[bool, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """Decide whether intra-card CP pays off and, if so, build the split tables.""" + device = raw_cu_seqlens.device + seqlen_dtype = raw_cu_seqlens.dtype + + raw_cu_seqlens_list = (raw_cu_seqlens_cpu if raw_cu_seqlens_cpu is not None else raw_cu_seqlens).tolist() + raw_batch_size = len(raw_cu_seqlens_list) - 1 + seqlens = [raw_cu_seqlens_list[i + 1] - raw_cu_seqlens_list[i] for i in range(raw_batch_size)] + num_chunks = [(s + chunk_size - 1) // chunk_size for s in seqlens] + if max(num_chunks) <= 0: + return False, None, None, None, None + + H = num_v_heads + V_BLOCKS = 1 # bump to 2 once main kernel supports V-blocking + target_cp_batch = max(1, sm_count // (H * V_BLOCKS)) + total_chunks = sum(num_chunks) + # mlc * cp_batch * chunk_size ≈ T per raw seq → mlc = total_chunks / cp_batch. + target_mlc = max(1, total_chunks // (max(1, target_cp_batch))) + # Snap to nearest power of 2; clamp to ≥ 4 to keep multi-stage pipelining alive. + max_local_chunks = 2 ** round(math.log2(max(target_mlc, 1.0))) + max_local_chunks = max(max_local_chunks, 4) + max_local_tokens = max_local_chunks * chunk_size + + cp_cu_seqlens: list[int] = [] + ht_mask: list[bool] = [] + seq_map_c2r: list[int] = [] + seq_map_r2c: list[int] = [0] + + for i, c in enumerate(num_chunks): + s = raw_cu_seqlens_list[i] + e = raw_cu_seqlens_list[i + 1] + if c > max_local_chunks: + cut = s + while True: + cp_cu_seqlens.append(cut) + ht_mask.append(False) + seq_map_c2r.append(i) + remaining = e - cut + if remaining <= max_local_tokens + chunk_size: + break + cut += max_local_tokens + ht_mask[-1] = True + else: + cp_cu_seqlens.append(s) + ht_mask.append(True) + seq_map_c2r.append(i) + seq_map_r2c.append(len(cp_cu_seqlens)) + cp_cu_seqlens.append(raw_cu_seqlens_list[-1]) + + Be = total_chunks / max(num_chunks) + use_cp = (Be * H <= 40) or (Be * H <= 56 and max(num_chunks) >= 128) + # Additional cuLA-specific guard: never bother if there is only one CP-chunk + # (no split happened). + if len(cp_cu_seqlens) - 1 == raw_batch_size: + use_cp = False + + if use_cp and raw_batch_size == 1: + T_max = max(seqlens) + if H <= 8 or H <= 16: + if T_max < 4096: + use_cp = False + elif H <= 32: + if T_max < 16384: + use_cp = False + else: # H >= 64 + use_cp = False + elif use_cp and raw_batch_size > 1: + T_packed = sum(seqlens) + native_grid = raw_batch_size * H + + unaligned = any(s % chunk_size != 0 for s in seqlens[:-1]) or (raw_cu_seqlens_list[-1] % chunk_size != 0) + if unaligned: + use_cp = False + elif native_grid > 16: + # Native grid already big enough that CP's lift is marginal. + use_cp = False + elif T_packed * H <= 32768: + use_cp = False + elif H <= 8: + if T_packed < 8192: + use_cp = False + elif H <= 16: + if T_packed < 4096: + use_cp = False + else: # H >= 32 with native_grid <= 16 → only B=1 fits this, handled above + use_cp = False + + if not use_cp: + return False, None, None, None, None + + cp_cu_seqlens_t = torch.tensor(cp_cu_seqlens, dtype=seqlen_dtype, device=device) + seq_map_c2r_t = torch.tensor(seq_map_c2r, dtype=seqlen_dtype, device=device) + seq_map_r2c_t = torch.tensor(seq_map_r2c, dtype=seqlen_dtype, device=device) + ht_mask_t = torch.tensor(ht_mask, dtype=torch.bool, device=device) + return True, cp_cu_seqlens_t, seq_map_r2c_t, seq_map_c2r_t, ht_mask_t + + +def _build_raw_seq_idx( + cp_cu_seqlens: torch.Tensor, seq_map_c2r: torch.Tensor, T: int, chunk_size: int +) -> tuple[torch.Tensor, list[int]]: + """Per-chunk raw seq id: raw_seq_idx[i_t] = raw_seq containing chunk i_t.""" + NT = (T + chunk_size - 1) // chunk_size + out = torch.empty(NT, dtype=torch.int32, device=cp_cu_seqlens.device) + # cp_cu_seqlens[:-1] // chunk_size gives the chunk-index start of each CP-chunk. + cp_starts = (cp_cu_seqlens[:-1] // chunk_size).to(torch.int64) + cp_ends = ((cp_cu_seqlens[1:] + chunk_size - 1) // chunk_size).to(torch.int64) + + c2r_cpu = seq_map_c2r.tolist() + starts = cp_starts.tolist() + ends = cp_ends.tolist() + out_cpu = [0] * NT + for i in range(len(starts)): + out[starts[i] : ends[i]] = c2r_cpu[i] + for j in range(starts[i], ends[i]): + out_cpu[j] = c2r_cpu[i] + return out, out_cpu + + +_RAW_SEQ_IDX_CACHE: dict = {} + + +def _get_raw_seq_idx( + cp_cu_seqlens: torch.Tensor, seq_map_c2r: torch.Tensor, T: int, chunk_size: int +) -> tuple[torch.Tensor, list[int]]: + """weakref-guarded cache for the per-chunk raw_seq_idx tensor.""" + key = (id(cp_cu_seqlens), T, chunk_size) + hit = _RAW_SEQ_IDX_CACHE.get(key) + if hit is not None: + weak_ref, cached = hit + if weak_ref() is cp_cu_seqlens: + return cached + val = _build_raw_seq_idx(cp_cu_seqlens, seq_map_c2r, T, chunk_size) + _RAW_SEQ_IDX_CACHE[key] = (weakref.ref(cp_cu_seqlens), val) + return val + + +def _compute_cp_h0_via_fla_h( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cumsum: torch.Tensor, + beta: torch.Tensor, + scale: float, + raw_h0: torch.Tensor | None, + cp_cu_seqlens: torch.Tensor, + seq_map_c2r: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + """Compute cp_h0 directly via FLA's per-chunk h tensor — exact, no mt needed.""" + from cula.kda.cp_h_boundary import kda_cp_h0_boundary + from cula.kda.wy_intra import kda_intra_native + + cp_batch = cp_cu_seqlens.size(0) - 1 + raw_batch = int(seq_map_c2r.max().item()) + 1 if seq_map_c2r.numel() > 0 else 1 + T = k.size(1) + + if T > CP_PREPROCESS_TILE_TOKENS: + return _compute_cp_h0_via_fla_h_tiled( + k=k, + v=v, + g_cumsum=g_cumsum, + beta=beta, + raw_h0=raw_h0, + cp_cu_seqlens=cp_cu_seqlens, + seq_map_c2r=seq_map_c2r, + chunk_size=chunk_size, + ) + + w, u, _, kg = kda_intra_native( + k=k, + v=v, + gk=g_cumsum, + beta=beta, + chunk_size=chunk_size, + ) + + slot_map = _get_slot_map(cp_cu_seqlens, T, chunk_size) + + if raw_batch > 1: + if raw_h0 is None: + H_v = v.size(2) + K = k.size(3) + V = v.size(3) + raw_h0 = torch.zeros(raw_batch, H_v, V, K, dtype=torch.float32, device=k.device) + h0_chunk0 = raw_h0[0:1] # state at chunk 0 (always raw seq 0) + raw_seq_idx, _ = _get_raw_seq_idx(cp_cu_seqlens, seq_map_c2r, T, chunk_size) + else: + h0_chunk0 = raw_h0 + raw_seq_idx = None + + cp_h0 = kda_cp_h0_boundary( + kg=kg, + w=w, + u=u, + g_cumsum=g_cumsum, + h0=h0_chunk0, + slot_map=slot_map, + num_cp=cp_batch, + chunk_size=chunk_size, + raw_h0_dense=raw_h0 if raw_batch > 1 else None, + raw_seq_idx=raw_seq_idx, + ) + del w, u, kg + return cp_h0 + + +CP_PREPROCESS_TILE_TOKENS = 16384 + + +def _compute_cp_h0_via_fla_h_tiled( + k: torch.Tensor, + v: torch.Tensor, + g_cumsum: torch.Tensor, + beta: torch.Tensor, + raw_h0: torch.Tensor | None, + cp_cu_seqlens: torch.Tensor, + seq_map_c2r: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + from cula.kda.cp_h_boundary import kda_cp_h0_boundary + from cula.kda.wy_intra import kda_intra_native + + assert k.size(0) == 1, "tiled CP preprocess assumes packed [1, T, H, K] layout" + cp_batch = cp_cu_seqlens.size(0) - 1 + raw_batch = int(seq_map_c2r.max().item()) + 1 if seq_map_c2r.numel() > 0 else 1 + T = k.size(1) + H = k.size(2) + K = k.size(3) + V = v.size(3) + + if raw_h0 is None: + raw_h0 = torch.zeros(raw_batch, H, V, K, dtype=torch.float32, device=k.device) + device = k.device + + assert CP_PREPROCESS_TILE_TOKENS % chunk_size == 0 + tile_tokens = CP_PREPROCESS_TILE_TOKENS + + global_slot_map = _get_slot_map(cp_cu_seqlens, T, chunk_size) + # Per-chunk raw-seq map (length NT_total). For raw_batch==1 this is all + # zeros and we don't even pass it to the kernel. + if raw_batch > 1: + global_raw_seq_idx, global_raw_seq_idx_cpu = _get_raw_seq_idx(cp_cu_seqlens, seq_map_c2r, T, chunk_size) + else: + global_raw_seq_idx = None + global_raw_seq_idx_cpu = None + + cp_h0 = torch.empty(cp_batch, H, V, K, dtype=torch.float32, device=device) + + exit_state = torch.empty(H, V, K, dtype=torch.float32, device=device) + + BT = chunk_size + BC = 16 + max_tile_T = min(tile_tokens, T) + buf_w = torch.empty(1, max_tile_T, H, K, dtype=k.dtype, device=device) + buf_u = torch.empty(1, max_tile_T, H, V, dtype=v.dtype, device=device) + buf_kg = torch.empty(1, max_tile_T, H, K, dtype=k.dtype, device=device) + buf_Akkd = torch.empty(1, max_tile_T, H, BC, dtype=torch.float32, device=device) + buf_Akk = torch.empty(1, max_tile_T, H, BT, dtype=k.dtype, device=device) + + n_tiles = (T + tile_tokens - 1) // tile_tokens + for tile_idx in range(n_tiles): + s = tile_idx * tile_tokens + e = min(s + tile_tokens, T) + is_last_tile = tile_idx == n_tiles - 1 + s_chunk = s // chunk_size + e_chunk = (e + chunk_size - 1) // chunk_size + + k_t = k[:, s:e] + v_t = v[:, s:e] + g_t = g_cumsum[:, s:e] + beta_t = beta[:, s:e] + + w_t, u_t, _, kg_t = kda_intra_native( + k=k_t, + v=v_t, + gk=g_t, + beta=beta_t, + chunk_size=chunk_size, + out_w=buf_w, + out_u=buf_u, + out_kg=buf_kg, + out_Akkd=buf_Akkd, + out_Akk=buf_Akk, + ) + + slot_map_t = global_slot_map[s_chunk:e_chunk] + + if raw_batch > 1: + first_raw_in_tile = global_raw_seq_idx_cpu[s_chunk] + if tile_idx == 0: + h0_in = raw_h0[first_raw_in_tile : first_raw_in_tile + 1] + else: + prev_tile_last_raw = global_raw_seq_idx_cpu[s_chunk - 1] + if first_raw_in_tile == prev_tile_last_raw: + h0_in = exit_state + else: + h0_in = raw_h0[first_raw_in_tile : first_raw_in_tile + 1] + raw_seq_idx_t = global_raw_seq_idx[s_chunk:e_chunk] + else: + h0_in = raw_h0 if tile_idx == 0 else exit_state + raw_seq_idx_t = None + + kda_cp_h0_boundary( + kg=kg_t, + w=w_t, + u=u_t, + g_cumsum=g_t, + h0=h0_in, + slot_map=slot_map_t, + num_cp=cp_batch, + chunk_size=chunk_size, + cp_h0_out=cp_h0, + # Skip writing exit_state on the last tile — no consumer. + exit_state=None if is_last_tile else exit_state, + raw_h0_dense=raw_h0 if raw_batch > 1 else None, + raw_seq_idx=raw_seq_idx_t, + ) + + return cp_h0 + + +def intra_card_cp_preprocess( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + raw_h0: torch.Tensor | None, + raw_cu_seqlens: torch.Tensor | None, + chunk_size: int = 64, + raw_cu_seqlens_cpu: torch.Tensor | None = None, +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + assert k.dim() == 4 and k.size(0) == 1, "expected packed [1, T, H, K]" + num_v_heads = v.size(2) + sm_count = get_device_sm_count(k.device) + + if raw_cu_seqlens is None: + raw_cu_seqlens = _create_cu_seqlens(1, k.size(1), k.device.index, torch.int32) + + use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask = _calc_cp_seqs_cached( + raw_cu_seqlens, + chunk_size, + num_v_heads, + sm_count, + raw_cu_seqlens_cpu=raw_cu_seqlens_cpu, + ) + if not use_cp: + return None, None, None, None + + cp_h0 = _compute_cp_h0_via_fla_h( + q=q, + k=k, + v=v, + g_cumsum=g, + beta=beta, + scale=scale, + raw_h0=raw_h0, + cp_cu_seqlens=cp_cu_seqlens, + seq_map_c2r=seq_map_c2r, + chunk_size=chunk_size, + ) + + return cp_h0, cp_cu_seqlens, seq_map_c2r, raw_cu_seqlens diff --git a/cula/kda/cp_h_boundary.py b/cula/kda/cp_h_boundary.py new file mode 100644 index 00000000..b584bee9 --- /dev/null +++ b/cula/kda/cp_h_boundary.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [16, 32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT"], +) +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_EXIT_STATE": lambda args: args["exit_state"] is not None, + "MULTI_RAW": lambda args: args["raw_seq_idx"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def _kda_h_boundary_kernel( + kg, + w, + u, + gk, + h0, # fp32 [H, V, K] OR [1, H, V, K] — state at chunk 0 + cp_h0_out, # fp32 [num_cp, H, V, K], cuLA layout + slot_map, # int32 [NT], slot_map[i_t] = output slot or -1 + exit_state, # fp32 [H, V, K] or None — end-of-tile state for next tile + raw_h0_dense, # fp32 [raw_batch, H, V, K] — per-raw-seq h0 for cross-seq resets + raw_seq_idx, # int32 [NT] — raw_seq_idx[i_t] = which raw seq chunk i_t belongs to + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_EXIT_STATE: tl.constexpr, + MULTI_RAW: tl.constexpr, +): + # Grid: (V_tiles, H). One CTA per (V-tile, head) walks all T chunks serially. + i_v, i_h = tl.program_id(0), tl.program_id(1) + NT = tl.cdiv(T, BT) + + # State tiles: 2× [BV, 64] holds the (V_tile=BV, K=128) state split into K=64 halves. + b_h1 = tl.zeros([BV, 64], dtype=tl.float32) + b_h2 = tl.zeros([BV, 64], dtype=tl.float32) + + # Per-batch (B=1) offset into per-head buffers. + kg_ptr = kg + i_h * K + w_ptr = w + i_h * K + u_ptr = u + i_h * V + gk_ptr = gk + i_h * K + h0_ptr = h0 + i_h * V * K if USE_INITIAL_STATE else h0 # h0 [H, V, K] or [1, H, V, K] + + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0_ptr, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + p_h0_2 = tl.make_block_ptr(h0_ptr, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + + if MULTI_RAW: + prev_raw = tl.load(raw_seq_idx + 0) + else: + prev_raw = 0 # unused + + for i_t in range(NT): + if MULTI_RAW: + if i_t > 0: + cur_raw = tl.load(raw_seq_idx + i_t) + if cur_raw != prev_raw: + rh_base = raw_h0_dense + cur_raw.to(tl.int64) * (H * V * K) + i_h * (V * K) + p_rh_1 = tl.make_block_ptr(rh_base, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + p_rh_2 = tl.make_block_ptr(rh_base, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + b_h1 = tl.load(p_rh_1, boundary_check=(0, 1)).to(tl.float32) + b_h2 = tl.load(p_rh_2, boundary_check=(0, 1)).to(tl.float32) + prev_raw = cur_raw + + # ----- Conditional boundary store ----- + # slot = slot_map[i_t]; if slot >= 0, write h to cp_h0_out[slot, i_h, ...] + slot = tl.load(slot_map + i_t) + is_boundary = slot >= 0 + # Use slot=0 as safe target when not boundary; mask blocks the store. + safe_slot = tl.maximum(slot, 0).to(tl.int64) + out_base = cp_h0_out + safe_slot * (H * V * K) + i_h * (V * K) + p_out_1 = tl.make_block_ptr(out_base, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + p_out_2 = tl.make_block_ptr(out_base, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + if is_boundary: + tl.store(p_out_1, b_h1, boundary_check=(0, 1)) + tl.store(p_out_2, b_h2, boundary_check=(0, 1)) + + # ----- Compute v_new = u - w @ h ----- + # w [BT, K] split into 2× [BT, 64] for the matmuls + p_w1 = tl.make_block_ptr(w_ptr, (T, K), (H * K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w1 = tl.load(p_w1, boundary_check=(0, 1)) + b_v_acc = tl.dot(b_w1, tl.trans(b_h1).to(b_w1.dtype)) + p_w2 = tl.make_block_ptr(w_ptr, (T, K), (H * K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w2 = tl.load(p_w2, boundary_check=(0, 1)) + b_v_acc += tl.dot(b_w2, tl.trans(b_h2).to(b_w2.dtype)) + + # u [BT, V_tile] + p_u = tl.make_block_ptr(u_ptr, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_u, boundary_check=(0, 1)) - b_v_acc + + # ----- Apply gate decay: h *= exp2(gk_last) ----- + last_idx = tl.minimum((i_t + 1) * BT, T) - 1 + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk_ptr + last_idx * (H * K) + o_k1, mask=(o_k1 < K), other=0.0).to(tl.float32) + b_h1 *= tl.math.exp2(b_gk_last1)[None, :] + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk_ptr + last_idx * (H * K) + o_k2, mask=(o_k2 < K), other=0.0).to(tl.float32) + b_h2 *= tl.math.exp2(b_gk_last2)[None, :] + + # ----- State update: h += kg^T @ v_new (transpose_state_layout=True form) ----- + b_v_bf = b_v.to(kg.dtype.element_ty) + p_kg1 = tl.make_block_ptr(kg_ptr, (K, T), (1, H * K), (0, i_t * BT), (64, BT), (0, 1)) + b_kg1 = tl.load(p_kg1, boundary_check=(0, 1)) + b_h1 += tl.trans(tl.dot(b_kg1, b_v_bf)) + p_kg2 = tl.make_block_ptr(kg_ptr, (K, T), (1, H * K), (64, i_t * BT), (64, BT), (0, 1)) + b_kg2 = tl.load(p_kg2, boundary_check=(0, 1)) + b_h2 += tl.trans(tl.dot(b_kg2, b_v_bf)) + + # ----- After the loop: optionally store the end-of-tile state for the next tile ----- + if STORE_EXIT_STATE: + es_ptr = exit_state + i_h * V * K + p_es_1 = tl.make_block_ptr(es_ptr, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + p_es_2 = tl.make_block_ptr(es_ptr, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + tl.store(p_es_1, b_h1, boundary_check=(0, 1)) + tl.store(p_es_2, b_h2, boundary_check=(0, 1)) + + +def kda_cp_h0_boundary( + kg: torch.Tensor, # bf16 [1, T, H, K] + w: torch.Tensor, # bf16 [1, T, H, K] + u: torch.Tensor, # bf16 [1, T, H, V] + g_cumsum: torch.Tensor, # fp32 [1, T, H, K] + h0: torch.Tensor | None, # fp32 [H, V, K] (or [1, H, V, K]) — state at this tile's chunk 0 + slot_map: torch.Tensor, # int32 [NT], slot_map[i_t] = output slot or -1 + num_cp: int, + chunk_size: int = 64, + cp_h0_out: torch.Tensor | None = None, # pre-allocated [num_cp, H, V, K] fp32; allocated if None + exit_state: torch.Tensor | None = None, # if not None, kernel writes the end-of-T state to [H, V, K] fp32 + raw_h0_dense: torch.Tensor | None = None, # fp32 [raw_batch, H, V, K] for cross-raw-seq resets + raw_seq_idx: torch.Tensor | None = None, # int32 [NT] mapping chunk → raw seq idx (None for single-raw) +) -> torch.Tensor: + assert kg.is_contiguous() and w.is_contiguous() and u.is_contiguous() + assert kg.size(0) == 1 and w.size(0) == 1 and u.size(0) == 1 and g_cumsum.size(0) == 1 + assert slot_map.dtype == torch.int32 and slot_map.is_contiguous() + if raw_seq_idx is not None: + assert raw_h0_dense is not None, "raw_h0_dense required when raw_seq_idx is set" + assert raw_seq_idx.dtype == torch.int32 and raw_seq_idx.is_contiguous() + + T = kg.size(1) + H = kg.size(2) + K = kg.size(3) + V = u.size(3) + assert K == 128 and V == 128, f"Phase 1 kernel hard-codes K=V=128, got K={K} V={V}" + + if cp_h0_out is None: + cp_h0_out = torch.empty(num_cp, H, V, K, dtype=torch.float32, device=kg.device) + + BT = chunk_size + + # Grid depends on BV (autotune-selected); use meta-grid. + def grid(meta): + return (V // meta["BV"], H) + + _kda_h_boundary_kernel[grid]( + kg, + w, + u, + g_cumsum, + h0, + cp_h0_out, + slot_map, + exit_state, + raw_h0_dense, + raw_seq_idx, + T, + H=H, + K=K, + V=V, + BT=BT, + ) + return cp_h0_out diff --git a/cula/kda/gate_l2norm_fused.py b/cula/kda/gate_l2norm_fused.py new file mode 100644 index 00000000..aa93f69f --- /dev/null +++ b/cula/kda/gate_l2norm_fused.py @@ -0,0 +1,194 @@ +import torch +import triton +import triton.language as tl +from fla.ops.utils.constant import RCP_LN2 as _RCP_LN2 +from fla.ops.utils.index import prepare_chunk_indices +from fla.ops.utils.softplus import softplus + +# Triton requires module-level constants used inside @jit kernels to be +# wrapped in tl.constexpr. +RCP_LN2 = tl.constexpr(_RCP_LN2) + + +@triton.jit +def _gate_l2norm_fused_kernel( + # Pointers + g_ptr, + A_log_ptr, + dt_bias_ptr, # gate inputs + q_ptr, + k_ptr, # qk inputs + g_out_ptr, # gate output (fp32 cumsum) + yq_ptr, + yk_ptr, # qk outputs (bf16) + rstd_q_ptr, + rstd_k_ptr, # qk rstd outputs (fp32) + cu_seqlens_ptr, + chunk_indices_ptr, + # Scalars + lower_bound, + eps_l2, + T, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + HAS_BIAS: tl.constexpr, + USE_LOWER_BOUND: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t = tl.program_id(0) + i_bh = tl.program_id(1) + i_h = i_bh % H + + if IS_VARLEN: + i_n = tl.load(chunk_indices_ptr + i_t * 2).to(tl.int32) + i_t_local = tl.load(chunk_indices_ptr + i_t * 2 + 1).to(tl.int32) + bos = tl.load(cu_seqlens_ptr + i_n).to(tl.int32) + eos = tl.load(cu_seqlens_ptr + i_n + 1).to(tl.int32) + T_seq = eos - bos + bt_base = bos + i_t_local * BT + valid_t = (i_t_local * BT + tl.arange(0, BT)) < T_seq + else: + bt_base = (i_bh // H) * T + i_t * BT + valid_t = (i_t * BT + tl.arange(0, BT)) < T + + rows = bt_base + tl.arange(0, BT) + + cols = tl.arange(0, BD) # (BD,) + valid_d = cols < D + + offs = rows[:, None] * (H * D) + i_h * D + cols[None, :] + mask = valid_t[:, None] & valid_d[None, :] + + # ==================================================================== + # GATE: cumsum( transform(g + bias) ) * RCP_LN2 + # ==================================================================== + b_g = tl.load(g_ptr + offs, mask=mask, other=0.0).to(tl.float32) + + if HAS_BIAS: + b_bias = tl.load(dt_bias_ptr + i_h * D + cols, mask=valid_d, other=0.0).to(tl.float32) + b_g = b_g + b_bias[None, :] + + b_A = tl.load(A_log_ptr + i_h).to(tl.float32) + if USE_LOWER_BOUND: + b_gate = lower_bound * tl.sigmoid(tl.exp(b_A) * b_g) + else: + b_gate = -tl.exp(b_A) * softplus(b_g) + + b_gate_cs = tl.cumsum(b_gate, axis=0) * RCP_LN2 + # zero out the tokens beyond T so we don't pollute g_out tail rows + b_gate_cs = tl.where(mask, b_gate_cs, 0.0) + tl.store(g_out_ptr + offs, b_gate_cs, mask=mask) + + # ==================================================================== + # L2-NORM Q — per-row normalisation along D, rstd written to (B*T*H,) + # ==================================================================== + b_q = tl.load(q_ptr + offs, mask=mask, other=0.0).to(tl.float32) + b_q_sq = tl.sum(b_q * b_q, axis=1) # (BT,) + b_rstd_q = 1.0 / tl.sqrt(b_q_sq + eps_l2) # (BT,) + b_yq = b_q * b_rstd_q[:, None] + tl.store(yq_ptr + offs, b_yq.to(yq_ptr.dtype.element_ty), mask=mask) + + # rstd is shape (B, T, H,) contiguous in (b*T*H + t*H + h) order + rstd_offs = rows * H + i_h # (BT,) + tl.store(rstd_q_ptr + rstd_offs, b_rstd_q, mask=valid_t) + + # ==================================================================== + # L2-NORM K + # ==================================================================== + b_k = tl.load(k_ptr + offs, mask=mask, other=0.0).to(tl.float32) + b_k_sq = tl.sum(b_k * b_k, axis=1) + b_rstd_k = 1.0 / tl.sqrt(b_k_sq + eps_l2) + b_yk = b_k * b_rstd_k[:, None] + tl.store(yk_ptr + offs, b_yk.to(yk_ptr.dtype.element_ty), mask=mask) + tl.store(rstd_k_ptr + rstd_offs, b_rstd_k, mask=valid_t) + + +def gate_l2norm_fused_fwd( + g: torch.Tensor, # (B, T, H, D) bf16 + q: torch.Tensor, # (B, T, H, D) bf16 + k: torch.Tensor, # (B, T, H, D) bf16 + A_log: torch.Tensor, # (H,) fp32 + dt_bias: torch.Tensor | None, # (H*D,) fp32 or None + lower_bound: float | None, # if not None and safe_gate -> use lb*sigmoid path + chunk_size: int = 64, + eps_l2: float = 1e-6, + cu_seqlens: torch.IntTensor | None = None, + chunk_indices: torch.IntTensor | None = None, +): + """One-launch fused preprocessing. + + Returns: + g_out: (B, T, H, D) fp32 -- gate cumsum * RCP_LN2 + y_q: (B, T, H, D) bf16 -- l2-normalised q + y_k: (B, T, H, D) bf16 -- l2-normalised k + rstd_q, rstd_k: (B, T, H) fp32 -- 1/sqrt(sum^2 + eps) + """ + assert g.shape == q.shape == k.shape, f"shapes must match: g{g.shape} q{q.shape} k{k.shape}" + assert g.is_contiguous() and q.is_contiguous() and k.is_contiguous(), "all inputs must be contiguous" + B, T, H, D = g.shape + assert chunk_size == 64, "only chunk_size=64 supported (matches SM90 main kernel)" + + is_varlen = cu_seqlens is not None + if is_varlen: + assert B == 1, "varlen path expects packed B=1 layout" + if chunk_indices is None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + + if dt_bias is None: + dt_bias_arg = A_log # any valid fp32 pointer; HAS_BIAS=False suppresses use + has_bias = False + else: + dt_bias_arg = dt_bias + has_bias = True + + use_lower_bound = lower_bound is not None + if not use_lower_bound: + lower_bound = 0.0 # unused, but Triton needs a value + + g_out = torch.empty_like(g, dtype=torch.float32) + y_q = torch.empty_like(q) + y_k = torch.empty_like(k) + rstd_q = torch.empty((B, T, H), dtype=torch.float32, device=q.device) + rstd_k = torch.empty((B, T, H), dtype=torch.float32, device=q.device) + + BD = triton.next_power_of_2(D) + if is_varlen: + NT = chunk_indices.shape[0] + grid = (NT, H) + cu_seqlens_arg = cu_seqlens + chunk_indices_arg = chunk_indices + else: + NT = triton.cdiv(T, chunk_size) # ceil — partial last chunk handled by mask + grid = (NT, B * H) + cu_seqlens_arg = A_log + chunk_indices_arg = A_log + num_warps = 1 if BD <= 128 else (2 if BD <= 256 else 4) + + _gate_l2norm_fused_kernel[grid]( + g_ptr=g, + A_log_ptr=A_log, + dt_bias_ptr=dt_bias_arg, + q_ptr=q, + k_ptr=k, + g_out_ptr=g_out, + yq_ptr=y_q, + yk_ptr=y_k, + rstd_q_ptr=rstd_q, + rstd_k_ptr=rstd_k, + cu_seqlens_ptr=cu_seqlens_arg, + chunk_indices_ptr=chunk_indices_arg, + lower_bound=lower_bound, + eps_l2=eps_l2, + T=T, + H=H, + D=D, + BT=chunk_size, + BD=BD, + HAS_BIAS=has_bias, + USE_LOWER_BOUND=use_lower_bound, + IS_VARLEN=is_varlen, + num_warps=num_warps, + ) + return g_out, y_q, y_k, rstd_q, rstd_k diff --git a/cula/kda/hopper_fused_fwd_opt.py b/cula/kda/hopper_fused_fwd_opt.py new file mode 100644 index 00000000..69fcd186 --- /dev/null +++ b/cula/kda/hopper_fused_fwd_opt.py @@ -0,0 +1,374 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimized Hopper KDA prefill: fused gate+l2norm preprocessing + intra-card CP.""" + +import torch +from einops import rearrange +from fla.modules.l2norm import l2norm_fwd +from fla.ops.kda.gate import kda_gate_chunk_cumsum +from fla.ops.utils import chunk_local_cumsum +from fla.ops.utils.constant import RCP_LN2 +from fla.ops.utils.index import prepare_chunk_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + +import cula.cudac as cula_cuda +from cula.kda.cp_context import intra_card_cp_preprocess +from cula.kda.gate_l2norm_fused import gate_l2norm_fused_fwd +from cula.kda.l2norm_qk_fused import l2norm_fwd_qk +from cula.utils import _get_cache_buf, assert_hopper, get_device_sm_count, prepare_uniform_cu_seqlens + +FUSED_L2NORM_QK_TH_MAX = 10000 + +FUSED_GATE_L2NORM_TH_FIXED = 16384 + +FUSED_GATE_L2NORM_TH_VARLEN = 65536 + +FUSED_GATE_L2NORM_VARLEN_AVG_SEQ = 256 + + +def _fused_gate_l2norm_threshold(cu_seqlens_is_none): + return FUSED_GATE_L2NORM_TH_FIXED if cu_seqlens_is_none else FUSED_GATE_L2NORM_TH_VARLEN + + +FUSED_GATE_L2NORM_TH_MAX = FUSED_GATE_L2NORM_TH_VARLEN + + +def _inference_forward( + q, + k, + v, + g, + beta, + A_log, + dt_bias, + scale, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel, + use_gate_in_kernel, + safe_gate, + lower_bound, + cu_seqlens, + chunk_indices, + auto_cp, + cu_seqlens_cpu=None, +): + chunk_size = 64 + batch_size, seq_len, num_qk_heads, head_dim = q.shape + num_v_heads = v.shape[-2] + + cu_seqlens_is_none = cu_seqlens is None + if cu_seqlens_is_none: + cu_seqlens = prepare_uniform_cu_seqlens(batch_size, seq_len, q.device, torch.int32) + if batch_size != 1: + q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) + + if cu_seqlens_is_none: + avg_seq_ok = True + else: + N = cu_seqlens.numel() - 1 + packed_T = q.shape[1] + avg_seq_ok = N <= 1 or packed_T <= N * FUSED_GATE_L2NORM_VARLEN_AVG_SEQ + + fused_all_pre = ( + use_gate_in_kernel + and use_qk_l2norm_in_kernel + and (q.numel() // q.shape[-1]) <= _fused_gate_l2norm_threshold(cu_seqlens_is_none) + and num_qk_heads == num_v_heads + and avg_seq_ok + ) + + if fused_all_pre: + if chunk_indices is None and not cu_seqlens_is_none: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size, cu_seqlens_cpu=cu_seqlens_cpu) + g_out, yq, yk, _, _ = gate_l2norm_fused_fwd( + g=g, + q=q, + k=k, + A_log=A_log, + dt_bias=dt_bias, + lower_bound=lower_bound if safe_gate else None, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens if not cu_seqlens_is_none else None, + chunk_indices=chunk_indices if not cu_seqlens_is_none else None, + ) + g, q, k = g_out, yq, yk + else: + if chunk_indices is None and not cu_seqlens_is_none: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size, cu_seqlens_cpu=cu_seqlens_cpu) + if use_gate_in_kernel: + g = kda_gate_chunk_cumsum( + g=g, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=lower_bound, + ) + else: + g = chunk_local_cumsum( + g=g, + chunk_size=chunk_size, + scale=RCP_LN2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + if use_qk_l2norm_in_kernel: + D = q.shape[-1] + n_rows = q.numel() // D + if n_rows <= FUSED_L2NORM_QK_TH_MAX: + q_flat = q.view(-1, D) + k_flat = k.view(-1, D) + yq_flat, yk_flat, _, _ = l2norm_fwd_qk(q_flat, k_flat) + q = yq_flat.view_as(q) + k = yk_flat.view_as(k) + else: + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + packed_seq = batch_size * seq_len + q = q.reshape(packed_seq, num_qk_heads, head_dim).contiguous() + k = k.reshape(packed_seq, num_qk_heads, head_dim).contiguous() + v = v.reshape(packed_seq, num_v_heads, head_dim).contiguous() + g = g.reshape(packed_seq, num_v_heads, head_dim).contiguous() + beta = beta.reshape(packed_seq, num_v_heads).contiguous() + + cp_seq_map = None + raw_cu_seqlens_for_cp = None + cp_would_fire = auto_cp and ( + # Multi-seq varlen: CP only when grid is starved AND total T amortizes. + ( + cu_seqlens is not None + and (cu_seqlens.numel() - 1) > 1 + and (cu_seqlens.numel() - 1) * num_v_heads <= 16 + and packed_seq >= 8192 + ) + or + # Single sequence: per-H T thresholds matching _calc_cp_seqs. + ( + (cu_seqlens is None or cu_seqlens.numel() - 1 == 1) + and ( + (num_v_heads <= 8 and packed_seq >= 4096) + or (num_v_heads <= 16 and packed_seq >= 4096) + or (num_v_heads <= 32 and packed_seq >= 16384) + ) + ) + ) + if cp_would_fire: + q4 = q.view(1, packed_seq, num_qk_heads, head_dim) + k4 = k.view(1, packed_seq, num_qk_heads, head_dim) + v4 = v.view(1, packed_seq, num_v_heads, head_dim) + g4 = g.view(1, packed_seq, num_v_heads, head_dim) + beta4 = beta.view(1, packed_seq, num_v_heads) + cp_h0, cp_cu_seqlens, cp_seq_map, raw_cu_seqlens_for_cp = intra_card_cp_preprocess( + q=q4, + k=k4, + v=v4, + g=g4, + beta=beta4, + scale=scale, + raw_h0=initial_state, + raw_cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + raw_cu_seqlens_cpu=cu_seqlens_cpu, + ) + del q4, k4, v4, g4, beta4 + if cp_seq_map is not None: + cu_seqlens = cp_cu_seqlens + initial_state = cp_h0 + + sm_count = get_device_sm_count(q.device) + workspace_buffer = _get_cache_buf("hopper_kda_fwd_workspace", sm_count * 128, q.device) + + o, final_state = cula_cuda.kda_fwd_prefill( + None, + None, + q, + k, + v, + initial_state, + g, + beta, + cu_seqlens, + workspace_buffer, + scale, + output_final_state, + safe_gate, + cp_seq_map_=cp_seq_map, + raw_cu_seqlens_=raw_cu_seqlens_for_cp, + ) + o = rearrange(o, "(b t) h d -> b t h d", b=batch_size) + return o.to(q.dtype), final_state + + +class HopperChunkKDAFunctionOpt(torch.autograd.Function): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + safe_gate: bool = False, + lower_bound: float | None = None, + cu_seqlens: torch.IntTensor | None = None, + chunk_indices: torch.IntTensor | None = None, + auto_cp: bool = True, + cu_seqlens_cpu: torch.IntTensor | None = None, + ): + return _inference_forward( + q, + k, + v, + g, + beta, + A_log, + dt_bias, + scale, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel, + use_gate_in_kernel, + safe_gate, + lower_bound, + cu_seqlens, + chunk_indices, + auto_cp, + cu_seqlens_cpu=cu_seqlens_cpu, + ) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + raise NotImplementedError("Backward pass is not implemented yet.") + + +@torch.compiler.disable +def cula_kda_prefill_opt( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + safe_gate: bool = False, + lower_bound: float | None = None, + cu_seqlens: torch.IntTensor | None = None, + chunk_indices: torch.IntTensor | None = None, + auto_cp: bool = True, + cu_seqlens_cpu: torch.IntTensor | None = None, + **kwargs, +): + assert_hopper() + assert safe_gate, "Only support safe_gate=True." + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", + ) + if initial_state is not None: + assert initial_state.dtype == torch.float32, "initial_state must be in float32." + + A_log, dt_bias = None, None + if use_gate_in_kernel: + assert "A_log" in kwargs, "A_log must be provided when use_gate_in_kernel=True." + A_log, dt_bias = kwargs["A_log"], kwargs.get("dt_bias") + if safe_gate: + if lower_bound is None: + raise ValueError("`lower_bound` must be specified when `safe_gate=True` and `use_gate_in_kernel=True`.") + if not (-5 <= lower_bound < 0): + raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") + + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share batch and sequence dimensions." + batch_size, seq_len, num_qk_heads, head_dim = q.shape + num_v_heads = v.shape[-2] + assert num_qk_heads > 0 and num_v_heads > 0 + assert num_v_heads % num_qk_heads == 0 + assert g.shape == (batch_size, seq_len, num_v_heads, head_dim) + assert v.shape == (batch_size, seq_len, num_v_heads, head_dim) + assert beta.shape == (batch_size, seq_len, num_v_heads) + assert q.dtype == k.dtype == v.dtype == torch.bfloat16, "q, k, v must be in bfloat16." + assert beta.dtype == torch.bfloat16 or beta.dtype == torch.float32, "beta must be in bfloat16 or float32." + assert q.shape[-1] == k.shape[-1] == v.shape[-1] == 128, "Currently we only support head dim of 128 for KDA" + if scale is None: + scale = k.shape[-1] ** -0.5 + + needs_grad = torch.is_grad_enabled() and any(t.requires_grad for t in (q, k, v, g, beta) if t is not None) + if not needs_grad: + o, final_state = _inference_forward( + q, + k, + v, + g, + beta, + A_log, + dt_bias, + scale, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel, + use_gate_in_kernel, + safe_gate, + lower_bound, + cu_seqlens, + chunk_indices, + auto_cp, + cu_seqlens_cpu=cu_seqlens_cpu, + ) + return o, (final_state if output_final_state else None) + + o, final_state = HopperChunkKDAFunctionOpt.apply( + q, + k, + v, + g, + beta, + A_log, + dt_bias, + scale, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel, + use_gate_in_kernel, + safe_gate, + lower_bound, + cu_seqlens, + chunk_indices, + auto_cp, + cu_seqlens_cpu, + ) + + return o, (final_state if output_final_state else None) diff --git a/cula/kda/l2norm_qk_fused.py b/cula/kda/l2norm_qk_fused.py new file mode 100644 index 00000000..234d57bf --- /dev/null +++ b/cula/kda/l2norm_qk_fused.py @@ -0,0 +1,119 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +"""Fused l2-norm for paired (q, k) tensors — one Triton kernel handles both. + +cuLA's baseline `cula_kda_prefill` calls `l2norm_fwd(q)` and `l2norm_fwd(k)` +as two separate Triton kernel launches. Each launch costs ~50-80 μs of CPU +overhead (Python wrapper + torch.empty + CUDA driver dispatch), even though +the actual GPU work is tiny for D=128. + +The two operations are mathematically identical and operate on disjoint +inputs/outputs. By writing one kernel whose grid covers both q and k +(distinguishing them via `tl.program_id(1)`), we cut the Python-driver +overhead in half — saving one launch (~50 μs) per fwd at any T. + +Combined with skipping the gate-stream optimization (which didn't pay off +due to torch.cuda.stream context overhead — see hopper_fused_fwd_opt.py), +this is the cleanest small-T speedup we found. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2norm_fwd_qk_kernel( + q_ptr, + k_ptr, # input pointers (T*H, D) each + yq_ptr, + yk_ptr, # output pointers (T*H, D) each + rstd_q_ptr, + rstd_k_ptr, # output rstd (T*H,) each + eps, + D, + BD: tl.constexpr, +): + i_row = tl.program_id(0) # 0..T*H-1 + i_qk = tl.program_id(1) # 0=q, 1=k + + cols = tl.arange(0, BD) + mask = cols < D + + # is_q is uniform across all threads in the block (driven by + # tl.program_id(1)), so the if/else compiles to a single conditional + # branch — no warp divergence, and only one tensor is actually loaded. + is_q = i_qk == 0 + base_off = i_row * D + if is_q: + b_x = tl.load(q_ptr + base_off + cols, mask=mask, other=0.0).to(tl.float32) + else: + b_x = tl.load(k_ptr + base_off + cols, mask=mask, other=0.0).to(tl.float32) + + b_rstd = 1.0 / tl.sqrt(tl.sum(b_x * b_x) + eps) + b_y = b_x * b_rstd + + # Symmetric stores — same uniform-branch logic as the load above. + if is_q: + tl.store(yq_ptr + base_off + cols, b_y, mask=mask) + tl.store(rstd_q_ptr + i_row, b_rstd) + else: + tl.store(yk_ptr + base_off + cols, b_y, mask=mask) + tl.store(rstd_k_ptr + i_row, b_rstd) + + +def l2norm_fwd_qk( + q: torch.Tensor, + k: torch.Tensor, + eps: float = 1e-6, +): + """L2-normalize q and k along the last dim, in a single fused kernel. + + Args: + q, k: shape (..., D). Last dim is normalised; preceding dims are + treated as a flat "row" index. q and k must have identical shape. + eps: numerical safety eps. + + Returns: + (y_q, y_k, rstd_q, rstd_k) + y_q, y_k: normalized outputs, same shape as q, k. + rstd_q, rstd_k: 1/sqrt(sum(x^2)+eps), shape q.shape[:-1]. + """ + assert q.shape == k.shape, f"q.shape {q.shape} != k.shape {k.shape}" + assert q.dtype == k.dtype + assert q.device == k.device + assert q.is_contiguous() and k.is_contiguous(), "q, k must be contiguous" + + D = q.shape[-1] + T = q.numel() // D + + y_q = torch.empty_like(q) + y_k = torch.empty_like(k) + rstd_q = torch.empty(q.shape[:-1], dtype=torch.float32, device=q.device) + rstd_k = torch.empty(k.shape[:-1], dtype=torch.float32, device=k.device) + + BD = triton.next_power_of_2(D) + if D > BD or 65536 // q.element_size() < BD: + raise RuntimeError(f"D={D} too large for fused l2norm_fwd_qk") + + # Grid: (T*H, 2) — program_id(1) picks q (0) or k (1). + # Heuristic on num_warps: small D (e.g. 128) needs only 1 warp; up to 4. + num_warps = 1 if BD <= 256 else (2 if BD <= 1024 else 4) + _l2norm_fwd_qk_kernel[(T, 2)]( + q_ptr=q, + k_ptr=k, + yq_ptr=y_q, + yk_ptr=y_k, + rstd_q_ptr=rstd_q, + rstd_k_ptr=rstd_k, + eps=eps, + D=D, + BD=BD, + num_warps=num_warps, + ) + return y_q, y_k, rstd_q, rstd_k diff --git a/cula/kda/wy_intra.py b/cula/kda/wy_intra.py new file mode 100644 index 00000000..58e741c6 --- /dev/null +++ b/cula/kda/wy_intra.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +from cula.kda.wy_recompute import kda_recompute_w_u + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [2, 3, 4] + ], + key=["H", "K", "BT", "BC"], +) +@triton.jit(do_not_specialize=["T"]) +def _kda_intra_sub_chunk_kernel( + k, + g, + beta, + Akkd, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + bos = i_b * T + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_c = i_ti + tl.arange(0, BC) + m_c = o_c < T + + # Per-head pointer offsets + k_base = k + (bos * H + i_h) * K + g_base = g + (bos * H + i_h) * K + beta_base = beta + bos * H + i_h + Akkd_base = Akkd + (bos * H + i_h) * BC + + # Load k, g (BC × BK), beta (BC) for this sub-chunk + p_k = tl.make_block_ptr(k_base, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g_base, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_beta = tl.make_block_ptr(beta_base, (T,), (H,), (i_ti,), (BC,), (0,)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + o_gn = i_ti + tl.minimum(BC // 2, T - i_ti - 1) + o_k = tl.arange(0, BK) + m_k = o_k < K + b_gn = tl.load(g + (bos * H + i_h) * K + o_gn * (H * K) + o_k, mask=m_k, other=0.0).to(tl.float32) + + b_gm = (b_g - b_gn[None, :]).to(tl.float32) + b_gq = tl.where(m_c[:, None], tl.math.exp2(b_gm), 0.0) + b_gk = tl.where(m_c[:, None], tl.math.exp2(-b_gm), 0.0) + + b_kgt = tl.trans(b_k * b_gk) + b_Akk = tl.dot(b_k * b_gq, b_kgt) * b_beta[:, None] + + o_i = tl.arange(0, BC) + m_Akk = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + b_Akk = tl.where(m_Akk, b_Akk, 0.0) + + p_Akkd = tl.make_block_ptr(Akkd_base, (T, BC), (H * BC, 1), (i_ti, 0), (BC, BC), (1, 0)) + tl.store(p_Akkd, b_Akk.to(Akkd.dtype.element_ty), boundary_check=(0, 1)) + tl.debug_barrier() + + b_Ai = -b_Akk + for i in range(2, tl.minimum(BC, T - i_ti)): + b_a = -tl.load(Akkd_base + (i_ti + i) * (H * BC) + o_i) + b_a = tl.where(o_i < i, b_a, 0.0) + b_a += tl.sum(b_a[:, None] * b_Ai, 0) + b_Ai = tl.where((o_i == i)[:, None], b_a, b_Ai) + + b_Ai += m_I + + tl.store(p_Akkd, b_Ai.to(Akkd.dtype.element_ty), boundary_check=(0, 1)) + + +_SOLVE_DOT_PRECISION = tl.constexpr("tf32") + + +@triton.autotune( + configs=[triton.Config({"BK": BK}, num_warps=num_warps) for BK in [32, 64] for num_warps in [1, 2, 4]], + key=["H", "K", "BC"], +) +@triton.jit(do_not_specialize=["T"]) +def _kda_intra_inter_solve_kernel( + k, + g, + beta, + Akkd, + Akk, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + bos = i_b * T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + k_base = k + (bos * H + i_h) * K + g_base = g + (bos * H + i_h) * K + Akk_base = Akk + (bos * H + i_h) * BT + Akkd_base = Akkd + (bos * H + i_h) * BC + + o_i = tl.arange(0, BC) + m_tc1 = (i_tc1 + o_i) < T + m_tc2 = (i_tc2 + o_i) < T + m_tc3 = (i_tc3 + o_i) < T + + b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k0 = tl.make_block_ptr(k_base, (T, K), (H * K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) + p_g0 = tl.make_block_ptr(g_base, (T, K), (H * K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) + b_k0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) + b_g0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) + + # sub-chunk 1 (vs sub-chunk 0) + if i_tc1 < T: + p_k1 = tl.make_block_ptr(k_base, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_g1 = tl.make_block_ptr(g_base, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) + b_gn1 = tl.load(g + (bos * H + i_h) * K + i_tc1 * (H * K) + o_k, mask=m_k, other=0.0).to(tl.float32) + b_gqn = tl.where(m_tc1[:, None], tl.math.exp2(b_g1 - b_gn1[None, :]), 0.0) + b_kgt = tl.trans(b_k0 * tl.math.exp2(b_gn1[None, :] - b_g0)) + b_Akk10 += tl.dot(b_k1 * b_gqn, b_kgt) + + # sub-chunk 2 (vs 0 and 1) + if i_tc2 < T: + p_k2 = tl.make_block_ptr(k_base, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_g2 = tl.make_block_ptr(g_base, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) + b_gn2 = tl.load(g + (bos * H + i_h) * K + i_tc2 * (H * K) + o_k, mask=m_k, other=0.0).to(tl.float32) + b_gqn2 = tl.where(m_tc2[:, None], tl.math.exp2(b_g2 - b_gn2[None, :]), 0.0) + b_kg2 = b_k2 * b_gqn2 + b_kgt0 = tl.trans(b_k0 * tl.math.exp2(b_gn2[None, :] - b_g0)) + b_Akk20 += tl.dot(b_kg2, b_kgt0) + b_kgt1 = tl.trans(b_k1 * tl.math.exp2(b_gn2[None, :] - b_g1)) + b_Akk21 += tl.dot(b_kg2, b_kgt1) + + # sub-chunk 3 (vs 0, 1, 2) + if i_tc3 < T: + p_k3 = tl.make_block_ptr(k_base, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_g3 = tl.make_block_ptr(g_base, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) + b_gn3 = tl.load(g + (bos * H + i_h) * K + i_tc3 * (H * K) + o_k, mask=m_k, other=0.0).to(tl.float32) + b_gqn3 = tl.where(m_tc3[:, None], tl.math.exp2(b_g3 - b_gn3[None, :]), 0.0) + b_kg3 = b_k3 * b_gqn3 + b_kgt0 = tl.trans(b_k0 * tl.math.exp2(b_gn3[None, :] - b_g0)) + b_Akk30 += tl.dot(b_kg3, b_kgt0) + b_kgt1 = tl.trans(b_k1 * tl.math.exp2(b_gn3[None, :] - b_g1)) + b_Akk31 += tl.dot(b_kg3, b_kgt1) + b_kgt2 = tl.trans(b_k2 * tl.math.exp2(b_gn3[None, :] - b_g2)) + b_Akk32 += tl.dot(b_kg3, b_kgt2) + + beta_base = beta + bos * H + i_h + if i_tc1 < T: + p_b1 = tl.make_block_ptr(beta_base, (T,), (H,), (i_tc1,), (BC,), (0,)) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_Akk10 = b_Akk10 * b_b1[:, None] + if i_tc2 < T: + p_b2 = tl.make_block_ptr(beta_base, (T,), (H,), (i_tc2,), (BC,), (0,)) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_Akk20 = b_Akk20 * b_b2[:, None] + b_Akk21 = b_Akk21 * b_b2[:, None] + if i_tc3 < T: + p_b3 = tl.make_block_ptr(beta_base, (T,), (H,), (i_tc3,), (BC,), (0,)) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + b_Akk30 = b_Akk30 * b_b3[:, None] + b_Akk31 = b_Akk31 * b_b3[:, None] + b_Akk32 = b_Akk32 * b_b3[:, None] + + # Load 4 inverted diagonal blocks (from sub_chunk kernel) + p_Akk00 = tl.make_block_ptr(Akkd_base, (T, BC), (H * BC, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akkd_base, (T, BC), (H * BC, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akkd_base, (T, BC), (H * BC, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akkd_base, (T, BC), (H * BC, 1), (i_tc3, 0), (BC, BC), (1, 0)) + b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) + b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) + b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) + b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) + + b_Ai10 = -tl.dot( + tl.dot(b_Ai11, b_Akk10, input_precision=_SOLVE_DOT_PRECISION), + b_Ai00, + input_precision=_SOLVE_DOT_PRECISION, + ) + b_Ai21 = -tl.dot( + tl.dot(b_Ai22, b_Akk21, input_precision=_SOLVE_DOT_PRECISION), + b_Ai11, + input_precision=_SOLVE_DOT_PRECISION, + ) + b_Ai32 = -tl.dot( + tl.dot(b_Ai33, b_Akk32, input_precision=_SOLVE_DOT_PRECISION), + b_Ai22, + input_precision=_SOLVE_DOT_PRECISION, + ) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_Akk20, b_Ai00, input_precision=_SOLVE_DOT_PRECISION) + + tl.dot(b_Akk21, b_Ai10, input_precision=_SOLVE_DOT_PRECISION), + input_precision=_SOLVE_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_Akk31, b_Ai11, input_precision=_SOLVE_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai21, input_precision=_SOLVE_DOT_PRECISION), + input_precision=_SOLVE_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_Akk30, b_Ai00, input_precision=_SOLVE_DOT_PRECISION) + + tl.dot(b_Akk31, b_Ai10, input_precision=_SOLVE_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai20, input_precision=_SOLVE_DOT_PRECISION), + input_precision=_SOLVE_DOT_PRECISION, + ) + + # Store 10 blocks to the full BT×BT Akk buffer. + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + tl.store(p, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + tl.store(p, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) + tl.store(p, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + tl.store(p, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + tl.store(p, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc2, 2 * BC), (BC, BC), (1, 0)) + tl.store(p, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + tl.store(p, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + tl.store(p, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0)) + tl.store(p, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + p = tl.make_block_ptr(Akk_base, (T, BT), (H * BT, 1), (i_tc3, 3 * BC), (BC, BC), (1, 0)) + tl.store(p, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +def kda_intra_native( + k: torch.Tensor, # bf16 [1, T, H, K] + v: torch.Tensor, # bf16 [1, T, H, V] + gk: torch.Tensor, # fp32 [1, T, H, K] + beta: torch.Tensor, # bf16 [1, T, H] + chunk_size: int = 64, + q: torch.Tensor | None = None, # bf16 [1, T, H, K], for qg (optional) + need_qg: bool = False, + out_w: torch.Tensor | None = None, # bf16 [1, T, H, K] + out_u: torch.Tensor | None = None, # bf16 [1, T, H, V] + out_kg: torch.Tensor | None = None, # bf16 [1, T, H, K] + out_Akkd: torch.Tensor | None = None, # fp32 [1, T, H, BC] + out_Akk: torch.Tensor | None = None, # bf16 [1, T, H, BT] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]: + B, T, H, K = k.shape + V = v.shape[-1] + assert B == 1, "intra expects packed [1, T, H, K] input" + assert K == 128 and V == 128, f"specialized for K=V=128, got K={K} V={V}" + BT = chunk_size + BC = 16 + NT = (T + BT - 1) // BT + NC = BT // BC + + if out_Akkd is not None: + Akkd = out_Akkd[:B, :T] + else: + Akkd = torch.empty(B, T, H, BC, device=k.device, dtype=torch.float32) + if out_Akk is not None: + Akk = out_Akk[:B, :T] + Akk.zero_() + else: + Akk = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) + + # Step 1: per-sub-chunk diagonal Akk inversion + BK_sub = triton.next_power_of_2(K) # =128 for K=128 + grid_sub = (NT, NC, B * H) + _kda_intra_sub_chunk_kernel[grid_sub]( + k=k, + g=gk, + beta=beta, + Akkd=Akkd, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK_sub, + ) + + # Step 2: per-chunk off-diagonal + assemble full Akk_inv + grid_inter = (NT, B * H) + _kda_intra_inter_solve_kernel[grid_inter]( + k=k, + g=gk, + beta=beta, + Akkd=Akkd, + Akk=Akk, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) + if out_Akkd is None: + del Akkd + + # Step 3: recompute w, u, kg (and optionally qg) from Akk + w, u, qg, kg = kda_recompute_w_u( + k=k, + v=v, + beta=beta, + A=Akk, + q=q if need_qg else None, + gk=gk, + chunk_size=chunk_size, + out_w=out_w, + out_u=out_u, + out_kg=out_kg, + ) + # Akk is similarly dead after recompute's kernel is queued. + if out_Akk is None: + del Akk + return w, u, qg, kg diff --git a/cula/kda/wy_recompute.py b/cula/kda/wy_recompute.py new file mode 100644 index 00000000..edfb9693 --- /dev/null +++ b/cula/kda/wy_recompute.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT"], +) +@triton.heuristics( + { + "STORE_QG": lambda args: args["qg"] is not None, + "STORE_KG": lambda args: args["kg"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def _kda_recompute_wuk_kernel( + q, + k, + qg, + kg, + v, + beta, + w, + u, + A, + gk, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + STORE_QG: tl.constexpr, + STORE_KG: tl.constexpr, +): + """K = V = 128, BT = 64 specialized. BK = K, BV = V (no inner loop).""" + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + bos = i_b * T + + # Per-head pointer offsets + p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + # ----- u = A @ (β · v) ----- + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, 0), (BT, V), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, 0), (BT, V), (1, 0)) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + # ----- Load k, gk, compute β·exp2(gk)·k and kg ----- + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BT, K), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_gk = tl.make_block_ptr(gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BT, K), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) + b_exp_gk = tl.math.exp2(b_gk) + + # w = A @ (β · exp2(gk) · k) + b_kb = b_k * b_b[:, None] * b_exp_gk + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BT, K), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + # qg = q · exp2(gk) (optional) + if STORE_QG: + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BT, K), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_qg = b_q * b_exp_gk + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BT, K), (1, 0)) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) + + # kg = β · exp2(g_chunk_end - gk) · k (optional, needed by cp_h0 path) + if STORE_KG: + last_idx = tl.minimum(i_t * BT + BT, T) - 1 + o_k = tl.arange(0, K) + b_gn = tl.load(gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=o_k < K, other=0.0).to(tl.float32) + m_t = (i_t * BT + tl.arange(0, BT)) < T + b_kg = b_k * tl.where(m_t[:, None], tl.math.exp2(b_gn[None, :] - b_gk), 0.0) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BT, K), (1, 0)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) + + +def kda_recompute_w_u( + k: torch.Tensor, # bf16 [B, T, H, K] + v: torch.Tensor, # bf16 [B, T, H, V] + beta: torch.Tensor, # bf16 [B, T, H] + A: torch.Tensor, # bf16 [B, T, H, BT] + q: torch.Tensor | None, # bf16 [B, T, H, K], or None + gk: torch.Tensor, # fp32 [B, T, H, K] + chunk_size: int = 64, + out_w: torch.Tensor | None = None, + out_u: torch.Tensor | None = None, + out_kg: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + B, T, H, K = k.shape + V = v.shape[-1] + BT = chunk_size + assert K == 128 and V == 128, f"specialized for K=V=128, got K={K} V={V}" + assert A.shape[-1] == BT, f"expected A.shape[-1]={BT}, got {A.shape[-1]}" + + w = out_w[:B, :T] if out_w is not None else torch.empty_like(k) + u = out_u[:B, :T] if out_u is not None else torch.empty_like(v) + qg = torch.empty_like(q) if q is not None else None + if gk is not None: + kg = out_kg[:B, :T] if out_kg is not None else torch.empty_like(k) + else: + kg = None + + NT = triton.cdiv(T, BT) + grid = (NT, B * H) + _kda_recompute_wuk_kernel[grid]( + q=q, + k=k, + qg=qg, + kg=kg, + v=v, + beta=beta, + w=w, + u=u, + A=A, + gk=gk, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return w, u, qg, kg diff --git a/tests/conftest.py b/tests/conftest.py index f144c10b..a9338aca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import re + import pytest import torch @@ -56,9 +57,5 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) continue callspec = getattr(item, "callspec", None) - if ( - callspec is not None - and callspec.params.get("disable_recompute") - and "kda_fast_norecomp" not in item.keywords - ): + if callspec is not None and callspec.params.get("disable_recompute") and "kda_fast_norecomp" not in item.keywords: item.add_marker(skip_fast_norecomp) diff --git a/tests/test_intracard_cp_sm90.py b/tests/test_intracard_cp_sm90.py new file mode 100644 index 00000000..2353a66e --- /dev/null +++ b/tests/test_intracard_cp_sm90.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# Licensed under the Apache License, Version 2.0. +"""Tests for SM90 intra-card CP: dispatch routing + numerical accuracy. + +Mirrors tests/test_intracard_cp.py (SM100 version) but targets the Hopper +(SM90) `kda_prefill_hopper_opt` / `kda_prefill_hopper_auto` path. + +Three reference levels: + - cuLA basic (kda_prefill_hopper) — same C++ kernel, no CP scheduling + → verifies CP scheduling is value-preserving + - cuLA opt with auto_cp=False — opt Python wrapper but CP disabled + → isolates the CP code paths + - FLA chunk_kda (cross-impl reference) → source of truth for end-to-end output + +The CP path is exercised through: + - kda_prefill_hopper_auto (router picks opt when shape benefits from CP) + - kda_prefill_hopper_opt(auto_cp=True) (force CP entry; bypasses router) +""" + +from __future__ import annotations + +import math +import pathlib +import sys + +import pytest +import torch + +_REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from fla.ops.kda import chunk_kda as fla_chunk_kda # noqa: E402 +from fla.utils import assert_close # noqa: E402 + +from cula.kda import ( # noqa: E402 + kda_prefill_hopper, + kda_prefill_hopper_auto, + kda_prefill_hopper_opt, +) +from cula.kda.cp_context import _calc_cp_seqs # noqa: E402 +from cula.utils import get_device_sm_count # noqa: E402 + +BT, D = 64, 128 +DEVICE = "cuda" +DTYPE = torch.bfloat16 +LOWER_BOUND = -5.0 + +# Tolerances — same convention as tests/test_intracard_cp.py: +# * Same-kernel (CP scheduling only): torch.testing.assert_close +# (CP-on vs CP-off both go through cuLA kernels) +# * Cross-impl (vs FLA): fla.utils.assert_close(ratio=...) +ATOL_SAME_KERNEL = 1e-2 +RTOL_SAME_KERNEL = 1e-2 +RATIO_VS_FLA = 0.015 # bf16 cross-impl noise band (matches SM100 test) +RATIO_STRESS = 1e-6 # deterministic re-run: drift implies race + + +pytestmark = [ + pytest.mark.sm90_only, + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), +] + + +# ============================== Helpers ============================== + + +def _cu_from_seq_lens(seq_lens, device=DEVICE): + cu = [0] + for s in seq_lens: + cu.append(cu[-1] + s) + return torch.tensor(cu, dtype=torch.int32, device=device) + + +def make_varlen_inputs(seq_lens, H, *, use_h0=False, seed=42): + """Build varlen-packed B=1 inputs for kda_prefill_hopper_*.""" + total = sum(seq_lens) + N = len(seq_lens) + cu = _cu_from_seq_lens(seq_lens) + torch.manual_seed(seed) + q = torch.randn(1, total, H, D, dtype=DTYPE, device=DEVICE) + k = torch.randn(1, total, H, D, dtype=DTYPE, device=DEVICE) + v = torch.randn(1, total, H, D, dtype=DTYPE, device=DEVICE) + g = -torch.rand(1, total, H, D, dtype=torch.float32, device=DEVICE).abs() * 0.5 + beta = torch.randn(1, total, H, dtype=torch.float32, device=DEVICE).sigmoid().to(DTYPE) + A_log = torch.randn(H, dtype=torch.float32, device=DEVICE) + dt_bias = torch.randn(H, D, dtype=torch.float32, device=DEVICE) + h0 = torch.randn(N, H, D, D, dtype=torch.float32, device=DEVICE) * 0.1 if use_h0 else None + return q, k, v, g, beta, h0, A_log, dt_bias, cu + + +# ---- entry points under test ---- + + +def _common_cula_kw(q, k, v, g, beta, h0, A_log, dt_bias, cu): + return dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=1.0 / math.sqrt(D), + A_log=A_log, + dt_bias=dt_bias, + initial_state=h0, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + safe_gate=True, + lower_bound=LOWER_BOUND, + cu_seqlens=cu, + ) + + +def run_cula_basic(q, k, v, g, beta, h0, A_log, dt_bias, cu): + return kda_prefill_hopper(**_common_cula_kw(q, k, v, g, beta, h0, A_log, dt_bias, cu)) + + +def run_cula_opt_no_cp(q, k, v, g, beta, h0, A_log, dt_bias, cu): + return kda_prefill_hopper_opt( + **_common_cula_kw(q, k, v, g, beta, h0, A_log, dt_bias, cu), + auto_cp=False, + ) + + +def run_cula_opt_cp(q, k, v, g, beta, h0, A_log, dt_bias, cu): + """Force CP entry through opt wrapper.""" + return kda_prefill_hopper_opt( + **_common_cula_kw(q, k, v, g, beta, h0, A_log, dt_bias, cu), + auto_cp=True, + ) + + +def run_cula_auto(q, k, v, g, beta, h0, A_log, dt_bias, cu): + """Adaptive router — exercises the production entry point.""" + return kda_prefill_hopper_auto( + **_common_cula_kw(q, k, v, g, beta, h0, A_log, dt_bias, cu), + ) + + +def run_fla(q, k, v, g, beta, h0, A_log, dt_bias, cu): + """FLA reference. cuLA returns ht as [N, HV, V, K]; FLA's default layout + is [N, HV, K, V]. We pass ``transpose_state_layout=True`` so its output + matches cuLA's layout — no manual transpose needed before assert_close. + """ + return fla_chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=1.0 / math.sqrt(D), + A_log=A_log, + dt_bias=dt_bias, + initial_state=h0, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + safe_gate=True, + lower_bound=LOWER_BOUND, + cu_seqlens=cu.long(), + transpose_state_layout=True, + ) + + +# ---- assertions ---- + + +def _assert_same_kernel(name, actual, ref): + """torch.testing.assert_close with tight atol/rtol (CP-on vs CP-off use the + same C++ kernel; the only delta is per-chunk recurrence reordering).""" + if actual is None or ref is None: + assert actual is ref, f"{name}: one is None and the other isn't" + return + torch.testing.assert_close( + actual.float(), + ref.float(), + atol=ATOL_SAME_KERNEL, + rtol=RTOL_SAME_KERNEL, + msg=lambda m: f"{name}: {m}", + ) + + +def assert_cp_engages(cu, H): + """Fail fast if _calc_cp_seqs won't engage CP for this shape — without + that, the test silently checks CP-off vs CP-off. + """ + num_sms = get_device_sm_count(torch.device(DEVICE)) + use_cp, cp_cu, *_ = _calc_cp_seqs( + cu, + BT, + H, + num_sms, + raw_cu_seqlens_cpu=cu.cpu(), + ) + assert use_cp and cp_cu is not None, f"_calc_cp_seqs returned use_cp=False for cu={cu.tolist()} H={H}" + n_sub = int(cp_cu.numel() - 1) + raw_batch = int(cu.numel() - 1) + assert n_sub > raw_batch, f"CP didn't split: n_sub={n_sub} == raw_batch={raw_batch}" + + +# ====================== Dispatch path: CP vs no-CP ====================== +# Verifies kda_prefill_hopper_opt(auto_cp=True) routes through CP and matches +# the same-kernel no-CP baseline (kda_prefill_hopper). + +DISPATCH_CONFIGS = [ + # (seq_lens, H, use_h0) + ([32768], 4, False), + ([32768], 4, True), + ([65536], 4, True), + ([32768], 8, False), + ([65536], 8, True), + ([16384, 16384], 4, True), + ([28672, 4096], 4, True), + ([131072, 1024], 4, False), +] + + +@pytest.mark.parametrize("seq_lens,H,use_h0", DISPATCH_CONFIGS) +def test_cp_matches_basic_baseline(seq_lens, H, use_h0): + """CP-on (opt+auto_cp) output equals basic baseline (no-CP).""" + q, k, v, g, beta, h0, A_log, dt_bias, cu = make_varlen_inputs( + seq_lens, + H, + use_h0=use_h0, + ) + assert_cp_engages(cu, H) + with torch.inference_mode(): + o_base, ht_base = run_cula_basic(q, k, v, g, beta, h0, A_log, dt_bias, cu) + o_cp, ht_cp = run_cula_opt_cp(q, k, v, g, beta, h0, A_log, dt_bias, cu) + _assert_same_kernel("o", o_cp, o_base) + _assert_same_kernel("ht", ht_cp, ht_base) + + +@pytest.mark.parametrize("seq_lens,H,use_h0", DISPATCH_CONFIGS) +def test_auto_router_matches_basic_baseline(seq_lens, H, use_h0): + """kda_prefill_hopper_auto output (whatever path it picks) equals basic baseline.""" + q, k, v, g, beta, h0, A_log, dt_bias, cu = make_varlen_inputs( + seq_lens, + H, + use_h0=use_h0, + ) + with torch.inference_mode(): + o_base, ht_base = run_cula_basic(q, k, v, g, beta, h0, A_log, dt_bias, cu) + o_auto, ht_auto = run_cula_auto(q, k, v, g, beta, h0, A_log, dt_bias, cu) + _assert_same_kernel("o", o_auto, o_base) + _assert_same_kernel("ht", ht_auto, ht_base) + + +def test_cp_off_matches_basic_baseline(): + """opt with auto_cp=False must match basic (no CP, no fused-pre divergence).""" + seq_lens, H, use_h0 = [32768], 4, True + q, k, v, g, beta, h0, A_log, dt_bias, cu = make_varlen_inputs(seq_lens, H, use_h0=use_h0) + with torch.inference_mode(): + o_base, ht_base = run_cula_basic(q, k, v, g, beta, h0, A_log, dt_bias, cu) + o_off, ht_off = run_cula_opt_no_cp(q, k, v, g, beta, h0, A_log, dt_bias, cu) + _assert_same_kernel("o", o_off, o_base) + _assert_same_kernel("ht", ht_off, ht_base) + + +# ====================== Cross-impl: CP vs FLA ====================== + +VS_FLA_CONFIGS = [ + ([32768], 4), + ([65536], 4), + ([32768], 8), + ([16384, 16384], 4), + ([28672, 4096], 4), + ([131072, 1024], 4), +] + + +# Irregular varlen lengths +IRREGULAR_VARLEN_CONFIGS = [ + ([1], 4), + ([63], 4), + ([64], 4), + ([65], 4), + ([129], 4), + ([1, 63, 64, 65, 129], 4), + ([1, 63, 64, 65, 129], 8), + ([129, 65, 64, 63, 1], 4), + ([1024, 1, 63, 65, 129], 4), + ([4096, 1, 63, 64, 65, 129], 4), + ([4096, 1, 63, 64, 65, 129], 8), + ([8192, 1, 31, 63, 65, 127, 129, 255], 4), + ([1] * 8 + [63] * 4 + [129] * 2, 4), + ([255, 257, 511, 513], 4), +] + + +@pytest.mark.parametrize("seq_lens,H", IRREGULAR_VARLEN_CONFIGS) +def test_irregular_varlen_vs_fla(seq_lens, H): + """Irregular varlen lengths""" + q, k, v, g, beta, _, A_log, dt_bias, cu = make_varlen_inputs(seq_lens, H, use_h0=False) + with torch.inference_mode(): + o_fla, ht_fla = run_fla(q, k, v, g, beta, None, A_log, dt_bias, cu) + o_opt, ht_opt = run_cula_opt_no_cp(q, k, v, g, beta, None, A_log, dt_bias, cu) + assert_close(f"o (cu={cu.tolist()},H={H})", o_fla, o_opt, ratio=RATIO_VS_FLA) + assert_close(f"ht (cu={cu.tolist()},H={H})", ht_fla, ht_opt, ratio=RATIO_VS_FLA) + + +@pytest.mark.parametrize("seq_lens,H", IRREGULAR_VARLEN_CONFIGS) +def test_irregular_varlen_opt_matches_basic(seq_lens, H): + """Irregular varlen: opt path (may take fused gate+l2norm) equals basic baseline.""" + q, k, v, g, beta, _, A_log, dt_bias, cu = make_varlen_inputs(seq_lens, H, use_h0=False) + with torch.inference_mode(): + o_base, ht_base = run_cula_basic(q, k, v, g, beta, None, A_log, dt_bias, cu) + o_opt, ht_opt = run_cula_opt_no_cp(q, k, v, g, beta, None, A_log, dt_bias, cu) + _assert_same_kernel("o", o_opt, o_base) + _assert_same_kernel("ht", ht_opt, ht_base) + + +@pytest.mark.parametrize("seq_lens,H", VS_FLA_CONFIGS) +def test_cp_vs_fla(seq_lens, H): + """CP output matches FLA chunk_kda reference (cross-impl).""" + q, k, v, g, beta, _, A_log, dt_bias, cu = make_varlen_inputs(seq_lens, H, use_h0=False) + assert_cp_engages(cu, H) + with torch.inference_mode(): + o_fla, ht_fla = run_fla(q, k, v, g, beta, None, A_log, dt_bias, cu) + o_cp, ht_cp = run_cula_opt_cp(q, k, v, g, beta, None, A_log, dt_bias, cu) + assert_close(f"o (cu={cu.tolist()},H={H})", o_fla, o_cp, ratio=RATIO_VS_FLA) + assert_close(f"ht (cu={cu.tolist()},H={H})", ht_fla, ht_cp, ratio=RATIO_VS_FLA) + + +# ====================== Final state ht correctness ====================== +# Per-sequence ht must be independently correct for prefill→decode handoff. + +FINAL_STATE_CONFIGS = [ + ([65536], 4, False), + ([65536], 4, True), + ([65536, 16384], 4, True), + ([28672, 4096], 4, False), + ([131072, 1024], 4, True), +] + + +@pytest.mark.parametrize("seq_lens,H,use_h0", FINAL_STATE_CONFIGS) +def test_cp_final_state_per_seq(seq_lens, H, use_h0): + """Each sequence's ht matches basic baseline independently (no cross-leakage).""" + q, k, v, g, beta, h0, A_log, dt_bias, cu = make_varlen_inputs(seq_lens, H, use_h0=use_h0) + assert_cp_engages(cu, H) + with torch.inference_mode(): + _, ht_base = run_cula_basic(q, k, v, g, beta, h0, A_log, dt_bias, cu) + _, ht_cp = run_cula_opt_cp(q, k, v, g, beta, h0, A_log, dt_bias, cu) + assert ht_cp is not None and ht_cp.shape == ht_base.shape, ( + f"shape mismatch: cp={tuple(ht_cp.shape)} base={tuple(ht_base.shape)}" + ) + for i in range(len(seq_lens)): + _assert_same_kernel(f"ht[{i}] (len={seq_lens[i]})", ht_cp[i], ht_base[i]) + + +# ====================== Stress: race / non-determinism ====================== +# CP's per-chunk preprocess + main kernel — re-running same inputs must +# produce bit-identical outputs (no race, no order-dependence). + +STRESS_ITERS = 50 + + +@pytest.mark.parametrize( + "seq_lens,H,use_h0", + [ + pytest.param([65536], 4, True, id="single-64K-H4-h0"), + pytest.param([65536, 4096], 4, True, id="multi-64K+4K-H4-h0"), + ], +) +def test_cp_stress_repeat(seq_lens, H, use_h0): + """Run CP N times; every iter must match the first (deterministic).""" + q, k, v, g, beta, h0, A_log, dt_bias, cu = make_varlen_inputs( + seq_lens, + H, + use_h0=use_h0, + seed=20260516, + ) + assert_cp_engages(cu, H) + with torch.inference_mode(): + o_ref, ht_ref = run_cula_opt_cp(q, k, v, g, beta, h0, A_log, dt_bias, cu) + torch.cuda.synchronize() + for i in range(STRESS_ITERS): + o_i, ht_i = run_cula_opt_cp(q, k, v, g, beta, h0, A_log, dt_bias, cu) + torch.cuda.synchronize() + assert_close(f"iter {i} o", o_ref, o_i, ratio=RATIO_STRESS) + assert_close(f"iter {i} ht", ht_ref, ht_i, ratio=RATIO_STRESS) + + +# ====================== h0=None equivalence ====================== +# We patched cp_context.py so raw_h0=None synthesizes a zero pool. Verify the +# kernel result is numerically equivalent to passing an explicit zero h0. + + +def test_cp_h0_none_equiv_h0_zeros(): + """h0=None must produce identical ht to h0=zeros (no implicit init drift).""" + seq_lens, H = [65536, 4096], 4 + assert_cp_engages(_cu_from_seq_lens(seq_lens), H) + q, k, v, g, beta, _, A_log, dt_bias, cu = make_varlen_inputs( + seq_lens, + H, + use_h0=False, + seed=20260501, + ) + h0_zeros = torch.zeros(len(seq_lens), H, D, D, dtype=torch.float32, device=DEVICE) + with torch.inference_mode(): + o_none, ht_none = run_cula_opt_cp(q, k, v, g, beta, None, A_log, dt_bias, cu) + o_zeros, ht_zeros = run_cula_opt_cp(q, k, v, g, beta, h0_zeros, A_log, dt_bias, cu) + torch.cuda.synchronize() + o_diff = (o_none.float() - o_zeros.float()).abs().max().item() + ht_diff = (ht_none.float() - ht_zeros.float()).abs().max().item() + assert o_diff < 1e-3, f"o: h0=None vs h0=zeros max abs diff {o_diff:.4e}" + assert ht_diff < 1e-4, f"ht: h0=None vs h0=zeros max abs diff {ht_diff:.4e}" + + +# ====================== CP bypass: shapes where _calc_cp_seqs returns False ====================== +# When CP heuristic says "don't split", auto_cp=True must produce bit-identical +# output to basic (because the kernel takes the same path). + +BYPASS_CONFIGS = [ + ([2048], 8), # H=8 single seq T<=2048 → no CP + ([16384], 64), # H=64 → CP never fires (per _calc_cp_seqs H>=64 branch) + ([4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096], 8), # native_grid 64 >> 16 + ([131072] + [1024] * 5, 8), # raw_batch big enough that native_grid > 16 +] + + +@pytest.mark.parametrize("seq_lens,H", BYPASS_CONFIGS) +def test_cp_bypass_matches_basic(seq_lens, H): + """When CP heuristic skips, auto_cp=True must be a no-op (same output as basic).""" + q, k, v, g, beta, _, A_log, dt_bias, cu = make_varlen_inputs(seq_lens, H, use_h0=False) + num_sms = get_device_sm_count(torch.device(DEVICE)) + use_cp, cp_cu, *_ = _calc_cp_seqs(cu, BT, H, num_sms, raw_cu_seqlens_cpu=cu.cpu()) + n_sub = int(cp_cu.numel() - 1) if cp_cu is not None else 0 + assert not use_cp or n_sub == len(seq_lens), ( + f"expected bypass for cu={cu.tolist()} H={H}, got use_cp={use_cp} n_sub={n_sub}" + ) + with torch.inference_mode(): + o_base, ht_base = run_cula_basic(q, k, v, g, beta, None, A_log, dt_bias, cu) + o_cp, ht_cp = run_cula_opt_cp(q, k, v, g, beta, None, A_log, dt_bias, cu) + _assert_same_kernel("o", o_cp, o_base) + _assert_same_kernel("ht", ht_cp, ht_base)