From d52936f4ad1bdfa97535eef2ffa04ca677646742 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 16:46:36 -0600 Subject: [PATCH 01/22] add single-dispatch layer-by-layer MHA --- aie_kernels/aie2p/softmax.cc | 3 +- iron/operators/gemm/design.py | 4 +- iron/operators/mha_prefill_lxl_sd/__init__.py | 2 + iron/operators/mha_prefill_lxl_sd/op.py | 245 ++++++++++++++++++ .../operators/mha_prefill_lxl_sd/reference.py | 206 +++++++++++++++ iron/operators/mha_prefill_lxl_sd/test.py | 186 +++++++++++++ 6 files changed, 643 insertions(+), 3 deletions(-) create mode 100644 iron/operators/mha_prefill_lxl_sd/__init__.py create mode 100644 iron/operators/mha_prefill_lxl_sd/op.py create mode 100644 iron/operators/mha_prefill_lxl_sd/reference.py create mode 100644 iron/operators/mha_prefill_lxl_sd/test.py diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 64cca202..5778682a 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -3,6 +3,7 @@ #include #include +#include #define SM_VEC_LEN 64 // 32 #define log2e 1.4453125 // 1.44269504089 @@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out aie::vector in_elems, exp_val, input_bf16, log2e_vec, max_val_vec; aie::accum out_vals, exp_val_accum, scaled_accum, exp_in_accum; - float max_val = 0; + float max_val = -INFINITY; float accum_exp_val = 0; float running_max = 0; bfloat16 col_sum_inv; diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index a8ed8ad3..32222dd4 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -299,7 +299,7 @@ def my_matmul( gemm_object, [C_l1_ty_internal], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32" + matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_f32" matmul_kernel = Kernel( matmul_func_name, gemm_object, @@ -314,7 +314,7 @@ def my_matmul( gemm_object, [C_l1_ty], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" matmul_kernel = Kernel( matmul_func_name, gemm_object, diff --git a/iron/operators/mha_prefill_lxl_sd/__init__.py b/iron/operators/mha_prefill_lxl_sd/__init__.py new file mode 100644 index 00000000..82f09a67 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py new file mode 100644 index 00000000..fe951829 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +A layer-by-layer (LxL) single-dispatch (SD) implementation of multi-head attention (MHA). +""" + +from iron.common.context import AIEContext +from iron.common.fusion import FusedMLIROperator +from iron.operators.gemm.op import AIEGEMM +from iron.operators.rope.op import AIERope +from iron.operators.strided_copy.op import AIEStridedCopy +from iron.operators.repeat.op import AIERepeat +from iron.operators.softmax.op import AIESoftmax +from iron.operators.transpose.op import AIETranspose +from iron.operators.elementwise_mul.op import AIEElementwiseMul +from iron.operators.elementwise_add.op import AIEElementwiseAdd + + +def _pick_tile_n(N, num_cols, max_tile_n=64): + tile_n = N // num_cols + while tile_n > max_tile_n: + tile_n //= 2 + assert N % (tile_n * num_cols) == 0 + return tile_n + + +def _build_core_ops(H, G, d, E, S, elf_ctx): + """Build core attention sub-ops and runlist (no projections/RoPE/GQA). + + Expects pre-processed inputs: + queries: (H, S, d) deinterleaved, contiguous per head + keys: (H, d, S) transposed and GQA-repeated + values: (H, S, d) GQA-repeated + """ + B = 2 # bytes per bf16 element + + gemm_scores = AIEGEMM( + M=S, K=d, N=S, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(S, 8), context=elf_ctx, + ) + scale = AIEElementwiseMul( + size=H * S * S, tile_size=S * S // 8, + num_aie_columns=8, context=elf_ctx, + ) + mask = AIEElementwiseAdd( + size=H * S * S, tile_size=S * S // 8, + num_aie_columns=8, context=elf_ctx, + ) + softmax = AIESoftmax( + rows=H * S, cols=S, num_aie_columns=1, num_channels=1, + rtp_vector_size=S, context=elf_ctx, + ) + gemm_context = AIEGEMM( + M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, + tile_n=16, context=elf_ctx, prio_accuracy=True, + ) + reinterleave = AIEStridedCopy( + input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, + output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, + input_buffer_size=H * S * d, output_buffer_size=S * H * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + gemm_output = AIEGEMM( + M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, + ) + + qh = S * d * B + kdS = d * S * B + kSd = S * d * B + sh = S * S * B + ch = S * d * B + + runlist = [ + *[(gemm_scores, + f"queries[{h*qh}:{(h+1)*qh}]", + f"keys[{h*kdS}:{(h+1)*kdS}]", + f"attn_scores[{h*sh}:{(h+1)*sh}]") + for h in range(H)], + (scale, "attn_scores", "attn_scale_factor", "attn_scores"), + (mask, "attn_scores", "causal_mask", "attn_scores_masked"), + (softmax, "attn_scores_masked", "attn_weights"), + *[(gemm_context, + f"attn_weights[{h*sh}:{(h+1)*sh}]", + f"values[{h*kSd}:{(h+1)*kSd}]", + f"attn_context[{h*ch}:{(h+1)*ch}]") + for h in range(H)], + (reinterleave, "attn_context", "context_interleaved"), + (gemm_output, "context_interleaved", "W_output", "attn_output"), + ] + + buffer_sizes = { + "queries": H * S * d * B, + "keys": H * d * S * B, + "values": H * S * d * B, + "attn_scores": H * S * S * B, + "attn_scores_masked": H * S * S * B, + "attn_weights": H * S * S * B, + "attn_context": H * S * d * B, + "context_interleaved": S * H * d * B, + } + + return runlist, buffer_sizes + + +class AIEAttentionPrefillFused(FusedMLIROperator): + """Fused attention prefill (core, no projections/RoPE). + + Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. + """ + + def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, + seq_len, context=None): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + elf_ctx = context or AIEContext() + runlist, buffer_sizes = _build_core_ops( + num_heads, num_kv_groups, head_dim, embedding_dim, seq_len, elf_ctx, + ) + + super().__init__( + name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s", + runlist=runlist, + input_args=["queries", "keys", "values", + "W_output", "attn_scale_factor", "causal_mask"], + output_args=["attn_output"], + buffer_sizes=buffer_sizes, + context=elf_ctx, + ) + + +class AIEAttentionPrefillProjectedFused(FusedMLIROperator): + """Fused attention prefill with Q/K/V projections and RoPE. + + Accepts raw input (S, E) and rope_angles (S, d). + """ + + def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, + seq_len, context=None): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + H, G, d, E, S = num_heads, num_kv_groups, head_dim, embedding_dim, seq_len + group_size = H // G + B = 2 + + elf_ctx = context or AIEContext() + + # ---- Projection + RoPE ---- + gemm_query = AIEGEMM( + M=S, K=E, N=H * d, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(H * d, 8), context=elf_ctx, + ) + gemm_kv = AIEGEMM( + M=S, K=E, N=G * d, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(G * d, 8), context=elf_ctx, + ) + rope_queries = AIERope(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) + rope_keys = AIERope(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) + + # ---- Deinterleave ---- + deinterleave_q = AIEStridedCopy( + input_sizes=(H, S, d), input_strides=(d, H * d, 1), input_offset=0, + output_sizes=(H, S, d), output_strides=(S * d, d, 1), output_offset=0, + input_buffer_size=S * H * d, output_buffer_size=H * S * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + deinterleave_kv = AIEStridedCopy( + input_sizes=(G, S, d), input_strides=(d, G * d, 1), input_offset=0, + output_sizes=(G, S, d), output_strides=(S * d, d, 1), output_offset=0, + input_buffer_size=S * G * d, output_buffer_size=G * S * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + + # ---- Transpose keys + GQA repeat ---- + transpose_keys = AIETranspose( + M=S, N=d, num_aie_columns=2, num_channels=1, + m=256, n=32, s=8, context=elf_ctx, + ) + repeat_kv = AIERepeat( + rows=G, cols=d * S, repeat=group_size, + transfer_size=d, context=elf_ctx, + ) + + kSd = S * d * B + kdS = d * S * B + + prefix_runlist = [ + (gemm_query, "input", "W_query", "queries_projected"), + (gemm_kv, "input", "W_key", "keys_projected"), + (gemm_kv, "input", "W_value", "values_projected"), + (rope_queries, "queries_projected", "rope_angles", "queries_roped"), + (rope_keys, "keys_projected", "rope_angles", "keys_roped"), + (deinterleave_q, "queries_roped", "queries"), + (deinterleave_kv, "keys_roped", "keys_deint"), + (deinterleave_kv, "values_projected", "values_deint"), + *[(transpose_keys, + f"keys_deint[{g*kSd}:{(g+1)*kSd}]", + f"keys_transposed[{g*kdS}:{(g+1)*kdS}]") + for g in range(G)], + (repeat_kv, "keys_transposed", "keys"), + (repeat_kv, "values_deint", "values"), + ] + prefix_buffer_sizes = { + "queries_projected": S * H * d * B, + "keys_projected": S * G * d * B, + "values_projected": S * G * d * B, + "queries_roped": S * H * d * B, + "keys_roped": S * G * d * B, + "keys_deint": G * S * d * B, + "values_deint": G * S * d * B, + "keys_transposed": G * d * S * B, + } + + core_runlist, core_buffer_sizes = _build_core_ops( + H, G, d, E, S, elf_ctx, + ) + + super().__init__( + name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s", + runlist=prefix_runlist + core_runlist, + input_args=["input", "rope_angles", "W_query", "W_key", "W_value", + "W_output", "attn_scale_factor", "causal_mask"], + output_args=["attn_output"], + buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes}, + context=elf_ctx, + ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py new file mode 100644 index 00000000..484f4d7f --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 + + +def apply_rope(x, lut): + """Apply Rotary Position Embedding using pre-computed cos/sin LUT. + + x: (rows, cols) — rows are (positions * heads) interleaved + lut: (angle_rows, cols) — interleaved [cos_0, sin_0, cos_1, sin_1, ...] + If angle_rows < rows, each angle row is reused for + (rows // angle_rows) consecutive input rows (block repetition). + Returns: (rows, cols) with RoPE applied (two-halves method) + """ + rows, cols = x.shape + angle_rows = lut.shape[0] + half = cols // 2 + + cos = lut[:, ::2] # (angle_rows, half) + sin = lut[:, 1::2] # (angle_rows, half) + + if angle_rows < rows: + # Block repetition: each angle row repeats for consecutive input rows + repeats = rows // angle_rows + cos = cos.repeat_interleave(repeats, dim=0) # (rows, half) + sin = sin.repeat_interleave(repeats, dim=0) # (rows, half) + + x1 = x[:, :half] + x2 = x[:, half:] + out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + return out + + +def generate_golden_reference( + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + seed=42, +): + """Generate golden reference for fused attention prefill. + + Parameters: + num_heads (H): number of query attention heads + num_kv_groups (G): number of KV heads (G=H for MHA, G i with -inf + causal_mask = torch.zeros(H * S, S, dtype=torch.bfloat16) + for h in range(H): + for i in range(S): + for j in range(S): + if j > i: + causal_mask[h * S + i, j] = torch.tensor(float("-inf")).to( + torch.bfloat16 + ) + + # ---- Step 1-3: Q/K/V projections ---- + queries_raw = x.float() @ W_query.float() # (S, H*d) + queries_raw = queries_raw.to(torch.bfloat16) + keys_raw = x.float() @ W_key.float() # (S, G*d) + keys_raw = keys_raw.to(torch.bfloat16) + values_raw = x.float() @ W_value.float() # (S, G*d) + values_raw = values_raw.to(torch.bfloat16) + + # ---- Step 4-5: RoPE ---- + # Q proj output is (S, H*d), viewed as (S*H, d) with heads interleaved: + # row layout: [pos0_head0, pos0_head1, ..., pos0_headH-1, pos1_head0, ...] + # RoPE angle_rows=S: row i uses angle row (i % S) = position index + queries_for_rope = queries_raw.reshape(S * H, d) + queries_roped = apply_rope(queries_for_rope, rope_angles) # (S*H, d) + + keys_for_rope = keys_raw.reshape(S * G, d) + keys_roped = apply_rope(keys_for_rope, rope_angles) # (S*G, d) + + # ---- Step 6: Deinterleave Q: (S*H, d) → (H, S, d) ---- + # Current layout: [pos0_h0, pos0_h1, ..., pos0_{H-1}, pos1_h0, ...] + # = (S, H, d) in memory; reshape and transpose to (H, S, d) + queries_deinterleaved = queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() # (H, S, d) + + # ---- Step 7: Deinterleave K: (S*G, d) → (G, S, d) then transpose to (G, d, S) ---- + keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) + + # ---- Step 8: Deinterleave V: (S, G*d) → (G, S, d) ---- + values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + + # ---- Step 9: GQA repeat ---- + if group_size > 1: + # Repeat keys and values: (G, ...) → (H, ...) + # Flatten to (G, d*S) / (G, S*d), repeat, reshape + keys_for_scores = keys_transposed.reshape(G, d * S).repeat_interleave( + group_size, dim=0 + ).reshape(H, d, S) + values_for_context = values_deinterleaved.reshape(G, S * d).repeat_interleave( + group_size, dim=0 + ).reshape(H, S, d) + else: + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) + + # ---- Step 10: Score GEMM per head ---- + # Q_head(S, d) @ K_head(d, S) → scores(S, S) + attn_scores = torch.zeros(H, S, S, dtype=torch.bfloat16) + for h in range(H): + attn_scores[h] = ( + queries_deinterleaved[h].float() @ keys_for_scores[h].float() + ).to(torch.bfloat16) + + # ---- Step 11: Scale ---- + attn_scores_scaled = (attn_scores.float() * scale).to(torch.bfloat16) + # ---- Step 12: Causal mask (add -inf) ---- + attn_scores_masked = attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() + attn_scores_masked = attn_scores_masked.to(torch.bfloat16) + + # ---- Step 13: Softmax ---- + attn_weights = torch.nn.functional.softmax( + attn_scores_masked.float().reshape(H, S, S), dim=-1 + ).to(torch.bfloat16) # (H, S, S) + + # ---- Step 14: Context GEMM per head ---- + # weights(S, S) @ values(S, d) → context(S, d) + attn_context = torch.zeros(H, S, d, dtype=torch.bfloat16) + for h in range(H): + attn_context[h] = ( + attn_weights[h].float() @ values_for_context[h].float() + ).to(torch.bfloat16) + + # ---- Step 15: Re-interleave context: (H, S, d) → (S, H*d) ---- + context_interleaved = attn_context.transpose(0, 1).contiguous().reshape(S, H * d) + + # ---- Step 16: Output projection ---- + attn_output = (context_interleaved.float() @ W_output.float()).to(torch.bfloat16) + + return { + "input": x, + "rope_angles": rope_angles, + "W_query": W_query, + "W_key": W_key, + "W_value": W_value, + "W_output": W_output, + "attn_scale_factor": attn_scale_factor, + "causal_mask": causal_mask, + "queries_raw": queries_raw, + "keys_raw": keys_raw, + "values_raw": values_raw, + "queries_roped": queries_roped, + "keys_roped": keys_roped, + "queries_deinterleaved": queries_deinterleaved, + "keys_deinterleaved": keys_deinterleaved, + "keys_transposed": keys_transposed, + "values_deinterleaved": values_deinterleaved, + "keys_for_scores": keys_for_scores, + "values_for_context": values_for_context, + "attn_scores": attn_scores, + "attn_scores_scaled": attn_scores_scaled, + "attn_scores_masked": attn_scores_masked, + "attn_weights": attn_weights, + "attn_context": attn_context, + "context_interleaved": context_interleaved, + "attn_output": attn_output, + } diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py new file mode 100644 index 00000000..ccf768d9 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import time + +import numpy as np +import pytest +import torch +from ml_dtypes import bfloat16 + +from iron.common.test_utils import verify_buffer +from iron.common.utils import torch_to_numpy + +from iron.operators.mha_prefill_lxl_sd.op import ( + AIEAttentionPrefillFused, + AIEAttentionPrefillProjectedFused, +) +from iron.operators.mha_prefill_lxl_sd.reference import generate_golden_reference + +REL_TOL = 0.08 +ABS_TOL = 2.0 +MAX_ERROR_RATE = 0.03 + + +def get_params(): + return [ + pytest.param(2, 2, 64, 256, 256, id="H2"), + pytest.param(32, 8, 64, 2048, 256, id="H32"), + ] + + +def _load_input(fc, name, tensor): + """Load a tensor into a named sub-buffer of the fused callable.""" + fc.get_buffer(name).view_as_np()[:] = torch_to_numpy(tensor).flatten() + + +def _get_scratch_tensor(fc, name, shape): + """Read a named buffer from the fused callable's scratch space.""" + fc.scratch_buffer.on = "npu" + fc.scratch_buffer.to("cpu") + sub = fc.get_buffer(name) + return np.frombuffer( + sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) + ).reshape(shape).astype(np.float32) + + +def _verify_output(fc, golden, H, d, S, E): + """Chain-consistent output verification shared by both test variants.""" + npu_context = torch.from_numpy( + _get_scratch_tensor(fc, "context_interleaved", (S, H * d)) + ).bfloat16() + chain_ref = (npu_context.float() @ golden["W_output"].float()).to(torch.bfloat16) + + fc.output_buffer.on = "npu" + fc.output_buffer.to("cpu") + output_np = fc.get_buffer("attn_output").view_as_np() + output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() + + errors = verify_buffer( + output, "attn_output", chain_ref.reshape(S, E), + rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" + + +def _core_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the core attention operator.""" + score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) + context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) + output_flops = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + return score_flops + context_flops + output_flops + + +def _projected_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the projected attention operator.""" + query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) + kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each + return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + + +def _print_metrics(label, elapsed_s, flops): + """Print latency and throughput metrics.""" + gflops = flops / elapsed_s / 1e9 + print(f" {label}: {elapsed_s*1e3:.2f} ms, {gflops:.2f} GFLOPS") + + +# --------------------------------------------------------------------------- +# Core attention tests (pre-projected Q, K, V) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd(H, G, d, E, S): + """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM -> output.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + t0 = time.perf_counter() + fc() + elapsed = time.perf_counter() - t0 + _print_metrics("core", elapsed, _core_gemm_flops(H, G, d, E, S)) + + _verify_output(fc, golden, H, d, S, E) + + +# --------------------------------------------------------------------------- +# Projected attention tests (with Q/K/V projections + RoPE) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_attention_prefill_projected_fused(H, G, d, E, S): + """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillProjectedFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "input", golden["input"]) + _load_input(fc, "rope_angles", golden["rope_angles"]) + _load_input(fc, "W_query", golden["W_query"]) + _load_input(fc, "W_key", golden["W_key"]) + _load_input(fc, "W_value", golden["W_value"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + t0 = time.perf_counter() + fc() + elapsed = time.perf_counter() - t0 + _print_metrics("projected", elapsed, _projected_gemm_flops(H, G, d, E, S)) + + _verify_output(fc, golden, H, d, S, E) + + +# --------------------------------------------------------------------------- +# Intermediate checks (extensive, not run by default) +# --------------------------------------------------------------------------- + +INTERMEDIATE_CHECKS = [ + ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S)), + ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S)), + ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S)), + ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d)), + ("context_interleaved", "context_interleaved", lambda H, G, S, d: (S, H * d)), +] + + +@pytest.mark.extensive +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): + """Check intermediate buffers of core attention (for debugging).""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + for buf_name, golden_key, shape_fn in INTERMEDIATE_CHECKS: + shape = shape_fn(H, G, S, d) + actual = _get_scratch_tensor(fc, buf_name, shape) + expected = golden[golden_key].float().numpy().reshape(shape) + diff = np.abs(actual - expected) + print( + f" [{buf_name}] shape={shape} " + f"nan={int(np.isnan(actual).sum())} " + f"max_abs_err={diff.max():.4f} mean_abs_err={diff.mean():.6f}" + ) From dfe5f888014df3e8ba6cee524d3d1a1e9d0d30d7 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 17:21:37 -0600 Subject: [PATCH 02/22] add GPT-2 sizes as test cases, make causal mask an option --- iron/common/fusion.py | 3 + iron/operators/mha_prefill_lxl_sd/op.py | 60 ++++++++++++----- iron/operators/mha_prefill_lxl_sd/test.py | 81 ++++++++++++++++++----- 3 files changed, 112 insertions(+), 32 deletions(-) diff --git a/iron/common/fusion.py b/iron/common/fusion.py index 99219848..f8d1b9cd 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -5,6 +5,7 @@ import ml_dtypes import pyxrt import ctypes +import time from . import compilation as comp from .base import AIEOperatorBase, MLIROperator from .utils import XRTSubBuffer @@ -290,8 +291,10 @@ def __call__(self, *args): for i, arg in enumerate(args): assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" run.set_arg(i, arg) + t0 = time.perf_counter() run.start() ret_code = run.wait() + self.last_elapsed = time.perf_counter() - t0 if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: raise RuntimeError(f"Kernel execution failed with return code {ret_code}") diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index fe951829..b939e7e2 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -25,13 +25,15 @@ def _pick_tile_n(N, num_cols, max_tile_n=64): return tile_n -def _build_core_ops(H, G, d, E, S, elf_ctx): +def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): """Build core attention sub-ops and runlist (no projections/RoPE/GQA). Expects pre-processed inputs: queries: (H, S, d) deinterleaved, contiguous per head keys: (H, d, S) transposed and GQA-repeated values: (H, S, d) GQA-repeated + + If causal_mask=False, the elementwise-add masking step is omitted. """ B = 2 # bytes per bf16 element @@ -43,10 +45,11 @@ def _build_core_ops(H, G, d, E, S, elf_ctx): size=H * S * S, tile_size=S * S // 8, num_aie_columns=8, context=elf_ctx, ) - mask = AIEElementwiseAdd( - size=H * S * S, tile_size=S * S // 8, - num_aie_columns=8, context=elf_ctx, - ) + if causal_mask: + mask = AIEElementwiseAdd( + size=H * S * S, tile_size=S * S // 8, + num_aie_columns=8, context=elf_ctx, + ) softmax = AIESoftmax( rows=H * S, cols=S, num_aie_columns=1, num_channels=1, rtp_vector_size=S, context=elf_ctx, @@ -79,8 +82,19 @@ def _build_core_ops(H, G, d, E, S, elf_ctx): f"attn_scores[{h*sh}:{(h+1)*sh}]") for h in range(H)], (scale, "attn_scores", "attn_scale_factor", "attn_scores"), - (mask, "attn_scores", "causal_mask", "attn_scores_masked"), - (softmax, "attn_scores_masked", "attn_weights"), + ] + + if causal_mask: + runlist += [ + (mask, "attn_scores", "causal_mask", "attn_scores_masked"), + (softmax, "attn_scores_masked", "attn_weights"), + ] + else: + runlist += [ + (softmax, "attn_scores", "attn_weights"), + ] + + runlist += [ *[(gemm_context, f"attn_weights[{h*sh}:{(h+1)*sh}]", f"values[{h*kSd}:{(h+1)*kSd}]", @@ -95,11 +109,12 @@ def _build_core_ops(H, G, d, E, S, elf_ctx): "keys": H * d * S * B, "values": H * S * d * B, "attn_scores": H * S * S * B, - "attn_scores_masked": H * S * S * B, "attn_weights": H * S * S * B, "attn_context": H * S * d * B, "context_interleaved": S * H * d * B, } + if causal_mask: + buffer_sizes["attn_scores_masked"] = H * S * S * B return runlist, buffer_sizes @@ -111,7 +126,7 @@ class AIEAttentionPrefillFused(FusedMLIROperator): """ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, context=None): + seq_len, causal_mask=True, context=None): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -126,13 +141,19 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() runlist, buffer_sizes = _build_core_ops( num_heads, num_kv_groups, head_dim, embedding_dim, seq_len, elf_ctx, + causal_mask=causal_mask, ) + mask_suffix = "_causal" if causal_mask else "_nomask" + input_args = ["queries", "keys", "values", + "W_output", "attn_scale_factor"] + if causal_mask: + input_args.append("causal_mask") + super().__init__( - name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s", + name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s{mask_suffix}", runlist=runlist, - input_args=["queries", "keys", "values", - "W_output", "attn_scale_factor", "causal_mask"], + input_args=input_args, output_args=["attn_output"], buffer_sizes=buffer_sizes, context=elf_ctx, @@ -146,7 +167,7 @@ class AIEAttentionPrefillProjectedFused(FusedMLIROperator): """ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, context=None): + seq_len, causal_mask=True, context=None): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -231,14 +252,19 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } core_runlist, core_buffer_sizes = _build_core_ops( - H, G, d, E, S, elf_ctx, + H, G, d, E, S, elf_ctx, causal_mask=causal_mask, ) + mask_suffix = "_causal" if causal_mask else "_nomask" + input_args = ["input", "rope_angles", "W_query", "W_key", "W_value", + "W_output", "attn_scale_factor"] + if causal_mask: + input_args.append("causal_mask") + super().__init__( - name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s", + name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s{mask_suffix}", runlist=prefix_runlist + core_runlist, - input_args=["input", "rope_angles", "W_query", "W_key", "W_value", - "W_output", "attn_scale_factor", "causal_mask"], + input_args=input_args, output_args=["attn_output"], buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes}, context=elf_ctx, diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index ccf768d9..1bfc0953 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import time - import numpy as np import pytest import torch @@ -25,10 +23,23 @@ def get_params(): return [ pytest.param(2, 2, 64, 256, 256, id="H2"), - pytest.param(32, 8, 64, 2048, 256, id="H32"), + pytest.param(32, 8, 64, 2048, 256, id="Llama3.2-256seq"), + pytest.param(12, 12, 64, 768, 256, id="GPT2-Small-256seq"), ] +def get_benchmark_params(): + """GPT-2 Small across sequence lengths 256..32768, with/without causal mask.""" + params = [] + S = 256 + while S <= 32768: + for mask in [True, False]: + tag = "causal" if mask else "nomask" + params.append(pytest.param(12, 12, 64, 768, S, mask, id=f"GPT2-S{S}-{tag}")) + S *= 2 + return params + + def _load_input(fc, name, tensor): """Load a tensor into a named sub-buffer of the fused callable.""" fc.get_buffer(name).view_as_np()[:] = torch_to_numpy(tensor).flatten() @@ -78,16 +89,14 @@ def _projected_gemm_flops(H, G, d, E, S): return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) -def _print_metrics(label, elapsed_s, flops): - """Print latency and throughput metrics.""" - gflops = flops / elapsed_s / 1e9 - print(f" {label}: {elapsed_s*1e3:.2f} ms, {gflops:.2f} GFLOPS") - - # --------------------------------------------------------------------------- # Core attention tests (pre-projected Q, K, V) # --------------------------------------------------------------------------- +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) @pytest.mark.parametrize("H,G,d,E,S", get_params()) def test_mha_pefill_lxl_sd(H, G, d, E, S): """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM -> output.""" @@ -104,10 +113,12 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) - t0 = time.perf_counter() fc() - elapsed = time.perf_counter() - t0 - _print_metrics("core", elapsed, _core_gemm_flops(H, G, d, E, S)) + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") _verify_output(fc, golden, H, d, S, E) @@ -116,6 +127,10 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): # Projected attention tests (with Q/K/V projections + RoPE) # --------------------------------------------------------------------------- +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) @pytest.mark.parametrize("H,G,d,E,S", get_params()) def test_attention_prefill_projected_fused(H, G, d, E, S): """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" @@ -134,14 +149,50 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) - t0 = time.perf_counter() fc() - elapsed = time.perf_counter() - t0 - _print_metrics("projected", elapsed, _projected_gemm_flops(H, G, d, E, S)) + + latency_us = fc.last_elapsed * 1e6 + gflops = _projected_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") _verify_output(fc, golden, H, d, S, E) +# --------------------------------------------------------------------------- +# Benchmark: GPT-2 Small core MHA across sequence lengths, +/- causal mask +# --------------------------------------------------------------------------- + +@pytest.mark.benchmark +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S,causal", get_benchmark_params()) +def test_mha_prefill_benchmark(H, G, d, E, S, causal): + """Benchmark core MHA for GPT-2 Small across sequence lengths.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillFused(H, G, d, E, S, causal_mask=causal) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + if causal: + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + # --------------------------------------------------------------------------- # Intermediate checks (extensive, not run by default) # --------------------------------------------------------------------------- From 71237a2a89f319353b7ae76f5fc04e617a8460e8 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 17:27:09 -0600 Subject: [PATCH 03/22] as benchmarked --- .../operators/mha_prefill_lxl_sd/reference.py | 148 +++++++----------- pytest.ini | 1 + 2 files changed, 54 insertions(+), 95 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py index 484f4d7f..3d1c21ed 100644 --- a/iron/operators/mha_prefill_lxl_sd/reference.py +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -2,36 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import numpy as np -from ml_dtypes import bfloat16 +from iron.operators.rope.rope_utils import apply_rope as _apply_rope_4d -def apply_rope(x, lut): - """Apply Rotary Position Embedding using pre-computed cos/sin LUT. - x: (rows, cols) — rows are (positions * heads) interleaved - lut: (angle_rows, cols) — interleaved [cos_0, sin_0, cos_1, sin_1, ...] - If angle_rows < rows, each angle row is reused for - (rows // angle_rows) consecutive input rows (block repetition). - Returns: (rows, cols) with RoPE applied (two-halves method) - """ - rows, cols = x.shape - angle_rows = lut.shape[0] - half = cols // 2 - - cos = lut[:, ::2] # (angle_rows, half) - sin = lut[:, 1::2] # (angle_rows, half) - - if angle_rows < rows: - # Block repetition: each angle row repeats for consecutive input rows - repeats = rows // angle_rows - cos = cos.repeat_interleave(repeats, dim=0) # (rows, half) - sin = sin.repeat_interleave(repeats, dim=0) # (rows, half) - - x1 = x[:, :half] - x2 = x[:, half:] - out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) - return out +def _bf16_matmul(a, b): + """(float32 matmul) → bfloat16, matching NPU accumulation.""" + return (a.float() @ b.float()).to(torch.bfloat16) def generate_golden_reference( @@ -75,64 +52,49 @@ def generate_golden_reference( rope_angles = rope_angles.to(torch.bfloat16) # Weight matrices (transposed for GEMM: input @ W → output) - # Q proj: (S, E) @ (E, H*d) → (S, H*d) W_query = torch.randn(E, H * d, dtype=torch.bfloat16) * val_range - # K proj: (S, E) @ (E, G*d) → (S, G*d) W_key = torch.randn(E, G * d, dtype=torch.bfloat16) * val_range - # V proj: (S, E) @ (E, G*d) → (S, G*d) W_value = torch.randn(E, G * d, dtype=torch.bfloat16) * val_range - # Output proj: (S, H*d) @ (H*d, E) → (S, E) W_output = torch.randn(H * d, E, dtype=torch.bfloat16) * val_range - # Scale factor: 1/sqrt(d), broadcast to (H*S, S) + # Scale factor: 1/sqrt(d), broadcast to (H*S*S,) scale = 1.0 / (d ** 0.5) attn_scale_factor = torch.full((H * S * S,), scale, dtype=torch.bfloat16) - # Causal mask: (H*S, S) — 0 for valid positions, -inf for future positions - # Row (h*S + i) attends to positions 0..i, so mask col j > i with -inf + # Causal mask: (H*S, S) — 0 for valid positions, -inf for future causal_mask = torch.zeros(H * S, S, dtype=torch.bfloat16) for h in range(H): for i in range(S): - for j in range(S): - if j > i: - causal_mask[h * S + i, j] = torch.tensor(float("-inf")).to( - torch.bfloat16 - ) - - # ---- Step 1-3: Q/K/V projections ---- - queries_raw = x.float() @ W_query.float() # (S, H*d) - queries_raw = queries_raw.to(torch.bfloat16) - keys_raw = x.float() @ W_key.float() # (S, G*d) - keys_raw = keys_raw.to(torch.bfloat16) - values_raw = x.float() @ W_value.float() # (S, G*d) - values_raw = values_raw.to(torch.bfloat16) - - # ---- Step 4-5: RoPE ---- - # Q proj output is (S, H*d), viewed as (S*H, d) with heads interleaved: - # row layout: [pos0_head0, pos0_head1, ..., pos0_headH-1, pos1_head0, ...] - # RoPE angle_rows=S: row i uses angle row (i % S) = position index - queries_for_rope = queries_raw.reshape(S * H, d) - queries_roped = apply_rope(queries_for_rope, rope_angles) # (S*H, d) - - keys_for_rope = keys_raw.reshape(S * G, d) - keys_roped = apply_rope(keys_for_rope, rope_angles) # (S*G, d) - - # ---- Step 6: Deinterleave Q: (S*H, d) → (H, S, d) ---- - # Current layout: [pos0_h0, pos0_h1, ..., pos0_{H-1}, pos1_h0, ...] - # = (S, H, d) in memory; reshape and transpose to (H, S, d) + for j in range(i + 1, S): + causal_mask[h * S + i, j] = torch.tensor(float("-inf")).to( + torch.bfloat16 + ) + + # ---- Q/K/V projections ---- + queries_raw = _bf16_matmul(x, W_query) # (S, H*d) + keys_raw = _bf16_matmul(x, W_key) # (S, G*d) + values_raw = _bf16_matmul(x, W_value) # (S, G*d) + + # ---- RoPE (reuses rope_utils.apply_rope with 4D interface) ---- + # Reshape interleaved (S, N*d) → (1, N, S, d) for rope_utils + queries_roped = _apply_rope_4d( + queries_raw.reshape(S, H, d).permute(1, 0, 2).unsqueeze(0), # (1, H, S, d) + rope_angles, + ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * H, d) # (S*H, d) + + keys_roped = _apply_rope_4d( + keys_raw.reshape(S, G, d).permute(1, 0, 2).unsqueeze(0), # (1, G, S, d) + rope_angles, + ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * G, d) # (S*G, d) + + # ---- Deinterleave Q/K/V ---- queries_deinterleaved = queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() # (H, S, d) + keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) + values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - # ---- Step 7: Deinterleave K: (S*G, d) → (G, S, d) then transpose to (G, d, S) ---- - keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) - - # ---- Step 8: Deinterleave V: (S, G*d) → (G, S, d) ---- - values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - - # ---- Step 9: GQA repeat ---- + # ---- GQA repeat ---- if group_size > 1: - # Repeat keys and values: (G, ...) → (H, ...) - # Flatten to (G, d*S) / (G, S*d), repeat, reshape keys_for_scores = keys_transposed.reshape(G, d * S).repeat_interleave( group_size, dim=0 ).reshape(H, d, S) @@ -140,41 +102,37 @@ def generate_golden_reference( group_size, dim=0 ).reshape(H, S, d) else: - keys_for_scores = keys_transposed # (H, d, S) - values_for_context = values_deinterleaved # (H, S, d) + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) - # ---- Step 10: Score GEMM per head ---- - # Q_head(S, d) @ K_head(d, S) → scores(S, S) - attn_scores = torch.zeros(H, S, S, dtype=torch.bfloat16) - for h in range(H): - attn_scores[h] = ( - queries_deinterleaved[h].float() @ keys_for_scores[h].float() - ).to(torch.bfloat16) + # ---- Score GEMM per head ---- + attn_scores = torch.stack( + [_bf16_matmul(queries_deinterleaved[h], keys_for_scores[h]) for h in range(H)] + ) # (H, S, S) - # ---- Step 11: Scale ---- + # ---- Scale ---- attn_scores_scaled = (attn_scores.float() * scale).to(torch.bfloat16) - # ---- Step 12: Causal mask (add -inf) ---- - attn_scores_masked = attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() - attn_scores_masked = attn_scores_masked.to(torch.bfloat16) - # ---- Step 13: Softmax ---- + # ---- Causal mask ---- + attn_scores_masked = ( + attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() + ).to(torch.bfloat16) + + # ---- Softmax ---- attn_weights = torch.nn.functional.softmax( attn_scores_masked.float().reshape(H, S, S), dim=-1 ).to(torch.bfloat16) # (H, S, S) - # ---- Step 14: Context GEMM per head ---- - # weights(S, S) @ values(S, d) → context(S, d) - attn_context = torch.zeros(H, S, d, dtype=torch.bfloat16) - for h in range(H): - attn_context[h] = ( - attn_weights[h].float() @ values_for_context[h].float() - ).to(torch.bfloat16) + # ---- Context GEMM per head ---- + attn_context = torch.stack( + [_bf16_matmul(attn_weights[h], values_for_context[h]) for h in range(H)] + ) # (H, S, d) - # ---- Step 15: Re-interleave context: (H, S, d) → (S, H*d) ---- + # ---- Re-interleave context: (H, S, d) → (S, H*d) ---- context_interleaved = attn_context.transpose(0, 1).contiguous().reshape(S, H * d) - # ---- Step 16: Output projection ---- - attn_output = (context_interleaved.float() @ W_output.float()).to(torch.bfloat16) + # ---- Output projection ---- + attn_output = _bf16_matmul(context_interleaved, W_output) return { "input": x, diff --git a/pytest.ini b/pytest.ini index 44f08847..a3566ee2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,4 +9,5 @@ python_functions = test_* markers = extensive: extensive test suite (deselect with '-m "not extensive"') supported_devices(*devices): mark test as only supported on the given devices (e.g. "npu1", "npu2"). All devices supported by default. + benchmark: benchmark-only tests (select with '-m benchmark') addopts = -v --tb=short --import-mode=importlib From 92e6607cfee819238f8bf96f5e409eac856227bf Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 17:57:04 -0600 Subject: [PATCH 04/22] fix DMA dimension overflow --- iron/operators/mha_prefill_lxl_sd/op.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index b939e7e2..d2af422c 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -59,8 +59,10 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): tile_n=16, context=elf_ctx, prio_accuracy=True, ) reinterleave = AIEStridedCopy( - input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, - output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, + #input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, + input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, + #output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, + output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, input_buffer_size=H * S * d, output_buffer_size=S * H * d, transfer_size=S * d, num_aie_channels=1, context=elf_ctx, ) From 43e4d0782c19c596708ee60553bfaf866c6d086a Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 18:23:12 -0600 Subject: [PATCH 05/22] create separate attn_scores_scaled buffer --- iron/operators/mha_prefill_lxl_sd/op.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index d2af422c..f0e7e3fb 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -83,17 +83,17 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): f"keys[{h*kdS}:{(h+1)*kdS}]", f"attn_scores[{h*sh}:{(h+1)*sh}]") for h in range(H)], - (scale, "attn_scores", "attn_scale_factor", "attn_scores"), + (scale, "attn_scores", "attn_scale_factor", "attn_scores_scaled"), ] if causal_mask: runlist += [ - (mask, "attn_scores", "causal_mask", "attn_scores_masked"), + (mask, "attn_scores_scaled", "causal_mask", "attn_scores_masked"), (softmax, "attn_scores_masked", "attn_weights"), ] else: runlist += [ - (softmax, "attn_scores", "attn_weights"), + (softmax, "attn_scores_scaled", "attn_weights"), ] runlist += [ @@ -111,6 +111,7 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): "keys": H * d * S * B, "values": H * S * d * B, "attn_scores": H * S * S * B, + "attn_scores_scaled": H * S * S * B, "attn_weights": H * S * S * B, "attn_context": H * S * d * B, "context_interleaved": S * H * d * B, From af752104eb2ee033b138076df4d15754b84ec7c2 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 7 Apr 2026 13:57:16 -0600 Subject: [PATCH 06/22] move output GEMM out of core MHA --- iron/operators/mha_prefill_lxl_sd/op.py | 53 +++++++++++++---------- iron/operators/mha_prefill_lxl_sd/test.py | 46 +++++++++++++------- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index f0e7e3fb..578a63bc 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -25,7 +25,7 @@ def _pick_tile_n(N, num_cols, max_tile_n=64): return tile_n -def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): +def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): """Build core attention sub-ops and runlist (no projections/RoPE/GQA). Expects pre-processed inputs: @@ -33,6 +33,9 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): keys: (H, d, S) transposed and GQA-repeated values: (H, S, d) GQA-repeated + Produces: + attn_context: (H, S, d) — per-head context vectors + If causal_mask=False, the elementwise-add masking step is omitted. """ B = 2 # bytes per bf16 element @@ -58,18 +61,6 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, tile_n=16, context=elf_ctx, prio_accuracy=True, ) - reinterleave = AIEStridedCopy( - #input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, - input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, - #output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, - output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, - input_buffer_size=H * S * d, output_buffer_size=S * H * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, - ) - gemm_output = AIEGEMM( - M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, - ) qh = S * d * B kdS = d * S * B @@ -102,8 +93,6 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): f"values[{h*kSd}:{(h+1)*kSd}]", f"attn_context[{h*ch}:{(h+1)*ch}]") for h in range(H)], - (reinterleave, "attn_context", "context_interleaved"), - (gemm_output, "context_interleaved", "W_output", "attn_output"), ] buffer_sizes = { @@ -114,7 +103,6 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): "attn_scores_scaled": H * S * S * B, "attn_weights": H * S * S * B, "attn_context": H * S * d * B, - "context_interleaved": S * H * d * B, } if causal_mask: buffer_sizes["attn_scores_masked"] = H * S * S * B @@ -143,13 +131,12 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() runlist, buffer_sizes = _build_core_ops( - num_heads, num_kv_groups, head_dim, embedding_dim, seq_len, elf_ctx, + num_heads, num_kv_groups, head_dim, seq_len, elf_ctx, causal_mask=causal_mask, ) mask_suffix = "_causal" if causal_mask else "_nomask" - input_args = ["queries", "keys", "values", - "W_output", "attn_scale_factor"] + input_args = ["queries", "keys", "values", "attn_scale_factor"] if causal_mask: input_args.append("causal_mask") @@ -157,7 +144,7 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s{mask_suffix}", runlist=runlist, input_args=input_args, - output_args=["attn_output"], + output_args=["attn_context"], buffer_sizes=buffer_sizes, context=elf_ctx, ) @@ -255,9 +242,29 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } core_runlist, core_buffer_sizes = _build_core_ops( - H, G, d, E, S, elf_ctx, causal_mask=causal_mask, + H, G, d, S, elf_ctx, causal_mask=causal_mask, + ) + + # ---- Reinterleave + output projection ---- + reinterleave = AIEStridedCopy( + input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, + output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, + input_buffer_size=H * S * d, output_buffer_size=S * H * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + gemm_output = AIEGEMM( + M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, ) + suffix_runlist = [ + (reinterleave, "attn_context", "context_interleaved"), + (gemm_output, "context_interleaved", "W_output", "attn_output"), + ] + suffix_buffer_sizes = { + "context_interleaved": S * H * d * B, + } + mask_suffix = "_causal" if causal_mask else "_nomask" input_args = ["input", "rope_angles", "W_query", "W_key", "W_value", "W_output", "attn_scale_factor"] @@ -266,9 +273,9 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, super().__init__( name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s{mask_suffix}", - runlist=prefix_runlist + core_runlist, + runlist=prefix_runlist + core_runlist + suffix_runlist, input_args=input_args, output_args=["attn_output"], - buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes}, + buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes, **suffix_buffer_sizes}, context=elf_ctx, ) diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 1bfc0953..94549d7f 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -55,6 +55,16 @@ def _get_scratch_tensor(fc, name, shape): ).reshape(shape).astype(np.float32) +def _get_output_tensor(fc, name, shape): + """Read a named buffer from the fused callable's output space.""" + fc.output_buffer.on = "npu" + fc.output_buffer.to("cpu") + sub = fc.get_buffer(name) + return np.frombuffer( + sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) + ).reshape(shape).astype(np.float32) + + def _verify_output(fc, golden, H, d, S, E): """Chain-consistent output verification shared by both test variants.""" npu_context = torch.from_numpy( @@ -78,15 +88,15 @@ def _core_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the core attention operator.""" score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) - output_flops = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) - return score_flops + context_flops + output_flops + return score_flops + context_flops def _projected_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the projected attention operator.""" query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each - return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj # --------------------------------------------------------------------------- @@ -99,7 +109,7 @@ def _projected_gemm_flops(H, G, d, E, S): ) @pytest.mark.parametrize("H,G,d,E,S", get_params()) def test_mha_pefill_lxl_sd(H, G, d, E, S): - """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM -> output.""" + """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM.""" golden = generate_golden_reference(H, G, d, E, S) op = AIEAttentionPrefillFused(H, G, d, E, S) @@ -109,7 +119,6 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "W_output", golden["W_output"]) _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) @@ -120,7 +129,14 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): print(f"\nLatency (us): {latency_us:.1f}") print(f"Throughput: {gflops:.6e} GFLOP/s") - _verify_output(fc, golden, H, d, S, E) + actual = _get_output_tensor(fc, "attn_context", (H, S, d)) + expected = golden["attn_context"].float().numpy().reshape(H, S, d) + errors = verify_buffer( + torch.from_numpy(actual).bfloat16(), "attn_context", + torch.from_numpy(expected).bfloat16().reshape(H, S, d), + rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" # --------------------------------------------------------------------------- @@ -180,7 +196,6 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "W_output", golden["W_output"]) _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) if causal: _load_input(fc, "causal_mask", golden["causal_mask"]) @@ -198,11 +213,10 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): # --------------------------------------------------------------------------- INTERMEDIATE_CHECKS = [ - ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S)), - ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S)), - ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S)), - ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d)), - ("context_interleaved", "context_interleaved", lambda H, G, S, d: (S, H * d)), + ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d), "output"), ] @@ -219,15 +233,17 @@ def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "W_output", golden["W_output"]) _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) fc() - for buf_name, golden_key, shape_fn in INTERMEDIATE_CHECKS: + for buf_name, golden_key, shape_fn, buf_type in INTERMEDIATE_CHECKS: shape = shape_fn(H, G, S, d) - actual = _get_scratch_tensor(fc, buf_name, shape) + if buf_type == "output": + actual = _get_output_tensor(fc, buf_name, shape) + else: + actual = _get_scratch_tensor(fc, buf_name, shape) expected = golden[golden_key].float().numpy().reshape(shape) diff = np.abs(actual - expected) print( From ee87e94cf57a08abc70478b1130b573907991f67 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 7 Apr 2026 14:56:33 -0600 Subject: [PATCH 07/22] remove symbol renaming after rebase to use link_with, other fixes --- iron/common/compilation/base.py | 1 + iron/common/fusion.py | 12 ++--- iron/operators/mha_prefill_lxl_sd/op.py | 50 +++++++++---------- .../operators/mha_prefill_lxl_sd/reference.py | 14 +++++- iron/operators/mha_prefill_lxl_sd/test.py | 35 ++++++------- 5 files changed, 56 insertions(+), 56 deletions(-) diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index d6d17a64..0d8c1613 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -502,6 +502,7 @@ def compile(self, graph): str(self.aiecc_path), "-v", "-j1", + "--dynamic-objFifos", "--no-compile-host", "--no-xchesscc", "--no-xbridge", diff --git a/iron/common/fusion.py b/iron/common/fusion.py index f8d1b9cd..292b26e9 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -43,8 +43,7 @@ def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators. Returns: - List of KernelObjectArtifact instances from all unique child operators, - with filenames and symbol prefixes disambiguated per operator index. + List of KernelObjectArtifact instances from all unique child operators. """ kernel_artifacts = [] seen: dict[int, object] = {} @@ -53,9 +52,6 @@ def get_kernel_artifacts(self): ] for idx, op in enumerate(unique_operators): objs = op.get_kernel_artifacts() - for obj in objs: - obj.filename = f"op{idx}_{obj.filename}" - obj.prefix_symbols = f"op{idx}_" kernel_artifacts.extend(objs) return kernel_artifacts @@ -83,8 +79,6 @@ def get_mlir_artifact(self): ] for idx, op in enumerate(unique_operators): mlir_artifact = op.get_mlir_artifact() - if len(op.get_kernel_artifacts()) > 0: - mlir_artifact.generator.kwargs["func_prefix"] = f"op{idx}_" op_name = f"op{idx}_{op.__class__.__name__}" op_names[id(op)] = op_name operator_mlir_map[op_name] = mlir_artifact @@ -374,10 +368,10 @@ def get_buffer(self, buffer_name): return sub_buffer def __call__(self): - self.input_buffer.to("npu") + self.input_buffer._sync_to_device() super().__call__( self.input_buffer.buffer_object(), self.output_buffer.buffer_object(), self.scratch_buffer.buffer_object(), ) - self.output_buffer.to("cpu") + self.output_buffer._sync_from_device() diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 578a63bc..b65a7171 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -7,14 +7,14 @@ from iron.common.context import AIEContext from iron.common.fusion import FusedMLIROperator -from iron.operators.gemm.op import AIEGEMM -from iron.operators.rope.op import AIERope -from iron.operators.strided_copy.op import AIEStridedCopy -from iron.operators.repeat.op import AIERepeat -from iron.operators.softmax.op import AIESoftmax -from iron.operators.transpose.op import AIETranspose -from iron.operators.elementwise_mul.op import AIEElementwiseMul -from iron.operators.elementwise_add.op import AIEElementwiseAdd +from iron.operators.gemm.op import GEMM +from iron.operators.rope.op import RoPE +from iron.operators.strided_copy.op import StridedCopy +from iron.operators.repeat.op import Repeat +from iron.operators.softmax.op import Softmax +from iron.operators.transpose.op import Transpose +from iron.operators.elementwise_mul.op import ElementwiseMul +from iron.operators.elementwise_add.op import ElementwiseAdd def _pick_tile_n(N, num_cols, max_tile_n=64): @@ -40,24 +40,24 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): """ B = 2 # bytes per bf16 element - gemm_scores = AIEGEMM( + gemm_scores = GEMM( M=S, K=d, N=S, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(S, 8), context=elf_ctx, ) - scale = AIEElementwiseMul( + scale = ElementwiseMul( size=H * S * S, tile_size=S * S // 8, num_aie_columns=8, context=elf_ctx, ) if causal_mask: - mask = AIEElementwiseAdd( + mask = ElementwiseAdd( size=H * S * S, tile_size=S * S // 8, num_aie_columns=8, context=elf_ctx, ) - softmax = AIESoftmax( + softmax = Softmax( rows=H * S, cols=S, num_aie_columns=1, num_channels=1, rtp_vector_size=S, context=elf_ctx, ) - gemm_context = AIEGEMM( + gemm_context = GEMM( M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, tile_n=16, context=elf_ctx, prio_accuracy=True, ) @@ -110,7 +110,7 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): return runlist, buffer_sizes -class AIEAttentionPrefillFused(FusedMLIROperator): +class AttentionPrefillFused(FusedMLIROperator): """Fused attention prefill (core, no projections/RoPE). Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. @@ -150,7 +150,7 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, ) -class AIEAttentionPrefillProjectedFused(FusedMLIROperator): +class AttentionPrefillProjectedFused(FusedMLIROperator): """Fused attention prefill with Q/K/V projections and RoPE. Accepts raw input (S, E) and rope_angles (S, d). @@ -176,25 +176,25 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() # ---- Projection + RoPE ---- - gemm_query = AIEGEMM( + gemm_query = GEMM( M=S, K=E, N=H * d, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(H * d, 8), context=elf_ctx, ) - gemm_kv = AIEGEMM( + gemm_kv = GEMM( M=S, K=E, N=G * d, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(G * d, 8), context=elf_ctx, ) - rope_queries = AIERope(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) - rope_keys = AIERope(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) + rope_queries = RoPE(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) + rope_keys = RoPE(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) # ---- Deinterleave ---- - deinterleave_q = AIEStridedCopy( + deinterleave_q = StridedCopy( input_sizes=(H, S, d), input_strides=(d, H * d, 1), input_offset=0, output_sizes=(H, S, d), output_strides=(S * d, d, 1), output_offset=0, input_buffer_size=S * H * d, output_buffer_size=H * S * d, transfer_size=S * d, num_aie_channels=1, context=elf_ctx, ) - deinterleave_kv = AIEStridedCopy( + deinterleave_kv = StridedCopy( input_sizes=(G, S, d), input_strides=(d, G * d, 1), input_offset=0, output_sizes=(G, S, d), output_strides=(S * d, d, 1), output_offset=0, input_buffer_size=S * G * d, output_buffer_size=G * S * d, @@ -202,11 +202,11 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, ) # ---- Transpose keys + GQA repeat ---- - transpose_keys = AIETranspose( + transpose_keys = Transpose( M=S, N=d, num_aie_columns=2, num_channels=1, m=256, n=32, s=8, context=elf_ctx, ) - repeat_kv = AIERepeat( + repeat_kv = Repeat( rows=G, cols=d * S, repeat=group_size, transfer_size=d, context=elf_ctx, ) @@ -246,13 +246,13 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, ) # ---- Reinterleave + output projection ---- - reinterleave = AIEStridedCopy( + reinterleave = StridedCopy( input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, input_buffer_size=H * S * d, output_buffer_size=S * H * d, transfer_size=S * d, num_aie_channels=1, context=elf_ctx, ) - gemm_output = AIEGEMM( + gemm_output = GEMM( M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py index 3d1c21ed..92cba961 100644 --- a/iron/operators/mha_prefill_lxl_sd/reference.py +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -3,7 +3,19 @@ import torch -from iron.operators.rope.rope_utils import apply_rope as _apply_rope_4d + +def _apply_rope_4d(x, angles): + """Apply RoPE to a 4D tensor using interleaved cos/sin angles. + + x: (batch, heads, seq_len, head_dim) + angles: (seq_len, head_dim) with interleaved [cos_0, sin_0, cos_1, sin_1, ...] + Returns: same shape as x with RoPE applied (two-halves method). + """ + half = x.shape[-1] // 2 + cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + sin = angles[:, 1::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + x1, x2 = x[..., :half], x[..., half:] + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) def _bf16_matmul(a, b): diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 94549d7f..318c44fc 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -7,11 +7,10 @@ from ml_dtypes import bfloat16 from iron.common.test_utils import verify_buffer -from iron.common.utils import torch_to_numpy from iron.operators.mha_prefill_lxl_sd.op import ( - AIEAttentionPrefillFused, - AIEAttentionPrefillProjectedFused, + AttentionPrefillFused, + AttentionPrefillProjectedFused, ) from iron.operators.mha_prefill_lxl_sd.reference import generate_golden_reference @@ -42,27 +41,22 @@ def get_benchmark_params(): def _load_input(fc, name, tensor): """Load a tensor into a named sub-buffer of the fused callable.""" - fc.get_buffer(name).view_as_np()[:] = torch_to_numpy(tensor).flatten() + np_buf = tensor.contiguous().view(torch.uint16).numpy().view(bfloat16) + fc.get_buffer(name).data[:] = np_buf.flatten() def _get_scratch_tensor(fc, name, shape): """Read a named buffer from the fused callable's scratch space.""" - fc.scratch_buffer.on = "npu" - fc.scratch_buffer.to("cpu") + fc.scratch_buffer._sync_from_device() sub = fc.get_buffer(name) - return np.frombuffer( - sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) - ).reshape(shape).astype(np.float32) + return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) def _get_output_tensor(fc, name, shape): """Read a named buffer from the fused callable's output space.""" - fc.output_buffer.on = "npu" - fc.output_buffer.to("cpu") + fc.output_buffer._sync_from_device() sub = fc.get_buffer(name) - return np.frombuffer( - sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) - ).reshape(shape).astype(np.float32) + return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) def _verify_output(fc, golden, H, d, S, E): @@ -72,9 +66,8 @@ def _verify_output(fc, golden, H, d, S, E): ).bfloat16() chain_ref = (npu_context.float() @ golden["W_output"].float()).to(torch.bfloat16) - fc.output_buffer.on = "npu" - fc.output_buffer.to("cpu") - output_np = fc.get_buffer("attn_output").view_as_np() + fc.output_buffer._sync_from_device() + output_np = fc.get_buffer("attn_output").data output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() errors = verify_buffer( @@ -112,7 +105,7 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM.""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillFused(H, G, d, E, S) + op = AttentionPrefillFused(H, G, d, E, S) op.compile() fc = op.get_callable() @@ -152,7 +145,7 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillProjectedFused(H, G, d, E, S) + op = AttentionPrefillProjectedFused(H, G, d, E, S) op.compile() fc = op.get_callable() @@ -189,7 +182,7 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): """Benchmark core MHA for GPT-2 Small across sequence lengths.""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillFused(H, G, d, E, S, causal_mask=causal) + op = AttentionPrefillFused(H, G, d, E, S, causal_mask=causal) op.compile() fc = op.get_callable() @@ -226,7 +219,7 @@ def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): """Check intermediate buffers of core attention (for debugging).""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillFused(H, G, d, E, S) + op = AttentionPrefillFused(H, G, d, E, S) op.compile() fc = op.get_callable() From ee027317a01c6f574c2aad0258cec9a4757c13ff Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 7 Apr 2026 14:57:00 -0600 Subject: [PATCH 08/22] format --- iron/operators/gemm/design.py | 4 +- iron/operators/mha_prefill_lxl_sd/op.py | 228 +++++++++++++----- .../operators/mha_prefill_lxl_sd/reference.py | 78 +++--- iron/operators/mha_prefill_lxl_sd/test.py | 39 ++- 4 files changed, 254 insertions(+), 95 deletions(-) diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index 32222dd4..8b717dcf 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -314,7 +314,9 @@ def my_matmul( gemm_object, [C_l1_ty], ) - matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + matmul_func_name = ( + f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + ) matmul_kernel = Kernel( matmul_func_name, gemm_object, diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index b65a7171..aba8a115 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -41,25 +41,46 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): B = 2 # bytes per bf16 element gemm_scores = GEMM( - M=S, K=d, N=S, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(S, 8), context=elf_ctx, + M=S, + K=d, + N=S, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(S, 8), + context=elf_ctx, ) scale = ElementwiseMul( - size=H * S * S, tile_size=S * S // 8, - num_aie_columns=8, context=elf_ctx, + size=H * S * S, + tile_size=S * S // 8, + num_aie_columns=8, + context=elf_ctx, ) if causal_mask: mask = ElementwiseAdd( - size=H * S * S, tile_size=S * S // 8, - num_aie_columns=8, context=elf_ctx, + size=H * S * S, + tile_size=S * S // 8, + num_aie_columns=8, + context=elf_ctx, ) softmax = Softmax( - rows=H * S, cols=S, num_aie_columns=1, num_channels=1, - rtp_vector_size=S, context=elf_ctx, + rows=H * S, + cols=S, + num_aie_columns=1, + num_channels=1, + rtp_vector_size=S, + context=elf_ctx, ) gemm_context = GEMM( - M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, - tile_n=16, context=elf_ctx, prio_accuracy=True, + M=S, + K=S, + N=d, + num_aie_columns=4, + tile_m=16, + tile_k=64, + tile_n=16, + context=elf_ctx, + prio_accuracy=True, ) qh = S * d * B @@ -69,11 +90,15 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): ch = S * d * B runlist = [ - *[(gemm_scores, - f"queries[{h*qh}:{(h+1)*qh}]", - f"keys[{h*kdS}:{(h+1)*kdS}]", - f"attn_scores[{h*sh}:{(h+1)*sh}]") - for h in range(H)], + *[ + ( + gemm_scores, + f"queries[{h*qh}:{(h+1)*qh}]", + f"keys[{h*kdS}:{(h+1)*kdS}]", + f"attn_scores[{h*sh}:{(h+1)*sh}]", + ) + for h in range(H) + ], (scale, "attn_scores", "attn_scale_factor", "attn_scores_scaled"), ] @@ -88,11 +113,15 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): ] runlist += [ - *[(gemm_context, - f"attn_weights[{h*sh}:{(h+1)*sh}]", - f"values[{h*kSd}:{(h+1)*kSd}]", - f"attn_context[{h*ch}:{(h+1)*ch}]") - for h in range(H)], + *[ + ( + gemm_context, + f"attn_weights[{h*sh}:{(h+1)*sh}]", + f"values[{h*kSd}:{(h+1)*kSd}]", + f"attn_context[{h*ch}:{(h+1)*ch}]", + ) + for h in range(H) + ], ] buffer_sizes = { @@ -116,8 +145,16 @@ class AttentionPrefillFused(FusedMLIROperator): Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. """ - def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, causal_mask=True, context=None): + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + ): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -131,7 +168,11 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() runlist, buffer_sizes = _build_core_ops( - num_heads, num_kv_groups, head_dim, seq_len, elf_ctx, + num_heads, + num_kv_groups, + head_dim, + seq_len, + elf_ctx, causal_mask=causal_mask, ) @@ -156,8 +197,16 @@ class AttentionPrefillProjectedFused(FusedMLIROperator): Accepts raw input (S, E) and rope_angles (S, d). """ - def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, causal_mask=True, context=None): + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + ): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -177,38 +226,73 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, # ---- Projection + RoPE ---- gemm_query = GEMM( - M=S, K=E, N=H * d, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(H * d, 8), context=elf_ctx, + M=S, + K=E, + N=H * d, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(H * d, 8), + context=elf_ctx, ) gemm_kv = GEMM( - M=S, K=E, N=G * d, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(G * d, 8), context=elf_ctx, + M=S, + K=E, + N=G * d, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(G * d, 8), + context=elf_ctx, ) rope_queries = RoPE(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) rope_keys = RoPE(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) # ---- Deinterleave ---- deinterleave_q = StridedCopy( - input_sizes=(H, S, d), input_strides=(d, H * d, 1), input_offset=0, - output_sizes=(H, S, d), output_strides=(S * d, d, 1), output_offset=0, - input_buffer_size=S * H * d, output_buffer_size=H * S * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + input_sizes=(H, S, d), + input_strides=(d, H * d, 1), + input_offset=0, + output_sizes=(H, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * H * d, + output_buffer_size=H * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, ) deinterleave_kv = StridedCopy( - input_sizes=(G, S, d), input_strides=(d, G * d, 1), input_offset=0, - output_sizes=(G, S, d), output_strides=(S * d, d, 1), output_offset=0, - input_buffer_size=S * G * d, output_buffer_size=G * S * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + input_sizes=(G, S, d), + input_strides=(d, G * d, 1), + input_offset=0, + output_sizes=(G, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * G * d, + output_buffer_size=G * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, ) # ---- Transpose keys + GQA repeat ---- transpose_keys = Transpose( - M=S, N=d, num_aie_columns=2, num_channels=1, - m=256, n=32, s=8, context=elf_ctx, + M=S, + N=d, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=elf_ctx, ) repeat_kv = Repeat( - rows=G, cols=d * S, repeat=group_size, - transfer_size=d, context=elf_ctx, + rows=G, + cols=d * S, + repeat=group_size, + transfer_size=d, + context=elf_ctx, ) kSd = S * d * B @@ -223,10 +307,14 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, (deinterleave_q, "queries_roped", "queries"), (deinterleave_kv, "keys_roped", "keys_deint"), (deinterleave_kv, "values_projected", "values_deint"), - *[(transpose_keys, - f"keys_deint[{g*kSd}:{(g+1)*kSd}]", - f"keys_transposed[{g*kdS}:{(g+1)*kdS}]") - for g in range(G)], + *[ + ( + transpose_keys, + f"keys_deint[{g*kSd}:{(g+1)*kSd}]", + f"keys_transposed[{g*kdS}:{(g+1)*kdS}]", + ) + for g in range(G) + ], (repeat_kv, "keys_transposed", "keys"), (repeat_kv, "values_deint", "values"), ] @@ -242,19 +330,38 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } core_runlist, core_buffer_sizes = _build_core_ops( - H, G, d, S, elf_ctx, causal_mask=causal_mask, + H, + G, + d, + S, + elf_ctx, + causal_mask=causal_mask, ) # ---- Reinterleave + output projection ---- reinterleave = StridedCopy( - input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, - output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, - input_buffer_size=H * S * d, output_buffer_size=S * H * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + input_sizes=(1, 1, 1, H * S * d), + input_strides=(0, 0, 0, 1), + input_offset=0, + output_sizes=(H, 256, S // 256, d), + output_strides=(d, 256 * H * d, H * d, 1), + output_offset=0, + input_buffer_size=H * S * d, + output_buffer_size=S * H * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, ) gemm_output = GEMM( - M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, + M=S, + K=H * d, + N=E, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(E, 8), + context=elf_ctx, + prio_accuracy=True, ) suffix_runlist = [ @@ -266,8 +373,15 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } mask_suffix = "_causal" if causal_mask else "_nomask" - input_args = ["input", "rope_angles", "W_query", "W_key", "W_value", - "W_output", "attn_scale_factor"] + input_args = [ + "input", + "rope_angles", + "W_query", + "W_key", + "W_value", + "W_output", + "attn_scale_factor", + ] if causal_mask: input_args.append("causal_mask") @@ -276,6 +390,10 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, runlist=prefix_runlist + core_runlist + suffix_runlist, input_args=input_args, output_args=["attn_output"], - buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes, **suffix_buffer_sizes}, + buffer_sizes={ + **prefix_buffer_sizes, + **core_buffer_sizes, + **suffix_buffer_sizes, + }, context=elf_ctx, ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py index 92cba961..3343fa8d 100644 --- a/iron/operators/mha_prefill_lxl_sd/reference.py +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -12,7 +12,7 @@ def _apply_rope_4d(x, angles): Returns: same shape as x with RoPE applied (two-halves method). """ half = x.shape[-1] // 2 - cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) sin = angles[:, 1::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) x1, x2 = x[..., :half], x[..., half:] return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) @@ -70,7 +70,7 @@ def generate_golden_reference( W_output = torch.randn(H * d, E, dtype=torch.bfloat16) * val_range # Scale factor: 1/sqrt(d), broadcast to (H*S*S,) - scale = 1.0 / (d ** 0.5) + scale = 1.0 / (d**0.5) attn_scale_factor = torch.full((H * S * S,), scale, dtype=torch.bfloat16) # Causal mask: (H*S, S) — 0 for valid positions, -inf for future @@ -83,39 +83,61 @@ def generate_golden_reference( ) # ---- Q/K/V projections ---- - queries_raw = _bf16_matmul(x, W_query) # (S, H*d) - keys_raw = _bf16_matmul(x, W_key) # (S, G*d) - values_raw = _bf16_matmul(x, W_value) # (S, G*d) + queries_raw = _bf16_matmul(x, W_query) # (S, H*d) + keys_raw = _bf16_matmul(x, W_key) # (S, G*d) + values_raw = _bf16_matmul(x, W_value) # (S, G*d) # ---- RoPE (reuses rope_utils.apply_rope with 4D interface) ---- # Reshape interleaved (S, N*d) → (1, N, S, d) for rope_utils - queries_roped = _apply_rope_4d( - queries_raw.reshape(S, H, d).permute(1, 0, 2).unsqueeze(0), # (1, H, S, d) - rope_angles, - ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * H, d) # (S*H, d) - - keys_roped = _apply_rope_4d( - keys_raw.reshape(S, G, d).permute(1, 0, 2).unsqueeze(0), # (1, G, S, d) - rope_angles, - ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * G, d) # (S*G, d) + queries_roped = ( + _apply_rope_4d( + queries_raw.reshape(S, H, d).permute(1, 0, 2).unsqueeze(0), # (1, H, S, d) + rope_angles, + ) + .squeeze(0) + .permute(1, 0, 2) + .contiguous() + .reshape(S * H, d) + ) # (S*H, d) + + keys_roped = ( + _apply_rope_4d( + keys_raw.reshape(S, G, d).permute(1, 0, 2).unsqueeze(0), # (1, G, S, d) + rope_angles, + ) + .squeeze(0) + .permute(1, 0, 2) + .contiguous() + .reshape(S * G, d) + ) # (S*G, d) # ---- Deinterleave Q/K/V ---- - queries_deinterleaved = queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() # (H, S, d) - keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) - values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + queries_deinterleaved = ( + queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() + ) # (H, S, d) + keys_deinterleaved = ( + keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() + ) # (G, S, d) + keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) + values_deinterleaved = ( + values_raw.reshape(S, G, d).transpose(0, 1).contiguous() + ) # (G, S, d) # ---- GQA repeat ---- if group_size > 1: - keys_for_scores = keys_transposed.reshape(G, d * S).repeat_interleave( - group_size, dim=0 - ).reshape(H, d, S) - values_for_context = values_deinterleaved.reshape(G, S * d).repeat_interleave( - group_size, dim=0 - ).reshape(H, S, d) + keys_for_scores = ( + keys_transposed.reshape(G, d * S) + .repeat_interleave(group_size, dim=0) + .reshape(H, d, S) + ) + values_for_context = ( + values_deinterleaved.reshape(G, S * d) + .repeat_interleave(group_size, dim=0) + .reshape(H, S, d) + ) else: - keys_for_scores = keys_transposed # (H, d, S) - values_for_context = values_deinterleaved # (H, S, d) + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) # ---- Score GEMM per head ---- attn_scores = torch.stack( @@ -133,7 +155,9 @@ def generate_golden_reference( # ---- Softmax ---- attn_weights = torch.nn.functional.softmax( attn_scores_masked.float().reshape(H, S, S), dim=-1 - ).to(torch.bfloat16) # (H, S, S) + ).to( + torch.bfloat16 + ) # (H, S, S) # ---- Context GEMM per head ---- attn_context = torch.stack( diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 318c44fc..7e90b748 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -49,14 +49,14 @@ def _get_scratch_tensor(fc, name, shape): """Read a named buffer from the fused callable's scratch space.""" fc.scratch_buffer._sync_from_device() sub = fc.get_buffer(name) - return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) def _get_output_tensor(fc, name, shape): """Read a named buffer from the fused callable's output space.""" fc.output_buffer._sync_from_device() sub = fc.get_buffer(name) - return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) def _verify_output(fc, golden, H, d, S, E): @@ -71,24 +71,28 @@ def _verify_output(fc, golden, H, d, S, E): output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() errors = verify_buffer( - output, "attn_output", chain_ref.reshape(S, E), - rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + output, + "attn_output", + chain_ref.reshape(S, E), + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, ) assert not errors, f"Output verification failed with {len(errors)} errors" def _core_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the core attention operator.""" - score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) - context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) + score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) + context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) return score_flops + context_flops def _projected_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the projected attention operator.""" - query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) - kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each - output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) + kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each + output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj @@ -96,6 +100,7 @@ def _projected_gemm_flops(H, G, d, E, S): # Core attention tests (pre-projected Q, K, V) # --------------------------------------------------------------------------- + @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", @@ -125,9 +130,12 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): actual = _get_output_tensor(fc, "attn_context", (H, S, d)) expected = golden["attn_context"].float().numpy().reshape(H, S, d) errors = verify_buffer( - torch.from_numpy(actual).bfloat16(), "attn_context", + torch.from_numpy(actual).bfloat16(), + "attn_context", torch.from_numpy(expected).bfloat16().reshape(H, S, d), - rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, ) assert not errors, f"Output verification failed with {len(errors)} errors" @@ -136,6 +144,7 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): # Projected attention tests (with Q/K/V projections + RoPE) # --------------------------------------------------------------------------- + @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", @@ -172,6 +181,7 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): # Benchmark: GPT-2 Small core MHA across sequence lengths, +/- causal mask # --------------------------------------------------------------------------- + @pytest.mark.benchmark @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", @@ -207,7 +217,12 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): INTERMEDIATE_CHECKS = [ ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S), "scratch"), - ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S), "scratch"), + ( + "attn_scores_masked", + "attn_scores_masked", + lambda H, G, S, d: (H, S, S), + "scratch", + ), ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S), "scratch"), ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d), "output"), ] From abf37ab66ad277bc2b034338f9c775bdf64c1394 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 15 Apr 2026 11:36:10 -0600 Subject: [PATCH 09/22] make mha_prefill_lxl_sd use all available columns --- iron/operators/mha_prefill_lxl_sd/op.py | 36 +++++++++++++---------- iron/operators/mha_prefill_lxl_sd/test.py | 2 +- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index aba8a115..4421822b 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -5,6 +5,8 @@ A layer-by-layer (LxL) single-dispatch (SD) implementation of multi-head attention (MHA). """ +import aie.utils as aie_utils + from iron.common.context import AIEContext from iron.common.fusion import FusedMLIROperator from iron.operators.gemm.op import GEMM @@ -25,7 +27,7 @@ def _pick_tile_n(N, num_cols, max_tile_n=64): return tile_n -def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): +def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): """Build core attention sub-ops and runlist (no projections/RoPE/GQA). Expects pre-processed inputs: @@ -38,29 +40,31 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): If causal_mask=False, the elementwise-add masking step is omitted. """ + if num_cols is None: + num_cols = aie_utils.get_current_device().cols B = 2 # bytes per bf16 element gemm_scores = GEMM( M=S, K=d, N=S, - num_aie_columns=8, + num_aie_columns=num_cols, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(S, 8), + tile_n=_pick_tile_n(S, num_cols), context=elf_ctx, ) scale = ElementwiseMul( size=H * S * S, - tile_size=S * S // 8, - num_aie_columns=8, + tile_size=S * S // num_cols, + num_aie_columns=num_cols, context=elf_ctx, ) if causal_mask: mask = ElementwiseAdd( size=H * S * S, - tile_size=S * S // 8, - num_aie_columns=8, + tile_size=S * S // num_cols, + num_aie_columns=num_cols, context=elf_ctx, ) softmax = Softmax( @@ -75,10 +79,10 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): M=S, K=S, N=d, - num_aie_columns=4, + num_aie_columns=min(4, num_cols), tile_m=16, tile_k=64, - tile_n=16, + tile_n=_pick_tile_n(d, min(4, num_cols)), context=elf_ctx, prio_accuracy=True, ) @@ -221,6 +225,7 @@ def __init__( H, G, d, E, S = num_heads, num_kv_groups, head_dim, embedding_dim, seq_len group_size = H // G B = 2 + num_cols = aie_utils.get_current_device().cols elf_ctx = context or AIEContext() @@ -229,20 +234,20 @@ def __init__( M=S, K=E, N=H * d, - num_aie_columns=8, + num_aie_columns=num_cols, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(H * d, 8), + tile_n=_pick_tile_n(H * d, num_cols), context=elf_ctx, ) gemm_kv = GEMM( M=S, K=E, N=G * d, - num_aie_columns=8, + num_aie_columns=num_cols, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(G * d, 8), + tile_n=_pick_tile_n(G * d, num_cols), context=elf_ctx, ) rope_queries = RoPE(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) @@ -336,6 +341,7 @@ def __init__( S, elf_ctx, causal_mask=causal_mask, + num_cols=num_cols, ) # ---- Reinterleave + output projection ---- @@ -356,10 +362,10 @@ def __init__( M=S, K=H * d, N=E, - num_aie_columns=8, + num_aie_columns=num_cols, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(E, 8), + tile_n=_pick_tile_n(E, num_cols), context=elf_ctx, prio_accuracy=True, ) diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 7e90b748..7cef67ac 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -31,7 +31,7 @@ def get_benchmark_params(): """GPT-2 Small across sequence lengths 256..32768, with/without causal mask.""" params = [] S = 256 - while S <= 32768: + while S <= 4096: #32768: for mask in [True, False]: tag = "causal" if mask else "nomask" params.append(pytest.param(12, 12, 64, 768, S, mask, id=f"GPT2-S{S}-{tag}")) From 4caac12ea96ec3fe05d475b614a583af013dee15 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 15 Apr 2026 11:36:25 -0600 Subject: [PATCH 10/22] update test result CSV iteratively rather than all at once --- conftest.py | 90 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/conftest.py b/conftest.py index d611f0dd..8a9ab947 100644 --- a/conftest.py +++ b/conftest.py @@ -59,6 +59,14 @@ def __init__(self, csv_path): self.commit = get_git_commit() self.date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.test_metrics = {} # test_name -> {metric_name -> [values]} + self._initialize_csv() + + def _initialize_csv(self): + """Initialize CSV file - will be written incrementally as tests complete""" + self.csv_path.parent.mkdir(parents=True, exist_ok=True) + # Clear the file at the start of a new test run + with open(self.csv_path, "w", newline="") as f: + pass # Create empty file, header will be written with first result def add_result(self, test_name, passed, captured_output, metric_patterns): self.test_metrics.setdefault(test_name, {}).setdefault("passed", []).append( @@ -72,40 +80,70 @@ def add_result(self, test_name, passed, captured_output, metric_patterns): value = float(match.group("value")) self.test_metrics[test_name].setdefault(metric_name, []).append(value) - def finalize_results(self): - """Compute statistics for all collected metrics""" - for test_name, data in self.test_metrics.items(): - row = { - "Commit": self.commit, - "Date": self.date, - "Test": test_name, - "Checks": f"{sum(data['passed'])}/{len(data['passed'])}", - } - for metric_name, values in data.items(): - if metric_name == "passed": - continue - if values: - row[f"{metric_name} (mean)"] = statistics.mean(values) - row[f"{metric_name} (median)"] = statistics.median(values) - row[f"{metric_name} (min)"] = min(values) - row[f"{metric_name} (max)"] = max(values) - row[f"{metric_name} (stddev)"] = ( - statistics.stdev(values) if len(values) > 1 else 0.0 - ) + def _compute_row(self, test_name, data): + """Compute statistics row for a single test""" + row = { + "Commit": self.commit, + "Date": self.date, + "Test": test_name, + "Checks": f"{sum(data['passed'])}/{len(data['passed'])}", + } + for metric_name, values in data.items(): + if metric_name == "passed": + continue + if values: + row[f"{metric_name} (mean)"] = statistics.mean(values) + row[f"{metric_name} (median)"] = statistics.median(values) + row[f"{metric_name} (min)"] = min(values) + row[f"{metric_name} (max)"] = max(values) + row[f"{metric_name} (stddev)"] = ( + statistics.stdev(values) if len(values) > 1 else 0.0 + ) + return row + + def write_test_result(self, test_name): + """Write or update result for a specific test incrementally""" + if test_name not in self.test_metrics: + return + + data = self.test_metrics[test_name] + row = self._compute_row(test_name, data) + + # Update or add to results + existing_idx = None + for idx, existing_row in enumerate(self.results): + if existing_row["Test"] == test_name: + existing_idx = idx + break + + if existing_idx is not None: + self.results[existing_idx] = row + else: self.results.append(row) - def write_csv(self): - self.results.sort(key=lambda x: (x["Test"], x["Date"])) + # Rewrite entire CSV to handle column changes and updates + self._write_csv_internal() + + def _write_csv_internal(self): + """Internal method to write CSV file""" + if not self.results: + return + + sorted_results = sorted(self.results, key=lambda x: (x["Test"], x["Date"])) cols = {} - for row in self.results: + for row in sorted_results: cols.update({k: None for k in row.keys()}) self.csv_path.parent.mkdir(parents=True, exist_ok=True) with open(self.csv_path, "w", newline="") as f: writer = csv.DictWriter(f, cols.keys()) writer.writeheader() - writer.writerows(self.results) + writer.writerows(sorted_results) + + def write_csv(self): + """Final write at session end - ensures all results are flushed""" + self._write_csv_internal() # Initialize the CSV writer once at test session setup @@ -147,6 +185,9 @@ def pytest_runtest_makereport(item, call): csv_reporter.add_result(test_name, passed, captured, metric_patterns) + # Write results incrementally after each test + csv_reporter.write_test_result(test_name) + def pytest_configure(config): csv_path = config.getoption("--csv-output") @@ -170,7 +211,6 @@ def pytest_collection_modifyitems(config, items): def pytest_sessionfinish(session, exitstatus): if hasattr(session.config, "_csv_reporter"): - session.config._csv_reporter.finalize_results() session.config._csv_reporter.write_csv() From 8a65c6d765bb8ef57505da094e0e0a2f5022331d Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 15 Apr 2026 12:20:37 -0600 Subject: [PATCH 11/22] make FusedMLIROperator work on Phoenix via multiple xclbin calls --- iron/common/fusion.py | 213 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 208 insertions(+), 5 deletions(-) diff --git a/iron/common/fusion.py b/iron/common/fusion.py index 292b26e9..7fccb125 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import hashlib +import logging import numpy as np import ml_dtypes import pyxrt @@ -10,7 +12,11 @@ from .base import AIEOperatorBase, MLIROperator from .utils import XRTSubBuffer import aie.utils as aie_utils +from aie.iron.device import NPU2 from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from aie.utils.npukernel import NPUKernel + +logger = logging.getLogger(__name__) # Fused Operator # ########################################################################## @@ -200,13 +206,26 @@ def add_buffers(buffer_type, args_list): def set_up_artifacts(self): """Set up the artifact dependency graph for this fused operator. - Computes the buffer layout first, then builds the fused MLIR artifact - and full-ELF artifact and registers them via ``add_artifacts()``. + Computes the buffer layout first, then builds the artifacts. + On NPU2, uses the full-ELF flow (fused MLIR → single ELF). + On NPU1 (Phoenix), uses chained xclbin flow (separate xclbin per + unique operator, chained via --xclbin-input). """ - # Calculate buffer layout before building mlir artifact (used by get_mlir_artifact) + # Calculate buffer layout (used by both paths for get_buffer()) self.subbuffer_layout, self.buffer_sizes, self.slice_info = ( self._calculate_buffer_layout() ) + + dev = aie_utils.get_current_device() + self._use_full_elf = isinstance(dev, NPU2) + + if self._use_full_elf: + self._set_up_full_elf_artifacts() + else: + self._set_up_xclbin_artifacts() + + def _set_up_full_elf_artifacts(self): + """Full-ELF path (NPU2): fuse MLIR into a single ELF.""" operator_name = self.name mlir_artifact = self.get_mlir_artifact() kernel_objects = self.get_kernel_artifacts() @@ -217,6 +236,58 @@ def set_up_artifacts(self): ) self.add_artifacts([full_elf_artifact]) + def _set_up_xclbin_artifacts(self): + """Chained xclbin path (NPU1/Phoenix): separate xclbin per unique operator. + + Mirrors the pattern from ``chain_swiglu_artifacts`` in + ``iron/operators/swiglu_base.py``: each unique operator gets its own + xclbin + insts compiled separately, linked via ``--xclbin-input``. + """ + seen: dict[int, object] = {} + unique_operators = [ + seen.setdefault(id(op), op) + for op, *_ in self.runlist + if id(op) not in seen + ] + + # Short hash to keep xclbin kernel names under 31 chars + # (xclbinutil limits m_name to 64 chars as "name:name") + name_hash = hashlib.sha1(self.name.encode()).hexdigest()[:6] + + artifacts = [] + prev_xclbin = None + self._op_xclbin_map = {} # id(op) -> xclbin artifact + self._op_insts_map = {} # id(op) -> insts artifact + self._op_kernel_name_map = {} # id(op) -> kernel_name + + for idx, op in enumerate(unique_operators): + op_label = f"f{name_hash}_op{idx}" + kernel_id = f"0x{0x901 + idx:x}" + + xclbin, insts = op.get_artifacts(prefix=f"{op_label}_") + # Use list() to avoid mutating the shared extra_flags list + # (get_artifacts may alias the same list between xclbin and insts) + xclbin.extra_flags = list(xclbin.extra_flags) + [ + f"--xclbin-instance-name={op_label}", + f"--xclbin-kernel-id={kernel_id}", + ] + xclbin.kernel_name = op_label + + if prev_xclbin is not None: + xclbin.xclbin_input = prev_xclbin + xclbin.dependencies.add(prev_xclbin) + + artifacts.append(insts) + self._op_xclbin_map[id(op)] = xclbin + self._op_insts_map[id(op)] = insts + self._op_kernel_name_map[id(op)] = op_label + prev_xclbin = xclbin + + # The last xclbin in the chain is the combined xclbin. + artifacts.append(prev_xclbin) + self.combined_xclbin = prev_xclbin + self.add_artifacts(artifacts) + def get_arg_spec(self): raise NotImplementedError( "FusedMLIROperator does not expose a unified arg spec; " @@ -227,9 +298,12 @@ def get_callable(self): """Return a callable that executes the fused operator on the NPU. Returns: - A ``FusedFullELFCallable`` wrapping this operator. + A ``FusedFullELFCallable`` on NPU2, or a ``FusedXclbinCallable`` + on NPU1 (Phoenix). """ - return FusedFullELFCallable(self) + if self._use_full_elf: + return FusedFullELFCallable(self) + return FusedXclbinCallable(self) def get_layout_for_buffer(self, buffer_name): """Return the (buffer_type, offset, length) layout for a named buffer. @@ -375,3 +449,132 @@ def __call__(self): self.scratch_buffer.buffer_object(), ) self.output_buffer._sync_from_device() + + +class FusedXclbinCallable: + """Callable for FusedMLIROperator on NPU1 (Phoenix) using chained xclbins. + + Instead of a single ELF dispatch, each step in the runlist is executed as a + separate ``NPUKernel`` invocation. Buffers are shared (same ``XRTTensor``) + across steps that reference the same buffer name, giving zero-copy handoff + between sequential operators. + """ + + def __init__(self, op): + self.op = op + self.last_elapsed = 0.0 + + combined_xclbin_path = op.combined_xclbin.filename + + # Build an NPUKernel per unique operator + self._op_callable_map = {} # id(op) -> NPUKernel + for op_id, xclbin in op._op_xclbin_map.items(): + insts = op._op_insts_map[op_id] + kernel_name = op._op_kernel_name_map[op_id] + self._op_callable_map[op_id] = NPUKernel( + xclbin_path=combined_xclbin_path, + kernel_name=kernel_name, + insts_path=insts.filename, + ) + + # Allocate one XRTTensor per unique base buffer name. + # Buffers that appear in multiple runlist entries share the same tensor + # (zero-copy between operators). + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + self._buffers = {} # base buffer name -> XRTTensor + for buf_name in list(op.subbuffer_layout.keys()): + _, _, length = op.subbuffer_layout[buf_name] + self._buffers[buf_name] = XRTTensor( + (max(length, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + # Pre-build the execution plan: list of (NPUKernel, [XRTTensor args]) + self._execution_plan = [] + for step_op, *buf_names in op.runlist: + kernel = self._op_callable_map[id(step_op)] + args = [] + for buf_name in buf_names: + args.append(self._resolve_buffer(buf_name)) + self._execution_plan.append((kernel, args)) + + # Cache for get_buffer() sub-buffer views (compatible with FusedFullELFCallable API) + self._buffer_cache = {} + + # Expose input/output/scratch buffers for API compatibility with + # FusedFullELFCallable (used by tests for _sync_from_device etc.) + input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes + self.input_buffer = XRTTensor( + (max(input_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + self.output_buffer = XRTTensor( + (max(output_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + self.scratch_buffer = XRTTensor( + (max(scratch_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + def _resolve_buffer(self, buf_name): + """Resolve a buffer name (possibly with slice notation) to an XRTTensor. + + Regular buffer names map directly to an allocated XRTTensor. + Sliced buffer names (e.g. ``queries[0:128]``) create an XRTSubBuffer + view into the parent buffer. + """ + if buf_name in self._buffers: + return self._buffers[buf_name] + + # Sliced buffer: "base_name[start:end]" + if buf_name in self.op.slice_info: + base_name, start_bytes, end_bytes = self.op.slice_info[buf_name] + parent = self._buffers[base_name] + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + size_bytes = end_bytes - start_bytes + sub = XRTSubBuffer( + parent_bo=parent.buffer_object(), + offset_bytes=start_bytes, + size_bytes=size_bytes, + shape=(size_bytes // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + # Cache so the same slice always returns the same object + self._buffers[buf_name] = sub + return sub + + raise ValueError(f"Unknown buffer '{buf_name}' in fused runlist") + + def get_buffer(self, buffer_name): + """Return an XRTTensor(-like) view for a named buffer. + + Compatible with the ``FusedFullELFCallable.get_buffer()`` API so that + test helpers (``_load_input``, ``_get_output_tensor``, etc.) work + unchanged. + + For the xclbin path, each buffer is its own standalone XRTTensor (or + XRTSubBuffer for sliced buffers), so this just returns the resolved + buffer directly. + """ + if buffer_name in self._buffer_cache: + return self._buffer_cache[buffer_name] + buf = self._resolve_buffer(buffer_name) + self._buffer_cache[buffer_name] = buf + return buf + + def __call__(self): + # Sync all input buffers to device + for buf_name in self.op.input_args: + self._buffers[buf_name]._sync_to_device() + + t0 = time.perf_counter() + for kernel, args in self._execution_plan: + kernel(*args) + self.last_elapsed = time.perf_counter() - t0 + + # Sync all base buffers from device so callers can read results + # (covers both output and scratch buffers) + for buf_name in self.op.subbuffer_layout: + if buf_name not in self.op.input_args: + self._buffers[buf_name]._sync_from_device() From daf9162f5693ffb5d6368a30aba0a7f3cf47bfc5 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 15 Apr 2026 14:14:05 -0600 Subject: [PATCH 12/22] make dispatch mode selectable, add tests --- iron/common/fusion.py | 51 +++++++++++++++++++---- iron/operators/mha_prefill_lxl_sd/op.py | 5 +++ iron/operators/mha_prefill_lxl_sd/test.py | 15 ++++--- 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/iron/common/fusion.py b/iron/common/fusion.py index 7fccb125..069e0089 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -23,11 +23,33 @@ class FusedMLIROperator(AIEOperatorBase): - """Operator that fuses multiple MLIROperators into one.""" + """Operator that fuses multiple MLIROperators into one. + + Args: + dispatch: Dispatch strategy for the fused operator. + ``"auto"`` (default) selects ``"fused"`` on NPU2 and + ``"separate"`` on NPU1. ``"fused"`` uses a single-ELF + dispatch (requires NPU2). ``"separate"`` compiles each + sub-operator to its own xclbin and invokes them sequentially. + """ + + DISPATCH_MODES = ("auto", "fused", "separate") def __init__( - self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs + self, + name, + runlist, + input_args, + output_args, + buffer_sizes=None, + dispatch="auto", + *args, + **kwargs, ): + if dispatch not in self.DISPATCH_MODES: + raise ValueError( + f"dispatch must be one of {self.DISPATCH_MODES!r}, got {dispatch!r}" + ) if not all( isinstance(op, MLIROperator) and all(isinstance(buf, str) for buf in bufs) for op, *bufs in runlist @@ -44,6 +66,7 @@ def __init__( self.explicit_buffer_sizes = ( buffer_sizes or {} ) # Optional dict: buffer_name -> size_in_bytes + self._dispatch = dispatch def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators. @@ -207,17 +230,27 @@ def set_up_artifacts(self): """Set up the artifact dependency graph for this fused operator. Computes the buffer layout first, then builds the artifacts. - On NPU2, uses the full-ELF flow (fused MLIR → single ELF). - On NPU1 (Phoenix), uses chained xclbin flow (separate xclbin per - unique operator, chained via --xclbin-input). + The dispatch mode (``"fused"`` vs ``"separate"``) is resolved here + when set to ``"auto"``. """ # Calculate buffer layout (used by both paths for get_buffer()) self.subbuffer_layout, self.buffer_sizes, self.slice_info = ( self._calculate_buffer_layout() ) - dev = aie_utils.get_current_device() - self._use_full_elf = isinstance(dev, NPU2) + is_npu2 = isinstance(aie_utils.get_current_device(), NPU2) + + if self._dispatch == "auto": + self._use_full_elf = is_npu2 + elif self._dispatch == "fused": + if not is_npu2: + raise RuntimeError( + "dispatch='fused' requires NPU2 (Strix); " + "Phoenix/NPU1 does not support full-ELF dispatch" + ) + self._use_full_elf = True + else: # "separate" + self._use_full_elf = False if self._use_full_elf: self._set_up_full_elf_artifacts() @@ -298,8 +331,8 @@ def get_callable(self): """Return a callable that executes the fused operator on the NPU. Returns: - A ``FusedFullELFCallable`` on NPU2, or a ``FusedXclbinCallable`` - on NPU1 (Phoenix). + A ``FusedFullELFCallable`` when using fused dispatch, or a + ``FusedXclbinCallable`` when using separate dispatch. """ if self._use_full_elf: return FusedFullELFCallable(self) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 4421822b..ace92d86 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -158,6 +158,7 @@ def __init__( seq_len, causal_mask=True, context=None, + dispatch="auto", ): assert head_dim == 64 assert num_heads % num_kv_groups == 0 @@ -191,6 +192,7 @@ def __init__( input_args=input_args, output_args=["attn_context"], buffer_sizes=buffer_sizes, + dispatch=dispatch, context=elf_ctx, ) @@ -210,6 +212,7 @@ def __init__( seq_len, causal_mask=True, context=None, + dispatch="auto", ): assert head_dim == 64 assert num_heads % num_kv_groups == 0 @@ -221,6 +224,7 @@ def __init__( self.head_dim = head_dim self.embedding_dim = embedding_dim self.seq_len = seq_len + self._dispatch_arg = dispatch H, G, d, E, S = num_heads, num_kv_groups, head_dim, embedding_dim, seq_len group_size = H // G @@ -401,5 +405,6 @@ def __init__( **core_buffer_sizes, **suffix_buffer_sizes, }, + dispatch=dispatch, context=elf_ctx, ) diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 7cef67ac..b15b3135 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -33,8 +33,13 @@ def get_benchmark_params(): S = 256 while S <= 4096: #32768: for mask in [True, False]: - tag = "causal" if mask else "nomask" - params.append(pytest.param(12, 12, 64, 768, S, mask, id=f"GPT2-S{S}-{tag}")) + for dispatch in ["auto", "separate"]: + tag = "causal" if mask else "nomask" + suffix = f"-{dispatch}" if dispatch != "auto" else "" + params.append(pytest.param( + 12, 12, 64, 768, S, mask, dispatch, + id=f"GPT2-S{S}-{tag}{suffix}", + )) S *= 2 return params @@ -187,12 +192,12 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): Latency=r"Latency \(us\): (?P[\d\.]+)", Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", ) -@pytest.mark.parametrize("H,G,d,E,S,causal", get_benchmark_params()) -def test_mha_prefill_benchmark(H, G, d, E, S, causal): +@pytest.mark.parametrize("H,G,d,E,S,causal,dispatch", get_benchmark_params()) +def test_mha_prefill_benchmark(H, G, d, E, S, causal, dispatch): """Benchmark core MHA for GPT-2 Small across sequence lengths.""" golden = generate_golden_reference(H, G, d, E, S) - op = AttentionPrefillFused(H, G, d, E, S, causal_mask=causal) + op = AttentionPrefillFused(H, G, d, E, S, causal_mask=causal, dispatch=dispatch) op.compile() fc = op.get_callable() From ffe05157d9bd27437ae6206f22b2e6ec02c06f7d Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 15 Apr 2026 15:22:51 -0600 Subject: [PATCH 13/22] use partial softmax on long sequence lengths --- aie_kernels/aie2p/softmax.cc | 133 +++++++++++++++ iron/operators/mha_prefill_lxl_sd/op.py | 5 + iron/operators/mha_prefill_lxl_sd/test.py | 2 +- iron/operators/softmax/design.py | 190 ++++++++++++++++++++++ iron/operators/softmax/op.py | 18 +- iron/operators/softmax/test.py | 49 ++++++ 6 files changed, 395 insertions(+), 2 deletions(-) diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 5778682a..48d82c3a 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -160,6 +160,118 @@ void partial_softmax_alias_bf16(bfloat16 *restrict input_vector, return; } +// --------------------------------------------------------------------------- +// Online (partial / tiled) softmax helpers +// +// These three kernels implement a two-pass online softmax that processes a row +// in sub-tile chunks, keeping running max and sum statistics in a small local +// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1] +// used): +// stats[0] = running max (scaled by log2e) +// stats[1] = running sum (of exp2(x*log2e - max)) +// --------------------------------------------------------------------------- + +void softmax_partial_stats_impl(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / SM_VEC_LEN; + + aie::vector input_bf16; + aie::accum scaled_accum, exp_in_accum; + aie::accum exp_val_accum = aie::zeros(); + + aie::vector log2e_vec = + aie::broadcast((bfloat16)log2e); + + // --- Phase 1: find local max (scaled by log2e) ------------------------- + float local_max = -INFINITY; + auto it_in1 = aie::cbegin_restrict_vector((bfloat16 *)input); + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in1++; + scaled_accum = aie::mul(input_bf16, log2e_vec); + float chunk_max = aie::reduce_max(scaled_accum.to_vector()); + if (chunk_max > local_max) { + local_max = chunk_max; + } + } + + // --- Phase 2: update running max, rescale running sum ------------------ + float old_max = (float)stats[0]; + float old_sum = (float)stats[1]; + + if (local_max > old_max) { + // New max is larger — rescale the old sum by exp2(old_max - new_max) + aie::vector diff_vec = + aie::broadcast(old_max - local_max); + aie::vector corr = aie::exp2(diff_vec); + old_sum = old_sum * (float)corr[0]; + old_max = local_max; + } + + // --- Phase 3: accumulate exp2(input * log2e - max) for this chunk ------ + aie::vector max_val_vec = + aie::broadcast((bfloat16)old_max); + + auto it_in2 = aie::cbegin_restrict_vector((bfloat16 *)input); + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in2++; + scaled_accum = aie::mul(input_bf16, log2e_vec); + exp_in_accum = aie::sub(scaled_accum, max_val_vec); + aie::vector exp_val = + aie::exp2(exp_in_accum.to_vector()); + exp_val_accum = add(exp_val_accum, exp_val); + } + + aie::vector reduce = exp_val_accum.to_vector(); + float local_sum = aie::reduce_add(reduce); + + // --- Phase 4: store updated stats -------------------------------------- + stats[0] = (bfloat16)old_max; + stats[1] = (bfloat16)(old_sum + local_sum); + + event1(); +} + +void softmax_partial_norm_impl(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / SM_VEC_LEN; + + float max_val = (float)stats[0]; + float sum_val = (float)stats[1]; + bfloat16 inv_sum = (bfloat16)aie::inv(sum_val); + + aie::vector log2e_vec = + aie::broadcast((bfloat16)log2e); + aie::vector max_val_vec = + aie::broadcast((bfloat16)max_val); + + aie::vector input_bf16; + aie::accum scaled_accum, exp_in_accum, out_vals; + + auto it_in = aie::cbegin_restrict_vector((bfloat16 *)input); + auto it_out = aie::begin_restrict_vector((bfloat16 *)output); + + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in++; + scaled_accum = aie::mul(input_bf16, log2e_vec); + exp_in_accum = aie::sub(scaled_accum, max_val_vec); + aie::vector exp_val = + aie::exp2(exp_in_accum.to_vector()); + out_vals = aie::mul(exp_val, inv_sum); + *it_out++ = out_vals.to_vector(); + } + + event1(); +} + extern "C" { void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size) @@ -178,6 +290,27 @@ void partial_softmax_bf16(bfloat16 *restrict input, partial_softmax_alias_bf16(input, output, scale_buffer, input_size, row_idx, num_rows, scale); } +void softmax_partial_init_bf16(bfloat16 *stats) +{ + stats[0] = (bfloat16)(-INFINITY); + stats[1] = (bfloat16)(0.0f); +} + +void softmax_partial_stats_bf16(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_stats_impl(input, stats, vector_size); +} + +void softmax_partial_norm_bf16(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_norm_impl(input, output, stats, vector_size); +} + void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size) { // TODO: Optimize this to use vector code diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index ace92d86..793ff8b6 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -67,12 +67,17 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): num_aie_columns=num_cols, context=elf_ctx, ) + # Use online/partial softmax when full-row tiles would exhaust AIE local + # memory (each double-buffered FIFO pair uses 4 * tile_size bytes; at + # S >= 8192 the in+out FIFOs alone consume the full 64 KB data memory). + softmax_chunk_size = 1024 if S >= 8192 else None softmax = Softmax( rows=H * S, cols=S, num_aie_columns=1, num_channels=1, rtp_vector_size=S, + chunk_size=softmax_chunk_size, context=elf_ctx, ) gemm_context = GEMM( diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index b15b3135..bbfd3e0e 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -31,7 +31,7 @@ def get_benchmark_params(): """GPT-2 Small across sequence lengths 256..32768, with/without causal mask.""" params = [] S = 256 - while S <= 4096: #32768: + while S <= 8192: for mask in [True, False]: for dispatch in ["auto", "separate"]: tag = "causal" if mask else "nomask" diff --git a/iron/operators/softmax/design.py b/iron/operators/softmax/design.py index 5cb68c39..cbf6120f 100644 --- a/iron/operators/softmax/design.py +++ b/iron/operators/softmax/design.py @@ -20,6 +20,182 @@ from ml_dtypes import bfloat16 +def _softmax_partial( + dev, + num_elements, + num_aie_columns, + num_channels, + tile_size, + chunk_size, + func_prefix="", +): + """Online / tiled softmax that processes each row in sub-tile chunks. + + Each row of *tile_size* elements is processed in two passes: + 1. Stats pass – reads chunks, accumulates running max and sum(exp). + 2. Norm pass – reads the same chunks again, writes exp(x-max)/sum. + + Two separate input ObjectFifos are used so that the DMA can feed each pass + independently from the same DDR source buffer. + """ + total_cores = num_aie_columns * num_channels + per_core_elements = num_elements // total_cores + if num_elements % total_cores != 0: + raise ValueError( + f"Number of elements ({num_elements}) must be a multiple of {total_cores}." + ) + + rows_per_core = per_core_elements // tile_size + chunks_per_row = tile_size // chunk_size + dtype = bfloat16 + + # Tensor / tile types + tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] + chunk_ty = np.ndarray[(chunk_size,), np.dtype[dtype]] + stats_ty = np.ndarray[(16,), np.dtype[dtype]] # only [0..1] used + + chunk = num_elements // num_aie_columns // num_channels + + # --- Object FIFOs ------------------------------------------------------- + of_in_stats = [ + ObjectFifo(chunk_ty, name=f"in_stats_{i}_{j}") + for i in range(num_aie_columns) + for j in range(num_channels) + ] + of_in_norm = [ + ObjectFifo(chunk_ty, name=f"in_norm_{i}_{j}") + for i in range(num_aie_columns) + for j in range(num_channels) + ] + of_outs = [ + ObjectFifo(chunk_ty, name=f"out_{i}_{j}") + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Kernel declarations ------------------------------------------------ + init_kernel = Kernel( + f"{func_prefix}softmax_partial_init_bf16", + f"{func_prefix}softmax.o", + [stats_ty], + ) + stats_kernel = Kernel( + f"{func_prefix}softmax_partial_stats_bf16", + f"{func_prefix}softmax.o", + [chunk_ty, stats_ty, np.int32], + ) + norm_kernel = Kernel( + f"{func_prefix}softmax_partial_norm_bf16", + f"{func_prefix}softmax.o", + [chunk_ty, chunk_ty, stats_ty, np.int32], + ) + + # --- Local stats buffers (one per core) --------------------------------- + stats_buffers = [ + Buffer( + initial_value=np.zeros(16, dtype=dtype), + name=f"stats_{i}_{j}", + ) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + barriers = [ + WorkerRuntimeBarrier() + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Worker body -------------------------------------------------------- + def core_body( + of_s, of_n, of_out, + init_k, stats_k, norm_k, + stats_buf, barrier, + ): + barrier.wait_for_value(1) + for _ in range_(rows_per_core): + # Reset running max / sum for the new row + init_k(stats_buf) + + # Pass 1 – accumulate max and sum(exp) + for _ in range_(chunks_per_row): + elem = of_s.acquire(1) + stats_k(elem, stats_buf, chunk_size) + of_s.release(1) + + # Pass 2 – normalise: exp(x - max) / sum + for _ in range_(chunks_per_row): + elem_in = of_n.acquire(1) + elem_out = of_out.acquire(1) + norm_k(elem_in, elem_out, stats_buf, chunk_size) + of_n.release(1) + of_out.release(1) + + # --- Workers ------------------------------------------------------------ + def _worker_args(k): + return [ + of_in_stats[k].cons(), + of_in_norm[k].cons(), + of_outs[k].prod(), + init_kernel, + stats_kernel, + norm_kernel, + stats_buffers[k], + barriers[k], + ] + + workers = [ + Worker(core_body, _worker_args(i * num_channels + j)) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Tensor access patterns (identical for both input FIFOs) ------------ + taps = [ + TensorAccessPattern( + (1, num_elements), + chunk * i * num_channels + chunk * j, + [1, 1, 1, chunk], + [0, 0, 0, 1], + ) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Runtime sequence --------------------------------------------------- + rt = Runtime() + with rt.sequence(tensor_ty, tensor_ty) as (A, C): + rt.start(*workers) + + for k in range(num_aie_columns * num_channels): + rt.set_barrier(barriers[k], 1) + + tg = rt.task_group() + + for i in range(num_aie_columns): + for j in range(num_channels): + k = i * num_channels + j + # Feed the stats-pass FIFO + rt.fill( + of_in_stats[k].prod(), A, taps[k], task_group=tg, + ) + # Feed the norm-pass FIFO (same source data) + rt.fill( + of_in_norm[k].prod(), A, taps[k], task_group=tg, + ) + + for i in range(num_aie_columns): + for j in range(num_channels): + k = i * num_channels + j + rt.drain( + of_outs[k].cons(), C, taps[k], wait=True, task_group=tg, + ) + + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) + + def softmax( dev, num_elements, @@ -31,7 +207,21 @@ def softmax( mask_patch_value=0, func_prefix="", kernel_obj_file="softmax.o", + chunk_size=None, ): + # ---- Partial (online) softmax path ---- + if chunk_size is not None: + return _softmax_partial( + dev, + num_elements, + num_aie_columns, + num_channels, + tile_size, + chunk_size, + func_prefix, + ) + + # ---- Full-row softmax path (original) ---- per_tile_elements = tile_size if rtp_vector_size is None: rtp_vector_size = per_tile_elements diff --git a/iron/operators/softmax/op.py b/iron/operators/softmax/op.py index 71aec051..fddfa0a1 100644 --- a/iron/operators/softmax/op.py +++ b/iron/operators/softmax/op.py @@ -20,7 +20,12 @@ @dataclass class Softmax(MLIROperator): - """AIE-accelerated Softmax operation""" + """AIE-accelerated Softmax operation + + When *chunk_size* is set (and < cols), uses an online / tiled softmax + that processes each row in two passes with sub-tile chunks, avoiding the + local-memory exhaustion that occurs with very long rows (e.g. S >= 8192). + """ rows: int cols: int @@ -28,6 +33,7 @@ class Softmax(MLIROperator): num_channels: int = 1 rtp_vector_size: int | None = None mask_patch_value: int = 0 + chunk_size: int | None = None context: object = field(default=None, repr=False) @property @@ -43,6 +49,15 @@ def __post_init__(self): raise ValueError( f"rows ({self.rows}) must be a multiple of num_aie_columns ({self.num_aie_columns})" ) + if self.chunk_size is not None: + if self.cols % self.chunk_size != 0: + raise ValueError( + f"cols ({self.cols}) must be a multiple of chunk_size ({self.chunk_size})" + ) + if self.chunk_size % 64 != 0: + raise ValueError( + f"chunk_size ({self.chunk_size}) must be a multiple of 64" + ) MLIROperator.__init__(self, context=self.context) @property @@ -69,6 +84,7 @@ def get_mlir_artifact(self): "rtp_vector_size": self.rtp_vector_size, "mask_patch_value": self.mask_patch_value, "kernel_obj_file": self._kernel_link_file, + "chunk_size": self.chunk_size, }, ), ) diff --git a/iron/operators/softmax/test.py b/iron/operators/softmax/test.py index 066d2309..8cca9f69 100755 --- a/iron/operators/softmax/test.py +++ b/iron/operators/softmax/test.py @@ -85,3 +85,52 @@ def test_softmax(input_length, num_aie_columns, num_channels, tile_size, aie_con print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") assert not errors, f"Test failed with errors: {errors}" + + +# --------------------------------------------------------------------------- +# Partial (online / tiled) softmax tests — enables long rows (S >= 8192) +# --------------------------------------------------------------------------- + + +def get_partial_params(): + """GPT-2 style: 12 heads × S rows of S columns, tested via partial softmax.""" + params = [] + for S in [8192]: + H = 12 # GPT-2 Small heads + rows = H * S + cols = S + chunk_size = 1024 + params.append(pytest.param(rows, cols, chunk_size, id=f"GPT2-S{S}")) + return params + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", +) +@pytest.mark.parametrize("rows,cols,chunk_size", get_partial_params()) +def test_softmax_partial(rows, cols, chunk_size, aie_context): + """Test partial / online softmax with sub-tile chunks for long rows.""" + + golden_ref = generate_golden_reference(rows=rows, cols=cols) + + operator = Softmax( + rows=rows, + cols=cols, + num_aie_columns=1, + num_channels=1, + chunk_size=chunk_size, + context=aie_context, + ) + + input_buffers = {"in": golden_ref["input"]} + output_buffers = {"output": golden_ref["output"]} + + errors, latency_us, bandwidth_gbps = run_test( + operator, input_buffers, output_buffers, rel_tol=0.08, abs_tol=1e-6 + ) + + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + + assert not errors, f"Test failed with errors: {errors}" From 6ce23a280c19238a243a8137ff9959a130a01467 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 16 Apr 2026 00:47:53 -0600 Subject: [PATCH 14/22] go up to 32768 sequence length for mha_lxl_sd benchmark tests --- iron/operators/mha_prefill_lxl_sd/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index bbfd3e0e..f7c5aa25 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -31,7 +31,7 @@ def get_benchmark_params(): """GPT-2 Small across sequence lengths 256..32768, with/without causal mask.""" params = [] S = 256 - while S <= 8192: + while S <= 32768: for mask in [True, False]: for dispatch in ["auto", "separate"]: tag = "causal" if mask else "nomask" From 0403ee24111f7c289b2e92416cbd4f27db8778d9 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 17 Apr 2026 17:17:05 -0600 Subject: [PATCH 15/22] stochastic testing for large sequence lengths; split GEMM+softmax into multiple invocations for large sequence lengths --- iron/operators/mha_prefill_lxl_sd/op.py | 143 ++++++++++++++---- .../operators/mha_prefill_lxl_sd/reference.py | 91 +++++++++++ iron/operators/mha_prefill_lxl_sd/test.py | 54 ++++++- 3 files changed, 248 insertions(+), 40 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 793ff8b6..2af4c75c 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -44,8 +44,29 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): num_cols = aie_utils.get_current_device().cols B = 2 # bytes per bf16 element + # ---- M-splitting for the per-head GEMMs ---- + # At very long sequence lengths the per-GEMM design's runtime sequence + # (number of `rt.fill`/`rt.drain` MLIR ops) grows linearly with M, which + # makes each sub-operator's MLIR module very large and can OOM the + # compiler at S >= 16K. We cap each GEMM invocation's M dimension at + # `gemm_M_chunk` and split the per-head computation into multiple + # back-to-back GEMM invocations with sliced buffer offsets. Single + # dispatch is preserved (same fused runlist, just with more entries). + # + # M_chunk must be a multiple of `tile_m * n_aie_rows = 64` and must + # divide S evenly. At min(S, 4096) we get: + # S <= 4096: 1 invocation per head per phase (no splitting) + # S = 8192: 2 invocations per head per phase + # S = 16384: 4 invocations per head per phase + # S = 32768: 8 invocations per head per phase + gemm_M_chunk = min(S, 4096) + assert S % gemm_M_chunk == 0, ( + f"S ({S}) must be a multiple of gemm_M_chunk ({gemm_M_chunk})" + ) + n_m_chunks = S // gemm_M_chunk + gemm_scores = GEMM( - M=S, + M=gemm_M_chunk, K=d, N=S, num_aie_columns=num_cols, @@ -71,8 +92,43 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): # memory (each double-buffered FIFO pair uses 4 * tile_size bytes; at # S >= 8192 the in+out FIFOs alone consume the full 64 KB data memory). softmax_chunk_size = 1024 if S >= 8192 else None + + # ---- Row-splitting for the softmax invocation ---- + # The shim DMA BD length field is a 32-bit unsigned word count (~4.29 B + # words ≈ 17 GB), but the current compiler lowering computes the BD length + # in bytes through int32 arithmetic and silently overflows when a single + # invocation's transfer exceeds 2 GB (= 2^31 bytes = 2^30 bf16 elements). + # We split the softmax call into N back-to-back invocations on disjoint + # row ranges to keep each transfer under that effective limit. The + # softmax buffers (attn_scores_masked / attn_scores_scaled / attn_weights) + # are row-major (S, S) per head, so a contiguous row-range slice maps + # directly to a contiguous byte range. + # Strict bound: bytes per BD must fit in signed int32 (< 2^31), so the + # per-invocation element count must be strictly less than 2^30. + SOFTMAX_MAX_ELEMENTS_PER_INV = (1 << 30) - 1 + total_softmax_rows = H * S + if total_softmax_rows * S <= SOFTMAX_MAX_ELEMENTS_PER_INV: + n_softmax_invocations = 1 + else: + # Smallest n that simultaneously divides total_softmax_rows evenly + # AND keeps each invocation's transfer at or below the limit. + n_softmax_invocations = ( + total_softmax_rows * S + SOFTMAX_MAX_ELEMENTS_PER_INV - 1 + ) // SOFTMAX_MAX_ELEMENTS_PER_INV + while ( + total_softmax_rows % n_softmax_invocations != 0 + or (total_softmax_rows // n_softmax_invocations) * S + > SOFTMAX_MAX_ELEMENTS_PER_INV + ): + n_softmax_invocations += 1 + softmax_rows_per_inv = total_softmax_rows // n_softmax_invocations + assert softmax_rows_per_inv % 16 == 0, ( + f"softmax_rows_per_inv ({softmax_rows_per_inv}) must be a multiple of 16; " + f"got total_rows={total_softmax_rows}, n_invocations={n_softmax_invocations}" + ) + softmax = Softmax( - rows=H * S, + rows=softmax_rows_per_inv, cols=S, num_aie_columns=1, num_channels=1, @@ -81,7 +137,7 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): context=elf_ctx, ) gemm_context = GEMM( - M=S, + M=gemm_M_chunk, K=S, N=d, num_aie_columns=min(4, num_cols), @@ -92,46 +148,69 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): prio_accuracy=True, ) - qh = S * d * B - kdS = d * S * B - kSd = S * d * B - sh = S * S * B - ch = S * d * B + # Per-head byte sizes + qh = S * d * B # queries per head: (S, d) + kdS = d * S * B # keys per head: (d, S) + kSd = S * d * B # values per head: (S, d) + sh = S * S * B # scores/weights per head: (S, S) + ch = S * d * B # context per head: (S, d) + + # Per-M-chunk byte sizes (the M dimension is contiguous in row-major + # storage so M-slices map directly to byte ranges within each head) + q_chunk = gemm_M_chunk * d * B # queries chunk: (M_chunk, d) + s_chunk = gemm_M_chunk * S * B # scores chunk: (M_chunk, S) + w_chunk = gemm_M_chunk * S * B # weights chunk: (M_chunk, S) + c_chunk = gemm_M_chunk * d * B # context chunk: (M_chunk, d) + + score_calls = [ + ( + gemm_scores, + f"queries[{h*qh + i*q_chunk}:{h*qh + (i+1)*q_chunk}]", + f"keys[{h*kdS}:{(h+1)*kdS}]", + f"attn_scores[{h*sh + i*s_chunk}:{h*sh + (i+1)*s_chunk}]", + ) + for h in range(H) + for i in range(n_m_chunks) + ] - runlist = [ - *[ + context_calls = [ + ( + gemm_context, + f"attn_weights[{h*sh + i*w_chunk}:{h*sh + (i+1)*w_chunk}]", + f"values[{h*kSd}:{(h+1)*kSd}]", + f"attn_context[{h*ch + i*c_chunk}:{h*ch + (i+1)*c_chunk}]", + ) + for h in range(H) + for i in range(n_m_chunks) + ] + + # Build the softmax runlist entries (one per invocation when row-split). + softmax_input_buf = "attn_scores_masked" if causal_mask else "attn_scores_scaled" + softmax_chunk_bytes = softmax_rows_per_inv * S * B + if n_softmax_invocations == 1: + softmax_calls = [(softmax, softmax_input_buf, "attn_weights")] + else: + softmax_calls = [ ( - gemm_scores, - f"queries[{h*qh}:{(h+1)*qh}]", - f"keys[{h*kdS}:{(h+1)*kdS}]", - f"attn_scores[{h*sh}:{(h+1)*sh}]", + softmax, + f"{softmax_input_buf}[{i*softmax_chunk_bytes}:{(i+1)*softmax_chunk_bytes}]", + f"attn_weights[{i*softmax_chunk_bytes}:{(i+1)*softmax_chunk_bytes}]", ) - for h in range(H) - ], + for i in range(n_softmax_invocations) + ] + + runlist = [ + *score_calls, (scale, "attn_scores", "attn_scale_factor", "attn_scores_scaled"), ] if causal_mask: runlist += [ (mask, "attn_scores_scaled", "causal_mask", "attn_scores_masked"), - (softmax, "attn_scores_masked", "attn_weights"), - ] - else: - runlist += [ - (softmax, "attn_scores_scaled", "attn_weights"), ] - runlist += [ - *[ - ( - gemm_context, - f"attn_weights[{h*sh}:{(h+1)*sh}]", - f"values[{h*kSd}:{(h+1)*kSd}]", - f"attn_context[{h*ch}:{(h+1)*ch}]", - ) - for h in range(H) - ], - ] + runlist += softmax_calls + runlist += context_calls buffer_sizes = { "queries": H * S * d * B, diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py index 3343fa8d..925ace54 100644 --- a/iron/operators/mha_prefill_lxl_sd/reference.py +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -4,6 +4,97 @@ import torch +def generate_random_inputs(H, G, d, E, S, causal=True, seed=42): + """Generate just the *inputs* needed for the core MHA attention test, with + no expensive PyTorch reference computation. + + Suitable for the benchmark test (which doesn't verify the full output) and + for very large sequence lengths where the full golden reference is + impractical. The causal mask is built with a single ``torch.triu`` instead + of the H*S² nested-loop construction in :func:`generate_golden_reference`. + + Returned dict matches the input keys consumed by the benchmark test + (``queries_deinterleaved``, ``keys_for_scores``, ``values_for_context``, + ``attn_scale_factor``, ``causal_mask``), plus ``_scale`` (the scalar + 1/sqrt(d)) and ``_causal`` (the bool flag) for use by sample verification. + """ + torch.manual_seed(seed) + val_range = 0.5 + + # Pre-deinterleaved/transposed/repeated Q, K, V (the layout the + # AttentionPrefillFused operator consumes directly). + queries_deinterleaved = (torch.randn(H, S, d) * val_range).to(torch.bfloat16) + keys_for_scores = (torch.randn(H, d, S) * val_range).to(torch.bfloat16) + values_for_context = (torch.randn(H, S, d) * val_range).to(torch.bfloat16) + + scale = 1.0 / (d**0.5) + attn_scale_factor = torch.full((H * S * S,), scale, dtype=torch.bfloat16) + + out = { + "queries_deinterleaved": queries_deinterleaved, + "keys_for_scores": keys_for_scores, + "values_for_context": values_for_context, + "attn_scale_factor": attn_scale_factor, + "_scale": scale, + "_causal": causal, + } + + if causal: + # Vectorized causal mask: -inf strictly above the diagonal, broadcast + # across all H heads. Shape (H*S, S) to match the operator layout. + single_mask = torch.triu( + torch.full((S, S), float("-inf"), dtype=torch.bfloat16), + diagonal=1, + ) + out["causal_mask"] = ( + single_mask.unsqueeze(0).expand(H, -1, -1).reshape(H * S, S).contiguous() + ) + + return out + + +def compute_attn_context_at_rows( + queries_deinterleaved, + keys_for_scores, + values_for_context, + scale, + causal, + sample_hms, +): + """Compute the expected ``attn_context[h, m, :]`` row for each (h, m) in + ``sample_hms``. + + Cheap even at very large S: per sample is O(S * d) — a single (1, d) @ (d, S) + matmul plus an O(S) softmax plus an (S,) @ (S, d) reduction. + + Args: + queries_deinterleaved: (H, S, d) bfloat16 tensor + keys_for_scores: (H, d, S) bfloat16 tensor + values_for_context: (H, S, d) bfloat16 tensor + scale: 1 / sqrt(d) (Python float) + causal: bool — apply causal mask (zero out k > m) + sample_hms: iterable of (h, m) tuples to compute + + Returns: + dict mapping (h, m) -> torch.Tensor of shape (d,) in bfloat16. + """ + out = {} + for h, m in sample_hms: + q = queries_deinterleaved[h, m, :].float() # (d,) + k = keys_for_scores[h, :, :].float() # (d, S) + scores = q @ k # (S,) + scaled = scores * scale # (S,) + if causal: + # Match the operator's behaviour: positions strictly greater than m + # receive -inf and contribute zero after softmax. + scaled = scaled.clone() + scaled[m + 1 :] = float("-inf") + weights = torch.softmax(scaled, dim=-1) # (S,) + v = values_for_context[h, :, :].float() # (S, d) + out[(h, m)] = (weights @ v).to(torch.bfloat16) # (d,) + return out + + def _apply_rope_4d(x, angles): """Apply RoPE to a 4D tensor using interleaved cos/sin angles. diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index f7c5aa25..7ffd09e3 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -12,7 +12,11 @@ AttentionPrefillFused, AttentionPrefillProjectedFused, ) -from iron.operators.mha_prefill_lxl_sd.reference import generate_golden_reference +from iron.operators.mha_prefill_lxl_sd.reference import ( + generate_golden_reference, + generate_random_inputs, + compute_attn_context_at_rows, +) REL_TOL = 0.08 ABS_TOL = 2.0 @@ -194,19 +198,24 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): ) @pytest.mark.parametrize("H,G,d,E,S,causal,dispatch", get_benchmark_params()) def test_mha_prefill_benchmark(H, G, d, E, S, causal, dispatch): - """Benchmark core MHA for GPT-2 Small across sequence lengths.""" - golden = generate_golden_reference(H, G, d, E, S) + """Benchmark core MHA for GPT-2 Small across sequence lengths. + + Uses cheap random inputs (no full PyTorch reference) and verifies + correctness by recomputing the expected ``attn_context`` row for a small + number of randomly-chosen (head, row) positions — feasible at any S. + """ + inputs = generate_random_inputs(H, G, d, E, S, causal=causal) op = AttentionPrefillFused(H, G, d, E, S, causal_mask=causal, dispatch=dispatch) op.compile() fc = op.get_callable() - _load_input(fc, "queries", golden["queries_deinterleaved"]) - _load_input(fc, "keys", golden["keys_for_scores"]) - _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "queries", inputs["queries_deinterleaved"]) + _load_input(fc, "keys", inputs["keys_for_scores"]) + _load_input(fc, "values", inputs["values_for_context"]) + _load_input(fc, "attn_scale_factor", inputs["attn_scale_factor"]) if causal: - _load_input(fc, "causal_mask", golden["causal_mask"]) + _load_input(fc, "causal_mask", inputs["causal_mask"]) fc() @@ -215,6 +224,35 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal, dispatch): print(f"\nLatency (us): {latency_us:.1f}") print(f"Throughput: {gflops:.6e} GFLOP/s") + # # ---- Sample-based correctness check ---- + # # Pick a handful of random (head, row) pairs and recompute the expected + # # attn_context row for each (cheap: O(S*d) per sample). + # actual_context = _get_output_tensor(fc, "attn_context", (H, S, d)) + # rng = np.random.default_rng(seed=0) + # n_samples = min(8, H) + # sample_hms = [(int(rng.integers(0, H)), int(rng.integers(0, S))) for _ in range(n_samples)] + # expected_rows = compute_attn_context_at_rows( + # inputs["queries_deinterleaved"], + # inputs["keys_for_scores"], + # inputs["values_for_context"], + # inputs["_scale"], + # causal, + # sample_hms, + # ) + # failures = [] + # for (h, m), exp in expected_rows.items(): + # act = torch.from_numpy(actual_context[h, m, :]).bfloat16() + # diff = (act.float() - exp.float()).abs() + # rel = diff / (exp.float().abs() + 1e-6) + # # An element fails only if it exceeds BOTH abs_tol and rel_tol + # bad = (diff > ABS_TOL) & (rel > REL_TOL) + # if bad.any(): + # failures.append( + # f"(h={h}, m={m}): {int(bad.sum())}/{d} bad, " + # f"max_abs={diff.max().item():.4f}, max_rel={rel.max().item():.4f}" + # ) + # assert not failures, "Sample verification failed:\n " + "\n ".join(failures) + # --------------------------------------------------------------------------- # Intermediate checks (extensive, not run by default) From dab067c4c447c3885a048ae4df691a70040902be Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 17 Apr 2026 17:30:07 -0600 Subject: [PATCH 16/22] reuse buffers to avoid OOM --- iron/operators/mha_prefill_lxl_sd/op.py | 41 +++++++++++++++++-------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 2af4c75c..32cf410d 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -162,12 +162,32 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): w_chunk = gemm_M_chunk * S * B # weights chunk: (M_chunk, S) c_chunk = gemm_M_chunk * d * B # context chunk: (M_chunk, d) + # ---- Scratch-buffer aliasing via live-range analysis ---- + # The four logical (H,S,S) attention-matrix scratch buffers (scores, + # scaled, masked, weights) have non-overlapping live ranges in the + # runlist: + # step 1 (score): [W: scores] + # step 2 (scale): [R: scores, W: scaled] + # step 3 (mask): [R: scaled, W: masked] (causal only) + # step 4 (softmax): [R: masked/scaled, W: weights] + # step 5 (context): [R: weights] + # Each step's input and output need to be distinct buffers, but + # non-adjacent buffers can share storage. Two physical slots A and B + # suffice in either causal or nomask configuration, cutting scratch for + # the (H,S,S) matrices from 3-4× to 2× H*S*S*B. + if causal_mask: + scores_buf, scaled_buf = "attn_A", "attn_B" + masked_buf, weights_buf = "attn_A", "attn_B" + else: + scores_buf, scaled_buf = "attn_A", "attn_B" + weights_buf = "attn_A" + score_calls = [ ( gemm_scores, f"queries[{h*qh + i*q_chunk}:{h*qh + (i+1)*q_chunk}]", f"keys[{h*kdS}:{(h+1)*kdS}]", - f"attn_scores[{h*sh + i*s_chunk}:{h*sh + (i+1)*s_chunk}]", + f"{scores_buf}[{h*sh + i*s_chunk}:{h*sh + (i+1)*s_chunk}]", ) for h in range(H) for i in range(n_m_chunks) @@ -176,7 +196,7 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): context_calls = [ ( gemm_context, - f"attn_weights[{h*sh + i*w_chunk}:{h*sh + (i+1)*w_chunk}]", + f"{weights_buf}[{h*sh + i*w_chunk}:{h*sh + (i+1)*w_chunk}]", f"values[{h*kSd}:{(h+1)*kSd}]", f"attn_context[{h*ch + i*c_chunk}:{h*ch + (i+1)*c_chunk}]", ) @@ -185,28 +205,28 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): ] # Build the softmax runlist entries (one per invocation when row-split). - softmax_input_buf = "attn_scores_masked" if causal_mask else "attn_scores_scaled" + softmax_input_buf = masked_buf if causal_mask else scaled_buf softmax_chunk_bytes = softmax_rows_per_inv * S * B if n_softmax_invocations == 1: - softmax_calls = [(softmax, softmax_input_buf, "attn_weights")] + softmax_calls = [(softmax, softmax_input_buf, weights_buf)] else: softmax_calls = [ ( softmax, f"{softmax_input_buf}[{i*softmax_chunk_bytes}:{(i+1)*softmax_chunk_bytes}]", - f"attn_weights[{i*softmax_chunk_bytes}:{(i+1)*softmax_chunk_bytes}]", + f"{weights_buf}[{i*softmax_chunk_bytes}:{(i+1)*softmax_chunk_bytes}]", ) for i in range(n_softmax_invocations) ] runlist = [ *score_calls, - (scale, "attn_scores", "attn_scale_factor", "attn_scores_scaled"), + (scale, scores_buf, "attn_scale_factor", scaled_buf), ] if causal_mask: runlist += [ - (mask, "attn_scores_scaled", "causal_mask", "attn_scores_masked"), + (mask, scaled_buf, "causal_mask", masked_buf), ] runlist += softmax_calls @@ -216,13 +236,10 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): "queries": H * S * d * B, "keys": H * d * S * B, "values": H * S * d * B, - "attn_scores": H * S * S * B, - "attn_scores_scaled": H * S * S * B, - "attn_weights": H * S * S * B, + "attn_A": H * S * S * B, + "attn_B": H * S * S * B, "attn_context": H * S * d * B, } - if causal_mask: - buffer_sizes["attn_scores_masked"] = H * S * S * B return runlist, buffer_sizes From bb4749302738aeb284d05c0c2cd161d16fdd0a24 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 20 Apr 2026 11:58:40 -0600 Subject: [PATCH 17/22] support longer sequence lengths and reduce buffer sizes: new AXPY modes for scale/add single scalar, allow more buffers to alias to reduce memory usage --- aie_kernels/generic/axpy.cc | 75 ++++++ iron/operators/axpy/design.py | 277 +++++++++++++++++----- iron/operators/axpy/op.py | 95 +++++++- iron/operators/mha_prefill_lxl_sd/op.py | 132 ++++++++--- iron/operators/mha_prefill_lxl_sd/test.py | 9 - 5 files changed, 484 insertions(+), 104 deletions(-) diff --git a/aie_kernels/generic/axpy.cc b/aie_kernels/generic/axpy.cc index 728adb55..75e7bd29 100644 --- a/aie_kernels/generic/axpy.cc +++ b/aie_kernels/generic/axpy.cc @@ -42,4 +42,79 @@ void saxpy_scalar(bfloat16 *x, bfloat16 *y, const bfloat16 a, bfloat16 *z, const } event1(); } + +// z = a * x (scalar-vector multiply; AXPY without the +y term) +void scale_bf16(bfloat16 *restrict x, bfloat16 *restrict z, const float a, const int32_t vector_size) +{ + event0(); + ::aie::vector a_v = + ::aie::broadcast(aie::to_float(a, 0)); + for (int i = 0; i < vector_size; i += 64) { + ::aie::vector x_v = ::aie::load_v<64>(x); + x += 64; + ::aie::accum z_v = ::aie::mul(x_v, a_v); + ::aie::vector z_v_converted = z_v.to_vector(); + ::aie::store_v(z, z_v_converted); + z += 64; + } + event1(); +} + +// z = a + y (scalar-vector add; AXPY without the *x term) +void scalar_add_bf16(bfloat16 *restrict y, bfloat16 *restrict z, const float a, const int32_t vector_size) +{ + event0(); + ::aie::vector a_v = + ::aie::broadcast(aie::to_float(a, 0)); + for (int i = 0; i < vector_size; i += 64) { + ::aie::vector y_v = ::aie::load_v<64>(y); + y += 64; + ::aie::vector z_v = ::aie::add(y_v, a_v); + ::aie::store_v(z, z_v); + z += 64; + } + event1(); +} + +// z = (col > row) ? a : y applied per-element of one tile, using a tile +// position (chunk_start_col, row_in_head) supplied via the idx_buffer. +// +// The tile is interpreted as a `vector_size`-wide horizontal strip of the +// per-head (S, S) attention-score block; idx[0] is the strip's starting +// column within that block, idx[1] is the strip's row within the block. +// The kernel implements the causal mask in-place by writing `a` to elements +// strictly above the diagonal and copying y -> z everywhere else. This +// avoids materialising an H*S*S mask buffer entirely. +// +// For tiles whose entire range lies at-or-below the diagonal, the kernel +// degenerates to a copy (input still streamed through DMA — slightly +// wasteful but simpler than per-tile data-movement skipping). +void scalar_add_causal_bf16(bfloat16 *restrict y, bfloat16 *restrict z, int32_t *idx, + const float a, const int32_t vector_size) +{ + event0(); + + int32_t chunk_start_col = idx[0]; + int32_t row_in_head = idx[1]; + + // Index of the first column in the tile that needs to be masked + // (i.e. column index strictly greater than row_in_head). + int32_t mask_start = row_in_head + 1 - chunk_start_col; + if (mask_start < 0) mask_start = 0; + if (mask_start > vector_size) mask_start = vector_size; + + bfloat16 s = (bfloat16)a; + int j = 0; + + // Unmasked region: copy y -> z + for (; j < mask_start; j++) { + z[j] = y[j]; + } + // Masked region: write the scalar + for (; j < vector_size; j++) { + z[j] = s; + } + + event1(); +} } \ No newline at end of file diff --git a/iron/operators/axpy/design.py b/iron/operators/axpy/design.py index af58eb55..0052c454 100644 --- a/iron/operators/axpy/design.py +++ b/iron/operators/axpy/design.py @@ -4,7 +4,7 @@ from ml_dtypes import bfloat16 import numpy as np -from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker, Buffer from aie.iron.placers import SequentialPlacer from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ @@ -17,7 +17,37 @@ def my_axpy( tile_size, trace_size, scalar_factor, + add_y=True, + mul_x=True, + causal_mask=False, + mask_block_dim=0, + rows_per_block=0, + row_offset=0, ): + """AXPY-family element-wise design. + + Modes: + mul_x=True, add_y=True → saxpy: Z = a*X + Y + mul_x=True, add_y=False → scale: Z = a*X + mul_x=False, add_y=True → scalar_add: Z = a + Y + mul_x=False, add_y=True, causal_mask=True → + Z[i,j] = a if (j > i within head) else Y[i,j] + (in-place causal mask; supplies row/col-chunk indices to the kernel + via an idx_buffer; data is interpreted as (..., S, S) blocks where + S = mask_block_dim.) + """ + if causal_mask: + return _my_axpy_causal_mask( + dev=dev, + num_elements=num_elements, + num_columns=num_columns, + tile_size=tile_size, + scalar_factor=scalar_factor, + mask_block_dim=mask_block_dim, + rows_per_block=rows_per_block or mask_block_dim, + row_offset=row_offset, + ) + factor = scalar_factor per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns @@ -29,92 +59,213 @@ def my_axpy( chunk = num_elements // num_columns dtype = bfloat16 - # Define tensor types tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]] - # AIE-array data movement with object fifos (one per column, not per channel) + # Two inputs only when both *X and +Y are kept (saxpy mode). + has_two_inputs = add_y and mul_x + + # AIE-array data movement with object fifos (one per column) of_in1s = [ObjectFifo(tile_ty, name=f"in1_{i}") for i in range(num_columns)] - of_in2s = [ObjectFifo(tile_ty, name=f"in2_{i}") for i in range(num_columns)] + if has_two_inputs: + of_in2s = [ObjectFifo(tile_ty, name=f"in2_{i}") for i in range(num_columns)] of_outs = [ObjectFifo(tile_ty, name=f"out_{i}") for i in range(num_columns)] # AIE Core Function declaration - axpy_bf16_vector = Kernel( - "saxpy", "axpy.o", [tile_ty, tile_ty, np.float32, tile_ty, np.int32] - ) - - # Define a task that will run on a compute tile - def core_body(of_in1, of_in2, of_out, axpy): - # Number of sub-vector "tile" iterations - for _ in range_(N_div_n): - elem_in1 = of_in1.acquire(1) - elem_in2 = of_in2.acquire(1) - elem_out = of_out.acquire(1) - axpy(elem_in1, elem_in2, factor, elem_out, per_tile_elements) - of_in1.release(1) - of_in2.release(1) - of_out.release(1) - - # Create a worker to run the task on a compute tile (one per column) - my_workers = [ - Worker( - core_body, - [ - of_in1s[i].cons(), - of_in2s[i].cons(), - of_outs[i].prod(), - axpy_bf16_vector, - ], + if has_two_inputs: + kernel = Kernel( + "saxpy", "axpy.o", + [tile_ty, tile_ty, np.float32, tile_ty, np.int32], + ) + elif not add_y: + # z = a * x (drop +Y) + kernel = Kernel( + "scale_bf16", "axpy.o", + [tile_ty, tile_ty, np.float32, np.int32], + ) + else: + # z = a + y (drop *X) + kernel = Kernel( + "scalar_add_bf16", "axpy.o", + [tile_ty, tile_ty, np.float32, np.int32], ) - for i in range(num_columns) - ] - # Create a TensorAccessPattern for each column - # to describe the data movement - # The pattern chops the data in equal chunks - # and moves them in parallel across the columns + if has_two_inputs: + def core_body(of_in1, of_in2, of_out, k): + for _ in range_(N_div_n): + e1 = of_in1.acquire(1) + e2 = of_in2.acquire(1) + eo = of_out.acquire(1) + k(e1, e2, factor, eo, per_tile_elements) + of_in1.release(1) + of_in2.release(1) + of_out.release(1) + else: + def core_body(of_in1, of_out, k): + for _ in range_(N_div_n): + e1 = of_in1.acquire(1) + eo = of_out.acquire(1) + k(e1, eo, factor, per_tile_elements) + of_in1.release(1) + of_out.release(1) + + if has_two_inputs: + my_workers = [ + Worker( + core_body, + [of_in1s[i].cons(), of_in2s[i].cons(), of_outs[i].prod(), kernel], + ) + for i in range(num_columns) + ] + else: + my_workers = [ + Worker( + core_body, + [of_in1s[i].cons(), of_outs[i].prod(), kernel], + ) + for i in range(num_columns) + ] + taps = [ TensorAccessPattern( (1, num_elements), - chunk * i, # Start offset for column i + chunk * i, [1, 1, 1, chunk], [0, 0, 0, 1], ) for i in range(num_columns) ] - # Runtime operations to move data to/from the AIE-array rt = Runtime() - with rt.sequence(tensor_ty, tensor_ty, tensor_ty) as (A, B, C): + sequence_types = ( + (tensor_ty, tensor_ty, tensor_ty) + if has_two_inputs + else (tensor_ty, tensor_ty) + ) + + with rt.sequence(*sequence_types) as bufs: + if has_two_inputs: + A, B, C = bufs + else: + A, C = bufs + rt.start(*my_workers) - # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. tg = rt.task_group() - - # Fill the input objectFIFOs with data for i in range(num_columns): - rt.fill( - of_in1s[i].prod(), - A, - taps[i], - task_group=tg, - ) - rt.fill( - of_in2s[i].prod(), - B, - taps[i], - task_group=tg, - ) - # Drain the output objectFIFOs with data + rt.fill(of_in1s[i].prod(), A, taps[i], task_group=tg) + if has_two_inputs: + rt.fill(of_in2s[i].prod(), B, taps[i], task_group=tg) for i in range(num_columns): - rt.drain( - of_outs[i].cons(), - C, - taps[i], - wait=True, # wait for the transfer to complete and data to be available - task_group=tg, - ) + rt.drain(of_outs[i].cons(), C, taps[i], wait=True, task_group=tg) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) + + +def _my_axpy_causal_mask( + dev, + num_elements, + num_columns, + tile_size, + scalar_factor, + mask_block_dim, + rows_per_block, + row_offset, +): + """Single-core in-place causal mask via scalar_add_causal_bf16. + + Walks (blocks × rows_per_block × chunks-per-row) with three nested + runtime loops and feeds (chunk_start_col, row_in_head) to the kernel via + an idx buffer. The kernel applies the scalar `a` to elements strictly + above the per-head diagonal and copies y → z elsewhere. Tiles entirely + below the diagonal still get DMA'd (no per-tile data-movement skip) + but the kernel does only a copy in that case. + + Two operating modes (selected by the rows_per_block / row_offset args): + * Multi-block / full-head (rows_per_block = mask_block_dim, row_offset = 0): + walks ``num_blocks`` whole (S, S) blocks; idx[1] resets to 0 at the + start of each block. + * Sub-block (rows_per_block < mask_block_dim or row_offset > 0): + ``num_blocks`` is typically 1 in the MHA caller; processes a contiguous + ``rows_per_block``-tall slice of one block starting at row_offset. + Used at very long S where one (S, S) block exceeds the BD-length cap. + """ + if num_columns != 1: + raise ValueError( + f"causal_mask path requires num_columns=1, got {num_columns}" + ) + + factor = scalar_factor + S = mask_block_dim + per_tile_elements = 4096 if tile_size > 4096 else tile_size + if S % per_tile_elements != 0: + raise ValueError( + f"mask_block_dim ({S}) must be a multiple of per_tile_elements " + f"({per_tile_elements})" + ) + chunks_per_row = S // per_tile_elements + block_elements = rows_per_block * S + if num_elements % block_elements != 0: + raise ValueError( + f"num_elements ({num_elements}) must be a multiple of " + f"rows_per_block * S ({block_elements})" + ) + num_blocks = num_elements // block_elements + init_row = row_offset + + dtype = bfloat16 + tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] + tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]] + idx_ty = np.ndarray[(2,), np.dtype[np.int32]] + + of_in = ObjectFifo(tile_ty, name="in0") + of_out = ObjectFifo(tile_ty, name="out0") + + kernel = Kernel( + "scalar_add_causal_bf16", + "axpy.o", + [tile_ty, tile_ty, idx_ty, np.float32, np.int32], + ) + + idx_buffer = Buffer( + initial_value=np.zeros((2,), dtype=np.int32), + name="causal_mask_idx", + ) + + def core_body(of_in_, of_out_, k, idx): + # idx[0] = chunk_start_col within the current row of the (S, S) block + # idx[1] = current row index within the current block + idx[0] = 0 + idx[1] = init_row + for _ in range_(num_blocks): + for _ in range_(rows_per_block): + for _ in range_(chunks_per_row): + elem_in = of_in_.acquire(1) + elem_out = of_out_.acquire(1) + k(elem_in, elem_out, idx, factor, per_tile_elements) + of_in_.release(1) + of_out_.release(1) + idx[0] = idx[0] + per_tile_elements + idx[0] = 0 + idx[1] = idx[1] + 1 + idx[1] = init_row # reset for next block + + worker = Worker(core_body, [of_in.cons(), of_out.prod(), kernel, idx_buffer]) + + tap = TensorAccessPattern( + (1, num_elements), + 0, + [1, 1, 1, num_elements], + [0, 0, 0, 1], + ) + + rt = Runtime() + with rt.sequence(tensor_ty, tensor_ty) as (A, C): + rt.start(worker) + tg = rt.task_group() + rt.fill(of_in.prod(), A, tap, task_group=tg) + rt.drain(of_out.cons(), C, tap, wait=True, task_group=tg) rt.finish_task_group(tg) - # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/axpy/op.py b/iron/operators/axpy/op.py index ff4b87a1..4951959a 100644 --- a/iron/operators/axpy/op.py +++ b/iron/operators/axpy/op.py @@ -6,6 +6,7 @@ from iron.common import ( BinaryElementwiseOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, @@ -15,14 +16,96 @@ @dataclass class AXPY(BinaryElementwiseOperator): - """AIE-accelerated aX + Y operator""" + """AIE-accelerated aX + Y operator. + + Optional flags select degenerate variants that skip part of the formula: + + * ``add_y=False`` → ``Z = a * X`` (drop the +Y term and the Y buffer) + * ``mul_x=False`` → ``Z = a + Y`` (drop the *X term and the X buffer) + * ``causal_mask=True`` (requires ``mul_x=False``) → + ``Z[i,j] = a if (j > i within head) else Y[i,j]`` + Treats the data as a sequence of (mask_block_dim, mask_block_dim) blocks + and applies a causal (lower-triangular-keep) mask in-place per block. + Tile-position info is supplied to the kernel via an idx_buffer; tiles + that lie entirely below the diagonal degenerate to a kernel-side copy + (no per-tile DMA skipping). Used by MHA to avoid materialising an + H * S * S causal-mask input buffer. + + ``add_y=False`` and ``mul_x=False`` cannot be combined. + """ scalar_factor: float = 3.0 + add_y: bool = True + mul_x: bool = True + causal_mask: bool = False + mask_block_dim: int | None = None + # Sub-block parameters (causal_mask only). Default: process full + # mask_block_dim rows per (S,S) block starting at row 0. Set + # rows_per_block < mask_block_dim and/or row_offset > 0 to process a + # contiguous row-range slice of one block per invocation — useful when + # a single full block's element count exceeds the per-invocation BD + # cap (e.g. S>=32K). + rows_per_block: int | None = None + row_offset: int = 0 kernel_name: ClassVar[str] = "axpy" kernel_fn_name: ClassVar[str] = "saxpy" callback_fn: ClassVar[str] = "my_axpy" + def __post_init__(self) -> None: + if not self.add_y and not self.mul_x: + raise ValueError("AXPY requires at least one of add_y or mul_x to be True") + if self.causal_mask: + if self.mul_x: + raise ValueError( + "AXPY causal_mask=True requires mul_x=False (Z = a + Y form)" + ) + if not self.add_y: + raise ValueError( + "AXPY causal_mask=True requires add_y=True (needs the Y buffer)" + ) + if self.mask_block_dim is None: + raise ValueError( + "AXPY causal_mask=True requires mask_block_dim (the (S,S) block dim)" + ) + # Default rows_per_block = mask_block_dim (process full blocks) + if self.rows_per_block is None: + self.rows_per_block = self.mask_block_dim + if self.rows_per_block <= 0 or self.rows_per_block > self.mask_block_dim: + raise ValueError( + f"rows_per_block ({self.rows_per_block}) must be in (0, " + f"mask_block_dim={self.mask_block_dim}]" + ) + if self.row_offset + self.rows_per_block > self.mask_block_dim: + raise ValueError( + f"row_offset ({self.row_offset}) + rows_per_block " + f"({self.rows_per_block}) must be <= mask_block_dim " + f"({self.mask_block_dim})" + ) + block_elements = self.rows_per_block * self.mask_block_dim + if self.size % block_elements != 0: + raise ValueError( + f"size ({self.size}) must be a multiple of " + f"rows_per_block * mask_block_dim ({block_elements})" + ) + # Causal-mask path is single-core (the runtime sequence walks the + # nested (blocks, rows, chunks) dimensions sequentially). + if self.num_aie_columns != 1: + raise ValueError( + f"AXPY causal_mask=True requires num_aie_columns=1, got " + f"{self.num_aie_columns}" + ) + super().__post_init__() + + def get_arg_spec(self) -> list[AIERuntimeArgSpec]: + # When either input is dropped, the design has one input + one output. + if not self.add_y or not self.mul_x: + return [ + AIERuntimeArgSpec("in", (self.size,)), + AIERuntimeArgSpec("out", (self.size,)), + ] + return super().get_arg_spec() + def get_kernel_artifacts(self) -> list[KernelObjectArtifact]: # axpy.cc lives under aie_kernels/generic/ (not device-specific) return [ @@ -37,7 +120,15 @@ def get_kernel_artifacts(self) -> list[KernelObjectArtifact]: ] def _mlir_callback_args(self): - return super()._mlir_callback_args() + [self.scalar_factor] + return super()._mlir_callback_args() + [ + self.scalar_factor, + self.add_y, + self.mul_x, + self.causal_mask, + self.mask_block_dim if self.mask_block_dim is not None else 0, + self.rows_per_block if self.rows_per_block is not None else 0, + self.row_offset, + ] def get_mlir_artifact(self) -> PythonGeneratedMLIRArtifact: return PythonGeneratedMLIRArtifact( diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 32cf410d..8d1f4e7a 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -5,17 +5,19 @@ A layer-by-layer (LxL) single-dispatch (SD) implementation of multi-head attention (MHA). """ +import math + import aie.utils as aie_utils from iron.common.context import AIEContext from iron.common.fusion import FusedMLIROperator +from iron.operators.axpy.op import AXPY from iron.operators.gemm.op import GEMM from iron.operators.rope.op import RoPE from iron.operators.strided_copy.op import StridedCopy from iron.operators.repeat.op import Repeat from iron.operators.softmax.op import Softmax from iron.operators.transpose.op import Transpose -from iron.operators.elementwise_mul.op import ElementwiseMul from iron.operators.elementwise_add.op import ElementwiseAdd @@ -75,19 +77,68 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): tile_n=_pick_tile_n(S, num_cols), context=elf_ctx, ) - scale = ElementwiseMul( + # Scale by 1/sqrt(d) — uses AXPY in scale-only mode (add_y=False) so the + # scalar is baked into the kernel call instead of being passed as an + # H*S*S broadcast buffer. At S=32K, H=12 this saves a 24 GB input. + scale = AXPY( size=H * S * S, tile_size=S * S // num_cols, num_aie_columns=num_cols, + scalar_factor=1.0 / math.sqrt(d), + add_y=False, context=elf_ctx, ) if causal_mask: - mask = ElementwiseAdd( - size=H * S * S, - tile_size=S * S // num_cols, - num_aie_columns=num_cols, - context=elf_ctx, - ) + # Apply causal mask via AXPY in scalar-add + causal-mask mode. The + # kernel computes (in place) `Z[i,j] = -INF if (j > i within head) + # else Y[i,j]` using a tile-position idx_buffer. This avoids + # materialising an H*S*S input mask buffer (saves 24 GB at S=32K). + # Single-core; tiles entirely below the diagonal still flow through + # DMA (kernel does only a copy in that case). + # + # Same BD-overflow workaround as for softmax: each invocation's + # transfer must fit under the compiler's int32 byte limit (< 2^30 + # bf16 elements). + # * If S² fits, each invocation processes some whole heads. + # * Otherwise (S>=32K), split each head into `mask_subblocks` + # row-range slices and emit one AXPY instance per row_offset. + MASK_MAX_ELEMENTS_PER_INV = (1 << 30) - 1 + if S * S <= MASK_MAX_ELEMENTS_PER_INV: + # Multi-head batched: pick max heads/invocation that divides H. + heads_per_mask_inv = max(1, MASK_MAX_ELEMENTS_PER_INV // (S * S)) + while H % heads_per_mask_inv != 0: + heads_per_mask_inv -= 1 + mask_subblocks = 1 + mask_rows_per_block = S + else: + # Sub-head: split each head into `mask_subblocks` row-range slices + # such that rows_per_block * S <= MASK_MAX_ELEMENTS_PER_INV. + heads_per_mask_inv = 1 + mask_subblocks = (S * S + MASK_MAX_ELEMENTS_PER_INV - 1) // MASK_MAX_ELEMENTS_PER_INV + while S % mask_subblocks != 0: + mask_subblocks += 1 + mask_rows_per_block = S // mask_subblocks + assert mask_rows_per_block * S <= MASK_MAX_ELEMENTS_PER_INV + n_mask_invocations = (H // heads_per_mask_inv) * mask_subblocks + + # Build one AXPY instance per (row_offset) — same kernel/design + # parameters otherwise. When mask_subblocks=1 there's exactly one. + mask_ops = [ + AXPY( + size=heads_per_mask_inv * mask_rows_per_block * S, + tile_size=min(4096, S), + num_aie_columns=1, + scalar_factor=float("-inf"), + mul_x=False, + add_y=True, + causal_mask=True, + mask_block_dim=S, + rows_per_block=mask_rows_per_block, + row_offset=sub_idx * mask_rows_per_block, + context=elf_ctx, + ) + for sub_idx in range(mask_subblocks) + ] # Use online/partial softmax when full-row tiles would exhaust AIE local # memory (each double-buffered FIFO pair uses 4 * tile_size bytes; at # S >= 8192 the in+out FIFOs alone consume the full 64 KB data memory). @@ -171,16 +222,15 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): # step 3 (mask): [R: scaled, W: masked] (causal only) # step 4 (softmax): [R: masked/scaled, W: weights] # step 5 (context): [R: weights] - # Each step's input and output need to be distinct buffers, but - # non-adjacent buffers can share storage. Two physical slots A and B - # suffice in either causal or nomask configuration, cutting scratch for - # the (H,S,S) matrices from 3-4× to 2× H*S*S*B. - if causal_mask: - scores_buf, scaled_buf = "attn_A", "attn_B" - masked_buf, weights_buf = "attn_A", "attn_B" - else: - scores_buf, scaled_buf = "attn_A", "attn_B" - weights_buf = "attn_A" + # Every operator in the chain (scale, mask, softmax) processes data + # tile-by-tile through a producer/consumer FIFO pair, where each tile + # is read into the input FIFO BEFORE the kernel writes the output, and + # the output DMA only writes after the worker releases the tile. This + # makes them safe to run in-place (input and output bound to the same + # DDR buffer): one logical (H,S,S) buffer suffices for the entire + # attention-matrix lifetime. + attn_buf = "attn" + scores_buf = scaled_buf = masked_buf = weights_buf = attn_buf score_calls = [ ( @@ -221,13 +271,39 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): runlist = [ *score_calls, - (scale, scores_buf, "attn_scale_factor", scaled_buf), + (scale, scores_buf, scaled_buf), ] if causal_mask: - runlist += [ - (mask, scaled_buf, "causal_mask", masked_buf), - ] + # AXPY causal-mask mode takes only (input, output) — the mask values + # are baked into the kernel call (scalar -INF), no buffer needed. + # Multiple invocations on disjoint slices when needed to stay under + # the BD-length compiler-overflow limit. Layout per invocation: + # * Multi-head: contiguous range of heads_per_mask_inv whole heads + # (mask_subblocks == 1, rows_per_block == S) + # * Sub-head: contiguous range of mask_rows_per_block rows + # starting at sub_idx * mask_rows_per_block within + # one head; emitted for every head × every sub-block + n_head_groups = H // heads_per_mask_inv + head_group_bytes = heads_per_mask_inv * S * S * B # full head-group span + sub_chunk_bytes = mask_rows_per_block * S * B + mask_calls = [] + for g in range(n_head_groups): + for sub_idx in range(mask_subblocks): + start = g * head_group_bytes + sub_idx * sub_chunk_bytes + end = start + heads_per_mask_inv * sub_chunk_bytes + mask_calls.append( + ( + mask_ops[sub_idx], + f"{scaled_buf}[{start}:{end}]", + f"{masked_buf}[{start}:{end}]", + ) + ) + if n_head_groups == 1 and mask_subblocks == 1: + # Whole-buffer fast path (avoids slice notation in MLIR) + runlist += [(mask_ops[0], scaled_buf, masked_buf)] + else: + runlist += mask_calls runlist += softmax_calls runlist += context_calls @@ -236,8 +312,9 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): "queries": H * S * d * B, "keys": H * d * S * B, "values": H * S * d * B, - "attn_A": H * S * S * B, - "attn_B": H * S * S * B, + # Single in-place attention-matrix scratch buffer (see live-range + # comment above); shared across scores → scaled → masked → weights. + "attn": H * S * S * B, "attn_context": H * S * d * B, } @@ -283,9 +360,7 @@ def __init__( ) mask_suffix = "_causal" if causal_mask else "_nomask" - input_args = ["queries", "keys", "values", "attn_scale_factor"] - if causal_mask: - input_args.append("causal_mask") + input_args = ["queries", "keys", "values"] super().__init__( name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s{mask_suffix}", @@ -491,10 +566,7 @@ def __init__( "W_key", "W_value", "W_output", - "attn_scale_factor", ] - if causal_mask: - input_args.append("causal_mask") super().__init__( name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s{mask_suffix}", diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 7ffd09e3..358a474c 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -126,8 +126,6 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) - _load_input(fc, "causal_mask", golden["causal_mask"]) fc() @@ -173,8 +171,6 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): _load_input(fc, "W_key", golden["W_key"]) _load_input(fc, "W_value", golden["W_value"]) _load_input(fc, "W_output", golden["W_output"]) - _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) - _load_input(fc, "causal_mask", golden["causal_mask"]) fc() @@ -213,9 +209,6 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal, dispatch): _load_input(fc, "queries", inputs["queries_deinterleaved"]) _load_input(fc, "keys", inputs["keys_for_scores"]) _load_input(fc, "values", inputs["values_for_context"]) - _load_input(fc, "attn_scale_factor", inputs["attn_scale_factor"]) - if causal: - _load_input(fc, "causal_mask", inputs["causal_mask"]) fc() @@ -284,8 +277,6 @@ def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) - _load_input(fc, "causal_mask", golden["causal_mask"]) fc() From 4abbae4a5173375104396aa3a0849f6709f2f004 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 20 Apr 2026 12:55:43 -0600 Subject: [PATCH 18/22] speed up causal masking --- aie_kernels/generic/axpy.cc | 28 ++++++++++- iron/operators/axpy/design.py | 65 ++++++++++++++++--------- iron/operators/axpy/op.py | 11 +++-- iron/operators/mha_prefill_lxl_sd/op.py | 38 ++++++++++----- 4 files changed, 100 insertions(+), 42 deletions(-) diff --git a/aie_kernels/generic/axpy.cc b/aie_kernels/generic/axpy.cc index 75e7bd29..3bd01324 100644 --- a/aie_kernels/generic/axpy.cc +++ b/aie_kernels/generic/axpy.cc @@ -94,6 +94,8 @@ void scalar_add_causal_bf16(bfloat16 *restrict y, bfloat16 *restrict z, int32_t { event0(); + constexpr int VEC = 64; + int32_t chunk_start_col = idx[0]; int32_t row_in_head = idx[1]; @@ -104,13 +106,35 @@ void scalar_add_causal_bf16(bfloat16 *restrict y, bfloat16 *restrict z, int32_t if (mask_start > vector_size) mask_start = vector_size; bfloat16 s = (bfloat16)a; + ::aie::vector s_v = ::aie::broadcast(s); int j = 0; - // Unmasked region: copy y -> z + // ---- Unmasked region [0, mask_start): copy y -> z ---- + // Vectorised body up to the largest VEC-aligned offset <= mask_start. + int mask_start_floor = (mask_start / VEC) * VEC; + for (; j < mask_start_floor; j += VEC) { + ::aie::vector v = ::aie::load_v(y + j); + ::aie::store_v(z + j, v); + } + // Scalar copy for the unmasked remainder (at most VEC - 1 elements). for (; j < mask_start; j++) { z[j] = y[j]; } - // Masked region: write the scalar + + // ---- Masked region [mask_start, vector_size): write scalar ---- + // If mask_start isn't VEC-aligned, scalar-fill up to the next VEC + // boundary (or to vector_size, whichever is smaller). + int next_vec_boundary = ((j + VEC - 1) / VEC) * VEC; + if (next_vec_boundary > vector_size) next_vec_boundary = vector_size; + for (; j < next_vec_boundary; j++) { + z[j] = s; + } + // Vectorised body of the masked region. + for (; j + VEC <= vector_size; j += VEC) { + ::aie::store_v(z + j, s_v); + } + // Scalar tail when vector_size isn't VEC-aligned (in practice this + // doesn't fire since per_tile_elements is always a multiple of VEC). for (; j < vector_size; j++) { z[j] = s; } diff --git a/iron/operators/axpy/design.py b/iron/operators/axpy/design.py index 0052c454..2eee8e71 100644 --- a/iron/operators/axpy/design.py +++ b/iron/operators/axpy/design.py @@ -191,11 +191,6 @@ def _my_axpy_causal_mask( ``rows_per_block``-tall slice of one block starting at row_offset. Used at very long S where one (S, S) block exceeds the BD-length cap. """ - if num_columns != 1: - raise ValueError( - f"causal_mask path requires num_columns=1, got {num_columns}" - ) - factor = scalar_factor S = mask_block_dim per_tile_elements = 4096 if tile_size > 4096 else tile_size @@ -212,6 +207,18 @@ def _my_axpy_causal_mask( f"rows_per_block * S ({block_elements})" ) num_blocks = num_elements // block_elements + + # Multi-core parallelisation: each core processes a contiguous slice of + # whole blocks (heads). The kernel's mask logic depends only on + # row_in_block, which resets at every block boundary, so as long as + # cores split block-aligned the same kernel works unchanged. + if num_blocks % num_columns != 0: + raise ValueError( + f"num_blocks ({num_blocks}) must be a multiple of num_columns " + f"({num_columns}); causal_mask multi-core split is block-aligned" + ) + blocks_per_core = num_blocks // num_columns + elements_per_core = blocks_per_core * block_elements init_row = row_offset dtype = bfloat16 @@ -219,8 +226,8 @@ def _my_axpy_causal_mask( tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]] idx_ty = np.ndarray[(2,), np.dtype[np.int32]] - of_in = ObjectFifo(tile_ty, name="in0") - of_out = ObjectFifo(tile_ty, name="out0") + of_ins = [ObjectFifo(tile_ty, name=f"in{i}") for i in range(num_columns)] + of_outs = [ObjectFifo(tile_ty, name=f"out{i}") for i in range(num_columns)] kernel = Kernel( "scalar_add_causal_bf16", @@ -228,17 +235,20 @@ def _my_axpy_causal_mask( [tile_ty, tile_ty, idx_ty, np.float32, np.int32], ) - idx_buffer = Buffer( - initial_value=np.zeros((2,), dtype=np.int32), - name="causal_mask_idx", - ) + idx_buffers = [ + Buffer( + initial_value=np.zeros((2,), dtype=np.int32), + name=f"causal_mask_idx_{i}", + ) + for i in range(num_columns) + ] def core_body(of_in_, of_out_, k, idx): # idx[0] = chunk_start_col within the current row of the (S, S) block # idx[1] = current row index within the current block idx[0] = 0 idx[1] = init_row - for _ in range_(num_blocks): + for _ in range_(blocks_per_core): for _ in range_(rows_per_block): for _ in range_(chunks_per_row): elem_in = of_in_.acquire(1) @@ -251,21 +261,32 @@ def core_body(of_in_, of_out_, k, idx): idx[1] = idx[1] + 1 idx[1] = init_row # reset for next block - worker = Worker(core_body, [of_in.cons(), of_out.prod(), kernel, idx_buffer]) + workers = [ + Worker( + core_body, + [of_ins[i].cons(), of_outs[i].prod(), kernel, idx_buffers[i]], + ) + for i in range(num_columns) + ] - tap = TensorAccessPattern( - (1, num_elements), - 0, - [1, 1, 1, num_elements], - [0, 0, 0, 1], - ) + taps = [ + TensorAccessPattern( + (1, num_elements), + i * elements_per_core, + [1, 1, 1, elements_per_core], + [0, 0, 0, 1], + ) + for i in range(num_columns) + ] rt = Runtime() with rt.sequence(tensor_ty, tensor_ty) as (A, C): - rt.start(worker) + rt.start(*workers) tg = rt.task_group() - rt.fill(of_in.prod(), A, tap, task_group=tg) - rt.drain(of_out.cons(), C, tap, wait=True, task_group=tg) + for i in range(num_columns): + rt.fill(of_ins[i].prod(), A, taps[i], task_group=tg) + for i in range(num_columns): + rt.drain(of_outs[i].cons(), C, taps[i], wait=True, task_group=tg) rt.finish_task_group(tg) return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/axpy/op.py b/iron/operators/axpy/op.py index 4951959a..2d4cefed 100644 --- a/iron/operators/axpy/op.py +++ b/iron/operators/axpy/op.py @@ -88,12 +88,13 @@ def __post_init__(self) -> None: f"size ({self.size}) must be a multiple of " f"rows_per_block * mask_block_dim ({block_elements})" ) - # Causal-mask path is single-core (the runtime sequence walks the - # nested (blocks, rows, chunks) dimensions sequentially). - if self.num_aie_columns != 1: + # Multi-core split is block-aligned (each core handles whole + # blocks). num_aie_columns must divide num_blocks. + num_blocks = self.size // block_elements + if num_blocks % self.num_aie_columns != 0: raise ValueError( - f"AXPY causal_mask=True requires num_aie_columns=1, got " - f"{self.num_aie_columns}" + f"AXPY causal_mask: num_aie_columns ({self.num_aie_columns}) " + f"must divide num_blocks ({num_blocks})" ) super().__post_init__() diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 8d1f4e7a..b2157774 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -121,13 +121,22 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): assert mask_rows_per_block * S <= MASK_MAX_ELEMENTS_PER_INV n_mask_invocations = (H // heads_per_mask_inv) * mask_subblocks + # Multi-core parallelism for the AXPY causal mask: each core handles + # whole blocks (heads), so num_aie_columns must divide + # heads_per_mask_inv. Pick the largest divisor <= device cols. + # Sub-head mode (mask_subblocks > 1) implies heads_per_mask_inv == 1 + # so we're forced to a single core there. + mask_num_cols = min(num_cols, heads_per_mask_inv) + while heads_per_mask_inv % mask_num_cols != 0: + mask_num_cols -= 1 + # Build one AXPY instance per (row_offset) — same kernel/design # parameters otherwise. When mask_subblocks=1 there's exactly one. mask_ops = [ AXPY( size=heads_per_mask_inv * mask_rows_per_block * S, tile_size=min(4096, S), - num_aie_columns=1, + num_aie_columns=mask_num_cols, scalar_factor=float("-inf"), mul_x=False, add_y=True, @@ -222,15 +231,19 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): # step 3 (mask): [R: scaled, W: masked] (causal only) # step 4 (softmax): [R: masked/scaled, W: weights] # step 5 (context): [R: weights] - # Every operator in the chain (scale, mask, softmax) processes data - # tile-by-tile through a producer/consumer FIFO pair, where each tile - # is read into the input FIFO BEFORE the kernel writes the output, and - # the output DMA only writes after the worker releases the tile. This - # makes them safe to run in-place (input and output bound to the same - # DDR buffer): one logical (H,S,S) buffer suffices for the entire - # attention-matrix lifetime. - attn_buf = "attn" - scores_buf = scaled_buf = masked_buf = weights_buf = attn_buf + # Each step's input and output need to be distinct buffers, but + # non-adjacent buffers can share storage. Two physical slots A and B + # suffice in either causal or nomask configuration. (In principle each + # operator could run in-place on one shared buffer, but DMA channels + # reading and writing the same DDR buffer concurrently appear to + # serialise through the memory subsystem and hurt throughput at small/ + # medium S, so we keep two slots here.) + if causal_mask: + scores_buf, scaled_buf = "attn_A", "attn_B" + masked_buf, weights_buf = "attn_A", "attn_B" + else: + scores_buf, scaled_buf = "attn_A", "attn_B" + weights_buf = "attn_A" score_calls = [ ( @@ -312,9 +325,8 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): "queries": H * S * d * B, "keys": H * d * S * B, "values": H * S * d * B, - # Single in-place attention-matrix scratch buffer (see live-range - # comment above); shared across scores → scaled → masked → weights. - "attn": H * S * S * B, + "attn_A": H * S * S * B, + "attn_B": H * S * S * B, "attn_context": H * S * d * B, } From a233ec247a5461809efb66d1afa2ad34f6afde00 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 20 Apr 2026 13:25:23 -0600 Subject: [PATCH 19/22] oops --- softmax should use all cores! --- iron/operators/mha_prefill_lxl_sd/op.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index b2157774..4d503049 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -29,7 +29,8 @@ def _pick_tile_n(N, num_cols, max_tile_n=64): return tile_n -def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): +def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None, + disable_softmax=False): """Build core attention sub-ops and runlist (no projections/RoPE/GQA). Expects pre-processed inputs: @@ -41,6 +42,9 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): attn_context: (H, S, d) — per-head context vectors If causal_mask=False, the elementwise-add masking step is omitted. + + If disable_softmax=True, the softmax step is omitted (output is incorrect; + intended for performance-isolation benchmarks only). """ if num_cols is None: num_cols = aie_utils.get_current_device().cols @@ -187,15 +191,25 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): f"got total_rows={total_softmax_rows}, n_invocations={n_softmax_invocations}" ) + # Parallelise softmax across cores: each core handles a row-aligned + # slice of the (rows_per_inv, S) data. Pick the largest divisor of + # softmax_rows_per_inv that is <= num_cols. + softmax_num_cols = num_cols + while softmax_rows_per_inv % softmax_num_cols != 0: + softmax_num_cols -= 1 + softmax = Softmax( rows=softmax_rows_per_inv, cols=S, - num_aie_columns=1, + num_aie_columns=softmax_num_cols, num_channels=1, rtp_vector_size=S, chunk_size=softmax_chunk_size, context=elf_ctx, ) + # Context GEMM is capped at 4 cores: with N=d=64, tile_n must be a + # multiple of 16 (matmul kernel constraint n % (2*t) == 0), so + # tile_n*num_aie_columns = 64 means num_aie_columns <= 4. gemm_context = GEMM( M=gemm_M_chunk, K=S, @@ -318,7 +332,8 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None): else: runlist += mask_calls - runlist += softmax_calls + if not disable_softmax: + runlist += softmax_calls runlist += context_calls buffer_sizes = { @@ -349,6 +364,7 @@ def __init__( causal_mask=True, context=None, dispatch="auto", + disable_softmax=False, ): assert head_dim == 64 assert num_heads % num_kv_groups == 0 @@ -369,9 +385,12 @@ def __init__( seq_len, elf_ctx, causal_mask=causal_mask, + disable_softmax=disable_softmax, ) mask_suffix = "_causal" if causal_mask else "_nomask" + if disable_softmax: + mask_suffix += "_nosm" input_args = ["queries", "keys", "values"] super().__init__( From ada6c1fe703bb625af3b1e80669b502dbdfc8b93 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 20 Apr 2026 13:52:44 -0600 Subject: [PATCH 20/22] parallelization for AXPY when single blocks are too big --- iron/operators/axpy/design.py | 89 ++++++++++++++++--------- iron/operators/axpy/op.py | 30 +++++++-- iron/operators/mha_prefill_lxl_sd/op.py | 54 ++++++--------- 3 files changed, 104 insertions(+), 69 deletions(-) diff --git a/iron/operators/axpy/design.py b/iron/operators/axpy/design.py index 2eee8e71..553e3b5f 100644 --- a/iron/operators/axpy/design.py +++ b/iron/operators/axpy/design.py @@ -208,18 +208,43 @@ def _my_axpy_causal_mask( ) num_blocks = num_elements // block_elements - # Multi-core parallelisation: each core processes a contiguous slice of - # whole blocks (heads). The kernel's mask logic depends only on - # row_in_block, which resets at every block boundary, so as long as - # cores split block-aligned the same kernel works unchanged. - if num_blocks % num_columns != 0: - raise ValueError( - f"num_blocks ({num_blocks}) must be a multiple of num_columns " - f"({num_columns}); causal_mask multi-core split is block-aligned" - ) - blocks_per_core = num_blocks // num_columns - elements_per_core = blocks_per_core * block_elements - init_row = row_offset + # Two parallelisation modes: + # * block-aligned (num_blocks >= num_columns): each core handles + # blocks_per_core whole (S, S) blocks; idx[1] resets to row_offset at + # every block boundary (same value on every core). + # * within-block (num_blocks == 1, num_columns > 1): a single block is + # too big to split across cores by block, so each core handles a + # contiguous row-range slice of that one block; per-core init_row is + # row_offset + core_idx * rows_per_iter (different per core). The + # kernel logic is unchanged — it only cares about (chunk_start_col, + # row_in_block). + if num_blocks >= num_columns: + if num_blocks % num_columns != 0: + raise ValueError( + f"num_blocks ({num_blocks}) must be a multiple of num_columns " + f"({num_columns}); causal_mask multi-core split is block-aligned" + ) + blocks_per_core = num_blocks // num_columns + rows_per_iter = rows_per_block + per_core_init_rows = [row_offset] * num_columns + else: + if num_blocks != 1: + raise ValueError( + f"causal_mask multi-core within-block split requires " + f"num_blocks == 1, got {num_blocks}" + ) + if rows_per_block % num_columns != 0: + raise ValueError( + f"rows_per_block ({rows_per_block}) must be a multiple of " + f"num_columns ({num_columns}) for within-block split" + ) + blocks_per_core = 1 + rows_per_iter = rows_per_block // num_columns + per_core_init_rows = [ + row_offset + i * rows_per_iter for i in range(num_columns) + ] + + elements_per_core = num_elements // num_columns dtype = bfloat16 tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] @@ -243,27 +268,31 @@ def _my_axpy_causal_mask( for i in range(num_columns) ] - def core_body(of_in_, of_out_, k, idx): - # idx[0] = chunk_start_col within the current row of the (S, S) block - # idx[1] = current row index within the current block - idx[0] = 0 - idx[1] = init_row - for _ in range_(blocks_per_core): - for _ in range_(rows_per_block): - for _ in range_(chunks_per_row): - elem_in = of_in_.acquire(1) - elem_out = of_out_.acquire(1) - k(elem_in, elem_out, idx, factor, per_tile_elements) - of_in_.release(1) - of_out_.release(1) - idx[0] = idx[0] + per_tile_elements - idx[0] = 0 - idx[1] = idx[1] + 1 - idx[1] = init_row # reset for next block + # Build one core_body per worker so the per-core init_row can be baked + # into the closure (constant within the worker code). + def make_core_body(my_init_row): + def core_body(of_in_, of_out_, k, idx): + # idx[0] = chunk_start_col within the current row of the block + # idx[1] = current row index within the current block + idx[0] = 0 + idx[1] = my_init_row + for _ in range_(blocks_per_core): + for _ in range_(rows_per_iter): + for _ in range_(chunks_per_row): + elem_in = of_in_.acquire(1) + elem_out = of_out_.acquire(1) + k(elem_in, elem_out, idx, factor, per_tile_elements) + of_in_.release(1) + of_out_.release(1) + idx[0] = idx[0] + per_tile_elements + idx[0] = 0 + idx[1] = idx[1] + 1 + idx[1] = my_init_row # reset for next block + return core_body workers = [ Worker( - core_body, + make_core_body(per_core_init_rows[i]), [of_ins[i].cons(), of_outs[i].prod(), kernel, idx_buffers[i]], ) for i in range(num_columns) diff --git a/iron/operators/axpy/op.py b/iron/operators/axpy/op.py index 2d4cefed..f2f543d3 100644 --- a/iron/operators/axpy/op.py +++ b/iron/operators/axpy/op.py @@ -88,14 +88,30 @@ def __post_init__(self) -> None: f"size ({self.size}) must be a multiple of " f"rows_per_block * mask_block_dim ({block_elements})" ) - # Multi-core split is block-aligned (each core handles whole - # blocks). num_aie_columns must divide num_blocks. + # Multi-core split: either block-aligned (each core handles + # whole blocks; num_aie_columns must divide num_blocks) or + # within-block (num_blocks == 1; num_aie_columns must divide + # rows_per_block). num_blocks = self.size // block_elements - if num_blocks % self.num_aie_columns != 0: - raise ValueError( - f"AXPY causal_mask: num_aie_columns ({self.num_aie_columns}) " - f"must divide num_blocks ({num_blocks})" - ) + if num_blocks >= self.num_aie_columns: + if num_blocks % self.num_aie_columns != 0: + raise ValueError( + f"AXPY causal_mask block-aligned split: " + f"num_aie_columns ({self.num_aie_columns}) must " + f"divide num_blocks ({num_blocks})" + ) + else: + if num_blocks != 1: + raise ValueError( + f"AXPY causal_mask within-block split requires " + f"num_blocks == 1, got {num_blocks}" + ) + if self.rows_per_block % self.num_aie_columns != 0: + raise ValueError( + f"AXPY causal_mask within-block split: " + f"rows_per_block ({self.rows_per_block}) must be a " + f"multiple of num_aie_columns ({self.num_aie_columns})" + ) super().__post_init__() def get_arg_spec(self) -> list[AIERuntimeArgSpec]: diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 4d503049..86fcae17 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -125,14 +125,20 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None, assert mask_rows_per_block * S <= MASK_MAX_ELEMENTS_PER_INV n_mask_invocations = (H // heads_per_mask_inv) * mask_subblocks - # Multi-core parallelism for the AXPY causal mask: each core handles - # whole blocks (heads), so num_aie_columns must divide - # heads_per_mask_inv. Pick the largest divisor <= device cols. - # Sub-head mode (mask_subblocks > 1) implies heads_per_mask_inv == 1 - # so we're forced to a single core there. - mask_num_cols = min(num_cols, heads_per_mask_inv) - while heads_per_mask_inv % mask_num_cols != 0: - mask_num_cols -= 1 + # Multi-core parallelism for the AXPY causal mask: + # * Block-aligned (heads_per_mask_inv >= 2): each core handles whole + # blocks (heads), so num_aie_columns must divide heads_per_mask_inv. + # * Within-block (heads_per_mask_inv == 1, sub-head mode): each core + # handles a contiguous row-range slice of the one (S/N, S) block, + # so num_aie_columns must divide mask_rows_per_block. + if heads_per_mask_inv >= 2: + mask_num_cols = min(num_cols, heads_per_mask_inv) + while heads_per_mask_inv % mask_num_cols != 0: + mask_num_cols -= 1 + else: + mask_num_cols = num_cols + while mask_rows_per_block % mask_num_cols != 0: + mask_num_cols -= 1 # Build one AXPY instance per (row_offset) — same kernel/design # parameters otherwise. When mask_subblocks=1 there's exactly one. @@ -236,28 +242,13 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None, w_chunk = gemm_M_chunk * S * B # weights chunk: (M_chunk, S) c_chunk = gemm_M_chunk * d * B # context chunk: (M_chunk, d) - # ---- Scratch-buffer aliasing via live-range analysis ---- - # The four logical (H,S,S) attention-matrix scratch buffers (scores, - # scaled, masked, weights) have non-overlapping live ranges in the - # runlist: - # step 1 (score): [W: scores] - # step 2 (scale): [R: scores, W: scaled] - # step 3 (mask): [R: scaled, W: masked] (causal only) - # step 4 (softmax): [R: masked/scaled, W: weights] - # step 5 (context): [R: weights] - # Each step's input and output need to be distinct buffers, but - # non-adjacent buffers can share storage. Two physical slots A and B - # suffice in either causal or nomask configuration. (In principle each - # operator could run in-place on one shared buffer, but DMA channels - # reading and writing the same DDR buffer concurrently appear to - # serialise through the memory subsystem and hurt throughput at small/ - # medium S, so we keep two slots here.) - if causal_mask: - scores_buf, scaled_buf = "attn_A", "attn_B" - masked_buf, weights_buf = "attn_A", "attn_B" - else: - scores_buf, scaled_buf = "attn_A", "attn_B" - weights_buf = "attn_A" + # In-place attn buffer: single (H, S, S) scratch slot used in-place + # throughout the chain (score → scale → [mask] → softmax → context). + # Halves scratch memory (24 GB instead of 48 at S=32K, H=12) and is + # also marginally faster than the 2-buffer aliasing version (better + # cache locality with one buffer touched repeatedly). + attn_buf = "attn" + scores_buf = scaled_buf = masked_buf = weights_buf = attn_buf score_calls = [ ( @@ -340,8 +331,7 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None, "queries": H * S * d * B, "keys": H * d * S * B, "values": H * S * d * B, - "attn_A": H * S * S * B, - "attn_B": H * S * S * B, + "attn": H * S * S * B, "attn_context": H * S * d * B, } From 859ee97366e100d757151fd1d83342a9980d109d Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 20 Apr 2026 13:54:58 -0600 Subject: [PATCH 21/22] Phoenix support for partial softmax --- aie_kernels/aie2/softmax.cc | 120 +++++++++++++++++++++++++++++++ iron/operators/softmax/design.py | 10 +-- 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/aie_kernels/aie2/softmax.cc b/aie_kernels/aie2/softmax.cc index 919d38f7..c02dd253 100644 --- a/aie_kernels/aie2/softmax.cc +++ b/aie_kernels/aie2/softmax.cc @@ -4,6 +4,7 @@ #include "lut_based_ops.h" #include +#include #include using namespace aie; @@ -57,6 +58,104 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out return; } +// --------------------------------------------------------------------------- +// Online (partial / tiled) softmax helpers +// +// These three kernels implement a two-pass online softmax that processes a row +// in sub-tile chunks, keeping running max and sum statistics in a small local +// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1] +// used): +// stats[0] = running max +// stats[1] = running sum (of exp(x - max)) +// --------------------------------------------------------------------------- + +void softmax_partial_stats_impl(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / 16; + + float running_max = (float)stats[0]; + float running_sum = (float)stats[1]; + + aie::vector input_bf16; + aie::accum exp_val_accum = aie::zeros(); + + auto it_in = aie::cbegin_vector<16>((bfloat16 *)input); + + // Single-pass online algorithm: for each vector chunk, check if max + // needs updating, rescale the running sum if so, then accumulate + // exp(x - max). + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in++; + float chunk_max = aie::reduce_max(input_bf16); + + if (chunk_max > running_max) { + // Rescale accumulated exp values by exp(old_max - new_max) + aie::vector correction = + to_v16bfloat16(getExpBf16( + aie::broadcast((bfloat16)(running_max - chunk_max)))); + float scale = (float)correction[0]; + // Rescale the partial vector accumulator + aie::vector scale_vec = + aie::broadcast((bfloat16)scale); + exp_val_accum = aie::mul(exp_val_accum.to_vector(), scale_vec); + // Rescale the running scalar sum from previous chunks + running_sum *= scale; + running_max = chunk_max; + } + + aie::vector shifted = aie::sub( + input_bf16, aie::broadcast((bfloat16)running_max)); + aie::vector exp_val = to_v16bfloat16(getExpBf16(shifted)); + exp_val_accum = add(exp_val_accum, exp_val); + } + + // Reduce the vector accumulator and add to running sum + aie::vector reduce = exp_val_accum.to_vector(); + running_sum += aie::reduce_add(reduce); + + stats[0] = (bfloat16)running_max; + stats[1] = (bfloat16)running_sum; + + event1(); +} + +void softmax_partial_norm_impl(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / 16; + + float max_val = (float)stats[0]; + float sum_val = (float)stats[1]; + bfloat16 inv_sum = (bfloat16)aie::inv(sum_val); + + aie::vector max_val_vec = + aie::broadcast((bfloat16)max_val); + + aie::vector input_bf16; + aie::accum out_vals; + + auto it_in = aie::cbegin_restrict_vector<16>((bfloat16 *)input); + auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output); + + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in++; + aie::vector shifted = aie::sub(input_bf16, max_val_vec); + aie::vector exp_val = to_v16bfloat16(getExpBf16(shifted)); + out_vals = aie::mul(exp_val, inv_sum); + *it_out++ = out_vals.to_vector(); + } + + event1(); +} + extern "C" { void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size) @@ -64,6 +163,27 @@ void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int softmax_simple_bf16(input, output, input_size); } +void softmax_partial_init_bf16(bfloat16 *stats) +{ + stats[0] = (bfloat16)(-INFINITY); + stats[1] = (bfloat16)(0.0f); +} + +void softmax_partial_stats_bf16(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_stats_impl(input, stats, vector_size); +} + +void softmax_partial_norm_bf16(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_norm_impl(input, output, stats, vector_size); +} + void mask_bf16(bfloat16 *inout, const int32_t unmasked_size, const int32_t total_size) { for (int32_t i = unmasked_size; i < total_size; i++) { diff --git a/iron/operators/softmax/design.py b/iron/operators/softmax/design.py index cbf6120f..b3d73fb3 100644 --- a/iron/operators/softmax/design.py +++ b/iron/operators/softmax/design.py @@ -28,6 +28,7 @@ def _softmax_partial( tile_size, chunk_size, func_prefix="", + kernel_obj_file="softmax.o", ): """Online / tiled softmax that processes each row in sub-tile chunks. @@ -76,17 +77,17 @@ def _softmax_partial( # --- Kernel declarations ------------------------------------------------ init_kernel = Kernel( f"{func_prefix}softmax_partial_init_bf16", - f"{func_prefix}softmax.o", + f"{func_prefix}{kernel_obj_file}", [stats_ty], ) stats_kernel = Kernel( f"{func_prefix}softmax_partial_stats_bf16", - f"{func_prefix}softmax.o", + f"{func_prefix}{kernel_obj_file}", [chunk_ty, stats_ty, np.int32], ) norm_kernel = Kernel( f"{func_prefix}softmax_partial_norm_bf16", - f"{func_prefix}softmax.o", + f"{func_prefix}{kernel_obj_file}", [chunk_ty, chunk_ty, stats_ty, np.int32], ) @@ -145,7 +146,7 @@ def _worker_args(k): ] workers = [ - Worker(core_body, _worker_args(i * num_channels + j)) + Worker(core_body, _worker_args(i * num_channels + j), stack_size=0xD00) for i in range(num_aie_columns) for j in range(num_channels) ] @@ -219,6 +220,7 @@ def softmax( tile_size, chunk_size, func_prefix, + kernel_obj_file, ) # ---- Full-row softmax path (original) ---- From 43759964d7d35d3c962c07694794d7241dc417a1 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 20 Apr 2026 14:02:21 -0600 Subject: [PATCH 22/22] reactivate sample-based verification for MHA --- iron/operators/mha_prefill_lxl_sd/test.py | 56 +++++++++++------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 358a474c..73d80440 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -217,34 +217,34 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal, dispatch): print(f"\nLatency (us): {latency_us:.1f}") print(f"Throughput: {gflops:.6e} GFLOP/s") - # # ---- Sample-based correctness check ---- - # # Pick a handful of random (head, row) pairs and recompute the expected - # # attn_context row for each (cheap: O(S*d) per sample). - # actual_context = _get_output_tensor(fc, "attn_context", (H, S, d)) - # rng = np.random.default_rng(seed=0) - # n_samples = min(8, H) - # sample_hms = [(int(rng.integers(0, H)), int(rng.integers(0, S))) for _ in range(n_samples)] - # expected_rows = compute_attn_context_at_rows( - # inputs["queries_deinterleaved"], - # inputs["keys_for_scores"], - # inputs["values_for_context"], - # inputs["_scale"], - # causal, - # sample_hms, - # ) - # failures = [] - # for (h, m), exp in expected_rows.items(): - # act = torch.from_numpy(actual_context[h, m, :]).bfloat16() - # diff = (act.float() - exp.float()).abs() - # rel = diff / (exp.float().abs() + 1e-6) - # # An element fails only if it exceeds BOTH abs_tol and rel_tol - # bad = (diff > ABS_TOL) & (rel > REL_TOL) - # if bad.any(): - # failures.append( - # f"(h={h}, m={m}): {int(bad.sum())}/{d} bad, " - # f"max_abs={diff.max().item():.4f}, max_rel={rel.max().item():.4f}" - # ) - # assert not failures, "Sample verification failed:\n " + "\n ".join(failures) + # ---- Sample-based correctness check ---- + # Pick a handful of random (head, row) pairs and recompute the expected + # attn_context row for each (cheap: O(S*d) per sample). + actual_context = _get_output_tensor(fc, "attn_context", (H, S, d)) + rng = np.random.default_rng(seed=0) + n_samples = 16 + sample_hms = [(int(rng.integers(0, H)), int(rng.integers(0, S))) for _ in range(n_samples)] + expected_rows = compute_attn_context_at_rows( + inputs["queries_deinterleaved"], + inputs["keys_for_scores"], + inputs["values_for_context"], + inputs["_scale"], + causal, + sample_hms, + ) + failures = [] + for (h, m), exp in expected_rows.items(): + act = torch.from_numpy(actual_context[h, m, :]).bfloat16() + diff = (act.float() - exp.float()).abs() + rel = diff / (exp.float().abs() + 1e-6) + # An element fails only if it exceeds BOTH abs_tol and rel_tol + bad = (diff > ABS_TOL) & (rel > REL_TOL) + if bad.any(): + failures.append( + f"(h={h}, m={m}): {int(bad.sum())}/{d} bad, " + f"max_abs={diff.max().item():.4f}, max_rel={rel.max().item():.4f}" + ) + assert not failures, "Sample verification failed:\n " + "\n ".join(failures) # ---------------------------------------------------------------------------