diff --git a/aie_kernels/aie2/softmax.cc b/aie_kernels/aie2/softmax.cc index 919d38f7..c02dd253 100644 --- a/aie_kernels/aie2/softmax.cc +++ b/aie_kernels/aie2/softmax.cc @@ -4,6 +4,7 @@ #include "lut_based_ops.h" #include +#include #include using namespace aie; @@ -57,6 +58,104 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out return; } +// --------------------------------------------------------------------------- +// Online (partial / tiled) softmax helpers +// +// These three kernels implement a two-pass online softmax that processes a row +// in sub-tile chunks, keeping running max and sum statistics in a small local +// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1] +// used): +// stats[0] = running max +// stats[1] = running sum (of exp(x - max)) +// --------------------------------------------------------------------------- + +void softmax_partial_stats_impl(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / 16; + + float running_max = (float)stats[0]; + float running_sum = (float)stats[1]; + + aie::vector input_bf16; + aie::accum exp_val_accum = aie::zeros(); + + auto it_in = aie::cbegin_vector<16>((bfloat16 *)input); + + // Single-pass online algorithm: for each vector chunk, check if max + // needs updating, rescale the running sum if so, then accumulate + // exp(x - max). + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in++; + float chunk_max = aie::reduce_max(input_bf16); + + if (chunk_max > running_max) { + // Rescale accumulated exp values by exp(old_max - new_max) + aie::vector correction = + to_v16bfloat16(getExpBf16( + aie::broadcast((bfloat16)(running_max - chunk_max)))); + float scale = (float)correction[0]; + // Rescale the partial vector accumulator + aie::vector scale_vec = + aie::broadcast((bfloat16)scale); + exp_val_accum = aie::mul(exp_val_accum.to_vector(), scale_vec); + // Rescale the running scalar sum from previous chunks + running_sum *= scale; + running_max = chunk_max; + } + + aie::vector shifted = aie::sub( + input_bf16, aie::broadcast((bfloat16)running_max)); + aie::vector exp_val = to_v16bfloat16(getExpBf16(shifted)); + exp_val_accum = add(exp_val_accum, exp_val); + } + + // Reduce the vector accumulator and add to running sum + aie::vector reduce = exp_val_accum.to_vector(); + running_sum += aie::reduce_add(reduce); + + stats[0] = (bfloat16)running_max; + stats[1] = (bfloat16)running_sum; + + event1(); +} + +void softmax_partial_norm_impl(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / 16; + + float max_val = (float)stats[0]; + float sum_val = (float)stats[1]; + bfloat16 inv_sum = (bfloat16)aie::inv(sum_val); + + aie::vector max_val_vec = + aie::broadcast((bfloat16)max_val); + + aie::vector input_bf16; + aie::accum out_vals; + + auto it_in = aie::cbegin_restrict_vector<16>((bfloat16 *)input); + auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output); + + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in++; + aie::vector shifted = aie::sub(input_bf16, max_val_vec); + aie::vector exp_val = to_v16bfloat16(getExpBf16(shifted)); + out_vals = aie::mul(exp_val, inv_sum); + *it_out++ = out_vals.to_vector(); + } + + event1(); +} + extern "C" { void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size) @@ -64,6 +163,27 @@ void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int softmax_simple_bf16(input, output, input_size); } +void softmax_partial_init_bf16(bfloat16 *stats) +{ + stats[0] = (bfloat16)(-INFINITY); + stats[1] = (bfloat16)(0.0f); +} + +void softmax_partial_stats_bf16(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_stats_impl(input, stats, vector_size); +} + +void softmax_partial_norm_bf16(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_norm_impl(input, output, stats, vector_size); +} + void mask_bf16(bfloat16 *inout, const int32_t unmasked_size, const int32_t total_size) { for (int32_t i = unmasked_size; i < total_size; i++) { diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 64cca202..48d82c3a 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -3,6 +3,7 @@ #include #include +#include #define SM_VEC_LEN 64 // 32 #define log2e 1.4453125 // 1.44269504089 @@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out aie::vector in_elems, exp_val, input_bf16, log2e_vec, max_val_vec; aie::accum out_vals, exp_val_accum, scaled_accum, exp_in_accum; - float max_val = 0; + float max_val = -INFINITY; float accum_exp_val = 0; float running_max = 0; bfloat16 col_sum_inv; @@ -159,6 +160,118 @@ void partial_softmax_alias_bf16(bfloat16 *restrict input_vector, return; } +// --------------------------------------------------------------------------- +// Online (partial / tiled) softmax helpers +// +// These three kernels implement a two-pass online softmax that processes a row +// in sub-tile chunks, keeping running max and sum statistics in a small local +// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1] +// used): +// stats[0] = running max (scaled by log2e) +// stats[1] = running sum (of exp2(x*log2e - max)) +// --------------------------------------------------------------------------- + +void softmax_partial_stats_impl(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / SM_VEC_LEN; + + aie::vector input_bf16; + aie::accum scaled_accum, exp_in_accum; + aie::accum exp_val_accum = aie::zeros(); + + aie::vector log2e_vec = + aie::broadcast((bfloat16)log2e); + + // --- Phase 1: find local max (scaled by log2e) ------------------------- + float local_max = -INFINITY; + auto it_in1 = aie::cbegin_restrict_vector((bfloat16 *)input); + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in1++; + scaled_accum = aie::mul(input_bf16, log2e_vec); + float chunk_max = aie::reduce_max(scaled_accum.to_vector()); + if (chunk_max > local_max) { + local_max = chunk_max; + } + } + + // --- Phase 2: update running max, rescale running sum ------------------ + float old_max = (float)stats[0]; + float old_sum = (float)stats[1]; + + if (local_max > old_max) { + // New max is larger — rescale the old sum by exp2(old_max - new_max) + aie::vector diff_vec = + aie::broadcast(old_max - local_max); + aie::vector corr = aie::exp2(diff_vec); + old_sum = old_sum * (float)corr[0]; + old_max = local_max; + } + + // --- Phase 3: accumulate exp2(input * log2e - max) for this chunk ------ + aie::vector max_val_vec = + aie::broadcast((bfloat16)old_max); + + auto it_in2 = aie::cbegin_restrict_vector((bfloat16 *)input); + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in2++; + scaled_accum = aie::mul(input_bf16, log2e_vec); + exp_in_accum = aie::sub(scaled_accum, max_val_vec); + aie::vector exp_val = + aie::exp2(exp_in_accum.to_vector()); + exp_val_accum = add(exp_val_accum, exp_val); + } + + aie::vector reduce = exp_val_accum.to_vector(); + float local_sum = aie::reduce_add(reduce); + + // --- Phase 4: store updated stats -------------------------------------- + stats[0] = (bfloat16)old_max; + stats[1] = (bfloat16)(old_sum + local_sum); + + event1(); +} + +void softmax_partial_norm_impl(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + event0(); + + const int elem_iters = vector_size / SM_VEC_LEN; + + float max_val = (float)stats[0]; + float sum_val = (float)stats[1]; + bfloat16 inv_sum = (bfloat16)aie::inv(sum_val); + + aie::vector log2e_vec = + aie::broadcast((bfloat16)log2e); + aie::vector max_val_vec = + aie::broadcast((bfloat16)max_val); + + aie::vector input_bf16; + aie::accum scaled_accum, exp_in_accum, out_vals; + + auto it_in = aie::cbegin_restrict_vector((bfloat16 *)input); + auto it_out = aie::begin_restrict_vector((bfloat16 *)output); + + for (int i = 0; i < elem_iters; i++) { + input_bf16 = *it_in++; + scaled_accum = aie::mul(input_bf16, log2e_vec); + exp_in_accum = aie::sub(scaled_accum, max_val_vec); + aie::vector exp_val = + aie::exp2(exp_in_accum.to_vector()); + out_vals = aie::mul(exp_val, inv_sum); + *it_out++ = out_vals.to_vector(); + } + + event1(); +} + extern "C" { void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size) @@ -177,6 +290,27 @@ void partial_softmax_bf16(bfloat16 *restrict input, partial_softmax_alias_bf16(input, output, scale_buffer, input_size, row_idx, num_rows, scale); } +void softmax_partial_init_bf16(bfloat16 *stats) +{ + stats[0] = (bfloat16)(-INFINITY); + stats[1] = (bfloat16)(0.0f); +} + +void softmax_partial_stats_bf16(bfloat16 *restrict input, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_stats_impl(input, stats, vector_size); +} + +void softmax_partial_norm_bf16(bfloat16 *restrict input, + bfloat16 *restrict output, + bfloat16 *stats, + const int32_t vector_size) +{ + softmax_partial_norm_impl(input, output, stats, vector_size); +} + void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size) { // TODO: Optimize this to use vector code diff --git a/aie_kernels/generic/axpy.cc b/aie_kernels/generic/axpy.cc index 728adb55..3bd01324 100644 --- a/aie_kernels/generic/axpy.cc +++ b/aie_kernels/generic/axpy.cc @@ -42,4 +42,103 @@ void saxpy_scalar(bfloat16 *x, bfloat16 *y, const bfloat16 a, bfloat16 *z, const } event1(); } + +// z = a * x (scalar-vector multiply; AXPY without the +y term) +void scale_bf16(bfloat16 *restrict x, bfloat16 *restrict z, const float a, const int32_t vector_size) +{ + event0(); + ::aie::vector a_v = + ::aie::broadcast(aie::to_float(a, 0)); + for (int i = 0; i < vector_size; i += 64) { + ::aie::vector x_v = ::aie::load_v<64>(x); + x += 64; + ::aie::accum z_v = ::aie::mul(x_v, a_v); + ::aie::vector z_v_converted = z_v.to_vector(); + ::aie::store_v(z, z_v_converted); + z += 64; + } + event1(); +} + +// z = a + y (scalar-vector add; AXPY without the *x term) +void scalar_add_bf16(bfloat16 *restrict y, bfloat16 *restrict z, const float a, const int32_t vector_size) +{ + event0(); + ::aie::vector a_v = + ::aie::broadcast(aie::to_float(a, 0)); + for (int i = 0; i < vector_size; i += 64) { + ::aie::vector y_v = ::aie::load_v<64>(y); + y += 64; + ::aie::vector z_v = ::aie::add(y_v, a_v); + ::aie::store_v(z, z_v); + z += 64; + } + event1(); +} + +// z = (col > row) ? a : y applied per-element of one tile, using a tile +// position (chunk_start_col, row_in_head) supplied via the idx_buffer. +// +// The tile is interpreted as a `vector_size`-wide horizontal strip of the +// per-head (S, S) attention-score block; idx[0] is the strip's starting +// column within that block, idx[1] is the strip's row within the block. +// The kernel implements the causal mask in-place by writing `a` to elements +// strictly above the diagonal and copying y -> z everywhere else. This +// avoids materialising an H*S*S mask buffer entirely. +// +// For tiles whose entire range lies at-or-below the diagonal, the kernel +// degenerates to a copy (input still streamed through DMA — slightly +// wasteful but simpler than per-tile data-movement skipping). +void scalar_add_causal_bf16(bfloat16 *restrict y, bfloat16 *restrict z, int32_t *idx, + const float a, const int32_t vector_size) +{ + event0(); + + constexpr int VEC = 64; + + int32_t chunk_start_col = idx[0]; + int32_t row_in_head = idx[1]; + + // Index of the first column in the tile that needs to be masked + // (i.e. column index strictly greater than row_in_head). + int32_t mask_start = row_in_head + 1 - chunk_start_col; + if (mask_start < 0) mask_start = 0; + if (mask_start > vector_size) mask_start = vector_size; + + bfloat16 s = (bfloat16)a; + ::aie::vector s_v = ::aie::broadcast(s); + int j = 0; + + // ---- Unmasked region [0, mask_start): copy y -> z ---- + // Vectorised body up to the largest VEC-aligned offset <= mask_start. + int mask_start_floor = (mask_start / VEC) * VEC; + for (; j < mask_start_floor; j += VEC) { + ::aie::vector v = ::aie::load_v(y + j); + ::aie::store_v(z + j, v); + } + // Scalar copy for the unmasked remainder (at most VEC - 1 elements). + for (; j < mask_start; j++) { + z[j] = y[j]; + } + + // ---- Masked region [mask_start, vector_size): write scalar ---- + // If mask_start isn't VEC-aligned, scalar-fill up to the next VEC + // boundary (or to vector_size, whichever is smaller). + int next_vec_boundary = ((j + VEC - 1) / VEC) * VEC; + if (next_vec_boundary > vector_size) next_vec_boundary = vector_size; + for (; j < next_vec_boundary; j++) { + z[j] = s; + } + // Vectorised body of the masked region. + for (; j + VEC <= vector_size; j += VEC) { + ::aie::store_v(z + j, s_v); + } + // Scalar tail when vector_size isn't VEC-aligned (in practice this + // doesn't fire since per_tile_elements is always a multiple of VEC). + for (; j < vector_size; j++) { + z[j] = s; + } + + event1(); +} } \ No newline at end of file diff --git a/conftest.py b/conftest.py index d611f0dd..8a9ab947 100644 --- a/conftest.py +++ b/conftest.py @@ -59,6 +59,14 @@ def __init__(self, csv_path): self.commit = get_git_commit() self.date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.test_metrics = {} # test_name -> {metric_name -> [values]} + self._initialize_csv() + + def _initialize_csv(self): + """Initialize CSV file - will be written incrementally as tests complete""" + self.csv_path.parent.mkdir(parents=True, exist_ok=True) + # Clear the file at the start of a new test run + with open(self.csv_path, "w", newline="") as f: + pass # Create empty file, header will be written with first result def add_result(self, test_name, passed, captured_output, metric_patterns): self.test_metrics.setdefault(test_name, {}).setdefault("passed", []).append( @@ -72,40 +80,70 @@ def add_result(self, test_name, passed, captured_output, metric_patterns): value = float(match.group("value")) self.test_metrics[test_name].setdefault(metric_name, []).append(value) - def finalize_results(self): - """Compute statistics for all collected metrics""" - for test_name, data in self.test_metrics.items(): - row = { - "Commit": self.commit, - "Date": self.date, - "Test": test_name, - "Checks": f"{sum(data['passed'])}/{len(data['passed'])}", - } - for metric_name, values in data.items(): - if metric_name == "passed": - continue - if values: - row[f"{metric_name} (mean)"] = statistics.mean(values) - row[f"{metric_name} (median)"] = statistics.median(values) - row[f"{metric_name} (min)"] = min(values) - row[f"{metric_name} (max)"] = max(values) - row[f"{metric_name} (stddev)"] = ( - statistics.stdev(values) if len(values) > 1 else 0.0 - ) + def _compute_row(self, test_name, data): + """Compute statistics row for a single test""" + row = { + "Commit": self.commit, + "Date": self.date, + "Test": test_name, + "Checks": f"{sum(data['passed'])}/{len(data['passed'])}", + } + for metric_name, values in data.items(): + if metric_name == "passed": + continue + if values: + row[f"{metric_name} (mean)"] = statistics.mean(values) + row[f"{metric_name} (median)"] = statistics.median(values) + row[f"{metric_name} (min)"] = min(values) + row[f"{metric_name} (max)"] = max(values) + row[f"{metric_name} (stddev)"] = ( + statistics.stdev(values) if len(values) > 1 else 0.0 + ) + return row + + def write_test_result(self, test_name): + """Write or update result for a specific test incrementally""" + if test_name not in self.test_metrics: + return + + data = self.test_metrics[test_name] + row = self._compute_row(test_name, data) + + # Update or add to results + existing_idx = None + for idx, existing_row in enumerate(self.results): + if existing_row["Test"] == test_name: + existing_idx = idx + break + + if existing_idx is not None: + self.results[existing_idx] = row + else: self.results.append(row) - def write_csv(self): - self.results.sort(key=lambda x: (x["Test"], x["Date"])) + # Rewrite entire CSV to handle column changes and updates + self._write_csv_internal() + + def _write_csv_internal(self): + """Internal method to write CSV file""" + if not self.results: + return + + sorted_results = sorted(self.results, key=lambda x: (x["Test"], x["Date"])) cols = {} - for row in self.results: + for row in sorted_results: cols.update({k: None for k in row.keys()}) self.csv_path.parent.mkdir(parents=True, exist_ok=True) with open(self.csv_path, "w", newline="") as f: writer = csv.DictWriter(f, cols.keys()) writer.writeheader() - writer.writerows(self.results) + writer.writerows(sorted_results) + + def write_csv(self): + """Final write at session end - ensures all results are flushed""" + self._write_csv_internal() # Initialize the CSV writer once at test session setup @@ -147,6 +185,9 @@ def pytest_runtest_makereport(item, call): csv_reporter.add_result(test_name, passed, captured, metric_patterns) + # Write results incrementally after each test + csv_reporter.write_test_result(test_name) + def pytest_configure(config): csv_path = config.getoption("--csv-output") @@ -170,7 +211,6 @@ def pytest_collection_modifyitems(config, items): def pytest_sessionfinish(session, exitstatus): if hasattr(session.config, "_csv_reporter"): - session.config._csv_reporter.finalize_results() session.config._csv_reporter.write_csv() diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index d6d17a64..0d8c1613 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -502,6 +502,7 @@ def compile(self, graph): str(self.aiecc_path), "-v", "-j1", + "--dynamic-objFifos", "--no-compile-host", "--no-xchesscc", "--no-xbridge", diff --git a/iron/common/fusion.py b/iron/common/fusion.py index 99219848..069e0089 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -1,26 +1,55 @@ # SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import hashlib +import logging import numpy as np import ml_dtypes import pyxrt import ctypes +import time from . import compilation as comp from .base import AIEOperatorBase, MLIROperator from .utils import XRTSubBuffer import aie.utils as aie_utils +from aie.iron.device import NPU2 from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from aie.utils.npukernel import NPUKernel + +logger = logging.getLogger(__name__) # Fused Operator # ########################################################################## class FusedMLIROperator(AIEOperatorBase): - """Operator that fuses multiple MLIROperators into one.""" + """Operator that fuses multiple MLIROperators into one. + + Args: + dispatch: Dispatch strategy for the fused operator. + ``"auto"`` (default) selects ``"fused"`` on NPU2 and + ``"separate"`` on NPU1. ``"fused"`` uses a single-ELF + dispatch (requires NPU2). ``"separate"`` compiles each + sub-operator to its own xclbin and invokes them sequentially. + """ + + DISPATCH_MODES = ("auto", "fused", "separate") def __init__( - self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs + self, + name, + runlist, + input_args, + output_args, + buffer_sizes=None, + dispatch="auto", + *args, + **kwargs, ): + if dispatch not in self.DISPATCH_MODES: + raise ValueError( + f"dispatch must be one of {self.DISPATCH_MODES!r}, got {dispatch!r}" + ) if not all( isinstance(op, MLIROperator) and all(isinstance(buf, str) for buf in bufs) for op, *bufs in runlist @@ -37,13 +66,13 @@ def __init__( self.explicit_buffer_sizes = ( buffer_sizes or {} ) # Optional dict: buffer_name -> size_in_bytes + self._dispatch = dispatch def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators. Returns: - List of KernelObjectArtifact instances from all unique child operators, - with filenames and symbol prefixes disambiguated per operator index. + List of KernelObjectArtifact instances from all unique child operators. """ kernel_artifacts = [] seen: dict[int, object] = {} @@ -52,9 +81,6 @@ def get_kernel_artifacts(self): ] for idx, op in enumerate(unique_operators): objs = op.get_kernel_artifacts() - for obj in objs: - obj.filename = f"op{idx}_{obj.filename}" - obj.prefix_symbols = f"op{idx}_" kernel_artifacts.extend(objs) return kernel_artifacts @@ -82,8 +108,6 @@ def get_mlir_artifact(self): ] for idx, op in enumerate(unique_operators): mlir_artifact = op.get_mlir_artifact() - if len(op.get_kernel_artifacts()) > 0: - mlir_artifact.generator.kwargs["func_prefix"] = f"op{idx}_" op_name = f"op{idx}_{op.__class__.__name__}" op_names[id(op)] = op_name operator_mlir_map[op_name] = mlir_artifact @@ -205,13 +229,36 @@ def add_buffers(buffer_type, args_list): def set_up_artifacts(self): """Set up the artifact dependency graph for this fused operator. - Computes the buffer layout first, then builds the fused MLIR artifact - and full-ELF artifact and registers them via ``add_artifacts()``. + Computes the buffer layout first, then builds the artifacts. + The dispatch mode (``"fused"`` vs ``"separate"``) is resolved here + when set to ``"auto"``. """ - # Calculate buffer layout before building mlir artifact (used by get_mlir_artifact) + # Calculate buffer layout (used by both paths for get_buffer()) self.subbuffer_layout, self.buffer_sizes, self.slice_info = ( self._calculate_buffer_layout() ) + + is_npu2 = isinstance(aie_utils.get_current_device(), NPU2) + + if self._dispatch == "auto": + self._use_full_elf = is_npu2 + elif self._dispatch == "fused": + if not is_npu2: + raise RuntimeError( + "dispatch='fused' requires NPU2 (Strix); " + "Phoenix/NPU1 does not support full-ELF dispatch" + ) + self._use_full_elf = True + else: # "separate" + self._use_full_elf = False + + if self._use_full_elf: + self._set_up_full_elf_artifacts() + else: + self._set_up_xclbin_artifacts() + + def _set_up_full_elf_artifacts(self): + """Full-ELF path (NPU2): fuse MLIR into a single ELF.""" operator_name = self.name mlir_artifact = self.get_mlir_artifact() kernel_objects = self.get_kernel_artifacts() @@ -222,6 +269,58 @@ def set_up_artifacts(self): ) self.add_artifacts([full_elf_artifact]) + def _set_up_xclbin_artifacts(self): + """Chained xclbin path (NPU1/Phoenix): separate xclbin per unique operator. + + Mirrors the pattern from ``chain_swiglu_artifacts`` in + ``iron/operators/swiglu_base.py``: each unique operator gets its own + xclbin + insts compiled separately, linked via ``--xclbin-input``. + """ + seen: dict[int, object] = {} + unique_operators = [ + seen.setdefault(id(op), op) + for op, *_ in self.runlist + if id(op) not in seen + ] + + # Short hash to keep xclbin kernel names under 31 chars + # (xclbinutil limits m_name to 64 chars as "name:name") + name_hash = hashlib.sha1(self.name.encode()).hexdigest()[:6] + + artifacts = [] + prev_xclbin = None + self._op_xclbin_map = {} # id(op) -> xclbin artifact + self._op_insts_map = {} # id(op) -> insts artifact + self._op_kernel_name_map = {} # id(op) -> kernel_name + + for idx, op in enumerate(unique_operators): + op_label = f"f{name_hash}_op{idx}" + kernel_id = f"0x{0x901 + idx:x}" + + xclbin, insts = op.get_artifacts(prefix=f"{op_label}_") + # Use list() to avoid mutating the shared extra_flags list + # (get_artifacts may alias the same list between xclbin and insts) + xclbin.extra_flags = list(xclbin.extra_flags) + [ + f"--xclbin-instance-name={op_label}", + f"--xclbin-kernel-id={kernel_id}", + ] + xclbin.kernel_name = op_label + + if prev_xclbin is not None: + xclbin.xclbin_input = prev_xclbin + xclbin.dependencies.add(prev_xclbin) + + artifacts.append(insts) + self._op_xclbin_map[id(op)] = xclbin + self._op_insts_map[id(op)] = insts + self._op_kernel_name_map[id(op)] = op_label + prev_xclbin = xclbin + + # The last xclbin in the chain is the combined xclbin. + artifacts.append(prev_xclbin) + self.combined_xclbin = prev_xclbin + self.add_artifacts(artifacts) + def get_arg_spec(self): raise NotImplementedError( "FusedMLIROperator does not expose a unified arg spec; " @@ -232,9 +331,12 @@ def get_callable(self): """Return a callable that executes the fused operator on the NPU. Returns: - A ``FusedFullELFCallable`` wrapping this operator. + A ``FusedFullELFCallable`` when using fused dispatch, or a + ``FusedXclbinCallable`` when using separate dispatch. """ - return FusedFullELFCallable(self) + if self._use_full_elf: + return FusedFullELFCallable(self) + return FusedXclbinCallable(self) def get_layout_for_buffer(self, buffer_name): """Return the (buffer_type, offset, length) layout for a named buffer. @@ -290,8 +392,10 @@ def __call__(self, *args): for i, arg in enumerate(args): assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" run.set_arg(i, arg) + t0 = time.perf_counter() run.start() ret_code = run.wait() + self.last_elapsed = time.perf_counter() - t0 if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: raise RuntimeError(f"Kernel execution failed with return code {ret_code}") @@ -371,10 +475,139 @@ def get_buffer(self, buffer_name): return sub_buffer def __call__(self): - self.input_buffer.to("npu") + self.input_buffer._sync_to_device() super().__call__( self.input_buffer.buffer_object(), self.output_buffer.buffer_object(), self.scratch_buffer.buffer_object(), ) - self.output_buffer.to("cpu") + self.output_buffer._sync_from_device() + + +class FusedXclbinCallable: + """Callable for FusedMLIROperator on NPU1 (Phoenix) using chained xclbins. + + Instead of a single ELF dispatch, each step in the runlist is executed as a + separate ``NPUKernel`` invocation. Buffers are shared (same ``XRTTensor``) + across steps that reference the same buffer name, giving zero-copy handoff + between sequential operators. + """ + + def __init__(self, op): + self.op = op + self.last_elapsed = 0.0 + + combined_xclbin_path = op.combined_xclbin.filename + + # Build an NPUKernel per unique operator + self._op_callable_map = {} # id(op) -> NPUKernel + for op_id, xclbin in op._op_xclbin_map.items(): + insts = op._op_insts_map[op_id] + kernel_name = op._op_kernel_name_map[op_id] + self._op_callable_map[op_id] = NPUKernel( + xclbin_path=combined_xclbin_path, + kernel_name=kernel_name, + insts_path=insts.filename, + ) + + # Allocate one XRTTensor per unique base buffer name. + # Buffers that appear in multiple runlist entries share the same tensor + # (zero-copy between operators). + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + self._buffers = {} # base buffer name -> XRTTensor + for buf_name in list(op.subbuffer_layout.keys()): + _, _, length = op.subbuffer_layout[buf_name] + self._buffers[buf_name] = XRTTensor( + (max(length, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + # Pre-build the execution plan: list of (NPUKernel, [XRTTensor args]) + self._execution_plan = [] + for step_op, *buf_names in op.runlist: + kernel = self._op_callable_map[id(step_op)] + args = [] + for buf_name in buf_names: + args.append(self._resolve_buffer(buf_name)) + self._execution_plan.append((kernel, args)) + + # Cache for get_buffer() sub-buffer views (compatible with FusedFullELFCallable API) + self._buffer_cache = {} + + # Expose input/output/scratch buffers for API compatibility with + # FusedFullELFCallable (used by tests for _sync_from_device etc.) + input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes + self.input_buffer = XRTTensor( + (max(input_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + self.output_buffer = XRTTensor( + (max(output_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + self.scratch_buffer = XRTTensor( + (max(scratch_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + def _resolve_buffer(self, buf_name): + """Resolve a buffer name (possibly with slice notation) to an XRTTensor. + + Regular buffer names map directly to an allocated XRTTensor. + Sliced buffer names (e.g. ``queries[0:128]``) create an XRTSubBuffer + view into the parent buffer. + """ + if buf_name in self._buffers: + return self._buffers[buf_name] + + # Sliced buffer: "base_name[start:end]" + if buf_name in self.op.slice_info: + base_name, start_bytes, end_bytes = self.op.slice_info[buf_name] + parent = self._buffers[base_name] + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + size_bytes = end_bytes - start_bytes + sub = XRTSubBuffer( + parent_bo=parent.buffer_object(), + offset_bytes=start_bytes, + size_bytes=size_bytes, + shape=(size_bytes // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + # Cache so the same slice always returns the same object + self._buffers[buf_name] = sub + return sub + + raise ValueError(f"Unknown buffer '{buf_name}' in fused runlist") + + def get_buffer(self, buffer_name): + """Return an XRTTensor(-like) view for a named buffer. + + Compatible with the ``FusedFullELFCallable.get_buffer()`` API so that + test helpers (``_load_input``, ``_get_output_tensor``, etc.) work + unchanged. + + For the xclbin path, each buffer is its own standalone XRTTensor (or + XRTSubBuffer for sliced buffers), so this just returns the resolved + buffer directly. + """ + if buffer_name in self._buffer_cache: + return self._buffer_cache[buffer_name] + buf = self._resolve_buffer(buffer_name) + self._buffer_cache[buffer_name] = buf + return buf + + def __call__(self): + # Sync all input buffers to device + for buf_name in self.op.input_args: + self._buffers[buf_name]._sync_to_device() + + t0 = time.perf_counter() + for kernel, args in self._execution_plan: + kernel(*args) + self.last_elapsed = time.perf_counter() - t0 + + # Sync all base buffers from device so callers can read results + # (covers both output and scratch buffers) + for buf_name in self.op.subbuffer_layout: + if buf_name not in self.op.input_args: + self._buffers[buf_name]._sync_from_device() diff --git a/iron/operators/axpy/design.py b/iron/operators/axpy/design.py index af58eb55..553e3b5f 100644 --- a/iron/operators/axpy/design.py +++ b/iron/operators/axpy/design.py @@ -4,7 +4,7 @@ from ml_dtypes import bfloat16 import numpy as np -from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker, Buffer from aie.iron.placers import SequentialPlacer from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ @@ -17,7 +17,37 @@ def my_axpy( tile_size, trace_size, scalar_factor, + add_y=True, + mul_x=True, + causal_mask=False, + mask_block_dim=0, + rows_per_block=0, + row_offset=0, ): + """AXPY-family element-wise design. + + Modes: + mul_x=True, add_y=True → saxpy: Z = a*X + Y + mul_x=True, add_y=False → scale: Z = a*X + mul_x=False, add_y=True → scalar_add: Z = a + Y + mul_x=False, add_y=True, causal_mask=True → + Z[i,j] = a if (j > i within head) else Y[i,j] + (in-place causal mask; supplies row/col-chunk indices to the kernel + via an idx_buffer; data is interpreted as (..., S, S) blocks where + S = mask_block_dim.) + """ + if causal_mask: + return _my_axpy_causal_mask( + dev=dev, + num_elements=num_elements, + num_columns=num_columns, + tile_size=tile_size, + scalar_factor=scalar_factor, + mask_block_dim=mask_block_dim, + rows_per_block=rows_per_block or mask_block_dim, + row_offset=row_offset, + ) + factor = scalar_factor per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns @@ -29,92 +59,263 @@ def my_axpy( chunk = num_elements // num_columns dtype = bfloat16 - # Define tensor types tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]] - # AIE-array data movement with object fifos (one per column, not per channel) + # Two inputs only when both *X and +Y are kept (saxpy mode). + has_two_inputs = add_y and mul_x + + # AIE-array data movement with object fifos (one per column) of_in1s = [ObjectFifo(tile_ty, name=f"in1_{i}") for i in range(num_columns)] - of_in2s = [ObjectFifo(tile_ty, name=f"in2_{i}") for i in range(num_columns)] + if has_two_inputs: + of_in2s = [ObjectFifo(tile_ty, name=f"in2_{i}") for i in range(num_columns)] of_outs = [ObjectFifo(tile_ty, name=f"out_{i}") for i in range(num_columns)] # AIE Core Function declaration - axpy_bf16_vector = Kernel( - "saxpy", "axpy.o", [tile_ty, tile_ty, np.float32, tile_ty, np.int32] - ) - - # Define a task that will run on a compute tile - def core_body(of_in1, of_in2, of_out, axpy): - # Number of sub-vector "tile" iterations - for _ in range_(N_div_n): - elem_in1 = of_in1.acquire(1) - elem_in2 = of_in2.acquire(1) - elem_out = of_out.acquire(1) - axpy(elem_in1, elem_in2, factor, elem_out, per_tile_elements) - of_in1.release(1) - of_in2.release(1) - of_out.release(1) - - # Create a worker to run the task on a compute tile (one per column) - my_workers = [ - Worker( - core_body, - [ - of_in1s[i].cons(), - of_in2s[i].cons(), - of_outs[i].prod(), - axpy_bf16_vector, - ], + if has_two_inputs: + kernel = Kernel( + "saxpy", "axpy.o", + [tile_ty, tile_ty, np.float32, tile_ty, np.int32], ) - for i in range(num_columns) - ] + elif not add_y: + # z = a * x (drop +Y) + kernel = Kernel( + "scale_bf16", "axpy.o", + [tile_ty, tile_ty, np.float32, np.int32], + ) + else: + # z = a + y (drop *X) + kernel = Kernel( + "scalar_add_bf16", "axpy.o", + [tile_ty, tile_ty, np.float32, np.int32], + ) + + if has_two_inputs: + def core_body(of_in1, of_in2, of_out, k): + for _ in range_(N_div_n): + e1 = of_in1.acquire(1) + e2 = of_in2.acquire(1) + eo = of_out.acquire(1) + k(e1, e2, factor, eo, per_tile_elements) + of_in1.release(1) + of_in2.release(1) + of_out.release(1) + else: + def core_body(of_in1, of_out, k): + for _ in range_(N_div_n): + e1 = of_in1.acquire(1) + eo = of_out.acquire(1) + k(e1, eo, factor, per_tile_elements) + of_in1.release(1) + of_out.release(1) + + if has_two_inputs: + my_workers = [ + Worker( + core_body, + [of_in1s[i].cons(), of_in2s[i].cons(), of_outs[i].prod(), kernel], + ) + for i in range(num_columns) + ] + else: + my_workers = [ + Worker( + core_body, + [of_in1s[i].cons(), of_outs[i].prod(), kernel], + ) + for i in range(num_columns) + ] - # Create a TensorAccessPattern for each column - # to describe the data movement - # The pattern chops the data in equal chunks - # and moves them in parallel across the columns taps = [ TensorAccessPattern( (1, num_elements), - chunk * i, # Start offset for column i + chunk * i, [1, 1, 1, chunk], [0, 0, 0, 1], ) for i in range(num_columns) ] - # Runtime operations to move data to/from the AIE-array rt = Runtime() - with rt.sequence(tensor_ty, tensor_ty, tensor_ty) as (A, B, C): + sequence_types = ( + (tensor_ty, tensor_ty, tensor_ty) + if has_two_inputs + else (tensor_ty, tensor_ty) + ) + + with rt.sequence(*sequence_types) as bufs: + if has_two_inputs: + A, B, C = bufs + else: + A, C = bufs + rt.start(*my_workers) - # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. tg = rt.task_group() - - # Fill the input objectFIFOs with data for i in range(num_columns): - rt.fill( - of_in1s[i].prod(), - A, - taps[i], - task_group=tg, + rt.fill(of_in1s[i].prod(), A, taps[i], task_group=tg) + if has_two_inputs: + rt.fill(of_in2s[i].prod(), B, taps[i], task_group=tg) + for i in range(num_columns): + rt.drain(of_outs[i].cons(), C, taps[i], wait=True, task_group=tg) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) + + +def _my_axpy_causal_mask( + dev, + num_elements, + num_columns, + tile_size, + scalar_factor, + mask_block_dim, + rows_per_block, + row_offset, +): + """Single-core in-place causal mask via scalar_add_causal_bf16. + + Walks (blocks × rows_per_block × chunks-per-row) with three nested + runtime loops and feeds (chunk_start_col, row_in_head) to the kernel via + an idx buffer. The kernel applies the scalar `a` to elements strictly + above the per-head diagonal and copies y → z elsewhere. Tiles entirely + below the diagonal still get DMA'd (no per-tile data-movement skip) + but the kernel does only a copy in that case. + + Two operating modes (selected by the rows_per_block / row_offset args): + * Multi-block / full-head (rows_per_block = mask_block_dim, row_offset = 0): + walks ``num_blocks`` whole (S, S) blocks; idx[1] resets to 0 at the + start of each block. + * Sub-block (rows_per_block < mask_block_dim or row_offset > 0): + ``num_blocks`` is typically 1 in the MHA caller; processes a contiguous + ``rows_per_block``-tall slice of one block starting at row_offset. + Used at very long S where one (S, S) block exceeds the BD-length cap. + """ + factor = scalar_factor + S = mask_block_dim + per_tile_elements = 4096 if tile_size > 4096 else tile_size + if S % per_tile_elements != 0: + raise ValueError( + f"mask_block_dim ({S}) must be a multiple of per_tile_elements " + f"({per_tile_elements})" + ) + chunks_per_row = S // per_tile_elements + block_elements = rows_per_block * S + if num_elements % block_elements != 0: + raise ValueError( + f"num_elements ({num_elements}) must be a multiple of " + f"rows_per_block * S ({block_elements})" + ) + num_blocks = num_elements // block_elements + + # Two parallelisation modes: + # * block-aligned (num_blocks >= num_columns): each core handles + # blocks_per_core whole (S, S) blocks; idx[1] resets to row_offset at + # every block boundary (same value on every core). + # * within-block (num_blocks == 1, num_columns > 1): a single block is + # too big to split across cores by block, so each core handles a + # contiguous row-range slice of that one block; per-core init_row is + # row_offset + core_idx * rows_per_iter (different per core). The + # kernel logic is unchanged — it only cares about (chunk_start_col, + # row_in_block). + if num_blocks >= num_columns: + if num_blocks % num_columns != 0: + raise ValueError( + f"num_blocks ({num_blocks}) must be a multiple of num_columns " + f"({num_columns}); causal_mask multi-core split is block-aligned" ) - rt.fill( - of_in2s[i].prod(), - B, - taps[i], - task_group=tg, + blocks_per_core = num_blocks // num_columns + rows_per_iter = rows_per_block + per_core_init_rows = [row_offset] * num_columns + else: + if num_blocks != 1: + raise ValueError( + f"causal_mask multi-core within-block split requires " + f"num_blocks == 1, got {num_blocks}" ) - # Drain the output objectFIFOs with data - for i in range(num_columns): - rt.drain( - of_outs[i].cons(), - C, - taps[i], - wait=True, # wait for the transfer to complete and data to be available - task_group=tg, + if rows_per_block % num_columns != 0: + raise ValueError( + f"rows_per_block ({rows_per_block}) must be a multiple of " + f"num_columns ({num_columns}) for within-block split" ) + blocks_per_core = 1 + rows_per_iter = rows_per_block // num_columns + per_core_init_rows = [ + row_offset + i * rows_per_iter for i in range(num_columns) + ] + + elements_per_core = num_elements // num_columns + + dtype = bfloat16 + tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] + tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]] + idx_ty = np.ndarray[(2,), np.dtype[np.int32]] + + of_ins = [ObjectFifo(tile_ty, name=f"in{i}") for i in range(num_columns)] + of_outs = [ObjectFifo(tile_ty, name=f"out{i}") for i in range(num_columns)] + + kernel = Kernel( + "scalar_add_causal_bf16", + "axpy.o", + [tile_ty, tile_ty, idx_ty, np.float32, np.int32], + ) + + idx_buffers = [ + Buffer( + initial_value=np.zeros((2,), dtype=np.int32), + name=f"causal_mask_idx_{i}", + ) + for i in range(num_columns) + ] + + # Build one core_body per worker so the per-core init_row can be baked + # into the closure (constant within the worker code). + def make_core_body(my_init_row): + def core_body(of_in_, of_out_, k, idx): + # idx[0] = chunk_start_col within the current row of the block + # idx[1] = current row index within the current block + idx[0] = 0 + idx[1] = my_init_row + for _ in range_(blocks_per_core): + for _ in range_(rows_per_iter): + for _ in range_(chunks_per_row): + elem_in = of_in_.acquire(1) + elem_out = of_out_.acquire(1) + k(elem_in, elem_out, idx, factor, per_tile_elements) + of_in_.release(1) + of_out_.release(1) + idx[0] = idx[0] + per_tile_elements + idx[0] = 0 + idx[1] = idx[1] + 1 + idx[1] = my_init_row # reset for next block + return core_body + + workers = [ + Worker( + make_core_body(per_core_init_rows[i]), + [of_ins[i].cons(), of_outs[i].prod(), kernel, idx_buffers[i]], + ) + for i in range(num_columns) + ] + + taps = [ + TensorAccessPattern( + (1, num_elements), + i * elements_per_core, + [1, 1, 1, elements_per_core], + [0, 0, 0, 1], + ) + for i in range(num_columns) + ] + + rt = Runtime() + with rt.sequence(tensor_ty, tensor_ty) as (A, C): + rt.start(*workers) + tg = rt.task_group() + for i in range(num_columns): + rt.fill(of_ins[i].prod(), A, taps[i], task_group=tg) + for i in range(num_columns): + rt.drain(of_outs[i].cons(), C, taps[i], wait=True, task_group=tg) rt.finish_task_group(tg) - # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/axpy/op.py b/iron/operators/axpy/op.py index ff4b87a1..f2f543d3 100644 --- a/iron/operators/axpy/op.py +++ b/iron/operators/axpy/op.py @@ -6,6 +6,7 @@ from iron.common import ( BinaryElementwiseOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, @@ -15,14 +16,113 @@ @dataclass class AXPY(BinaryElementwiseOperator): - """AIE-accelerated aX + Y operator""" + """AIE-accelerated aX + Y operator. + + Optional flags select degenerate variants that skip part of the formula: + + * ``add_y=False`` → ``Z = a * X`` (drop the +Y term and the Y buffer) + * ``mul_x=False`` → ``Z = a + Y`` (drop the *X term and the X buffer) + * ``causal_mask=True`` (requires ``mul_x=False``) → + ``Z[i,j] = a if (j > i within head) else Y[i,j]`` + Treats the data as a sequence of (mask_block_dim, mask_block_dim) blocks + and applies a causal (lower-triangular-keep) mask in-place per block. + Tile-position info is supplied to the kernel via an idx_buffer; tiles + that lie entirely below the diagonal degenerate to a kernel-side copy + (no per-tile DMA skipping). Used by MHA to avoid materialising an + H * S * S causal-mask input buffer. + + ``add_y=False`` and ``mul_x=False`` cannot be combined. + """ scalar_factor: float = 3.0 + add_y: bool = True + mul_x: bool = True + causal_mask: bool = False + mask_block_dim: int | None = None + # Sub-block parameters (causal_mask only). Default: process full + # mask_block_dim rows per (S,S) block starting at row 0. Set + # rows_per_block < mask_block_dim and/or row_offset > 0 to process a + # contiguous row-range slice of one block per invocation — useful when + # a single full block's element count exceeds the per-invocation BD + # cap (e.g. S>=32K). + rows_per_block: int | None = None + row_offset: int = 0 kernel_name: ClassVar[str] = "axpy" kernel_fn_name: ClassVar[str] = "saxpy" callback_fn: ClassVar[str] = "my_axpy" + def __post_init__(self) -> None: + if not self.add_y and not self.mul_x: + raise ValueError("AXPY requires at least one of add_y or mul_x to be True") + if self.causal_mask: + if self.mul_x: + raise ValueError( + "AXPY causal_mask=True requires mul_x=False (Z = a + Y form)" + ) + if not self.add_y: + raise ValueError( + "AXPY causal_mask=True requires add_y=True (needs the Y buffer)" + ) + if self.mask_block_dim is None: + raise ValueError( + "AXPY causal_mask=True requires mask_block_dim (the (S,S) block dim)" + ) + # Default rows_per_block = mask_block_dim (process full blocks) + if self.rows_per_block is None: + self.rows_per_block = self.mask_block_dim + if self.rows_per_block <= 0 or self.rows_per_block > self.mask_block_dim: + raise ValueError( + f"rows_per_block ({self.rows_per_block}) must be in (0, " + f"mask_block_dim={self.mask_block_dim}]" + ) + if self.row_offset + self.rows_per_block > self.mask_block_dim: + raise ValueError( + f"row_offset ({self.row_offset}) + rows_per_block " + f"({self.rows_per_block}) must be <= mask_block_dim " + f"({self.mask_block_dim})" + ) + block_elements = self.rows_per_block * self.mask_block_dim + if self.size % block_elements != 0: + raise ValueError( + f"size ({self.size}) must be a multiple of " + f"rows_per_block * mask_block_dim ({block_elements})" + ) + # Multi-core split: either block-aligned (each core handles + # whole blocks; num_aie_columns must divide num_blocks) or + # within-block (num_blocks == 1; num_aie_columns must divide + # rows_per_block). + num_blocks = self.size // block_elements + if num_blocks >= self.num_aie_columns: + if num_blocks % self.num_aie_columns != 0: + raise ValueError( + f"AXPY causal_mask block-aligned split: " + f"num_aie_columns ({self.num_aie_columns}) must " + f"divide num_blocks ({num_blocks})" + ) + else: + if num_blocks != 1: + raise ValueError( + f"AXPY causal_mask within-block split requires " + f"num_blocks == 1, got {num_blocks}" + ) + if self.rows_per_block % self.num_aie_columns != 0: + raise ValueError( + f"AXPY causal_mask within-block split: " + f"rows_per_block ({self.rows_per_block}) must be a " + f"multiple of num_aie_columns ({self.num_aie_columns})" + ) + super().__post_init__() + + def get_arg_spec(self) -> list[AIERuntimeArgSpec]: + # When either input is dropped, the design has one input + one output. + if not self.add_y or not self.mul_x: + return [ + AIERuntimeArgSpec("in", (self.size,)), + AIERuntimeArgSpec("out", (self.size,)), + ] + return super().get_arg_spec() + def get_kernel_artifacts(self) -> list[KernelObjectArtifact]: # axpy.cc lives under aie_kernels/generic/ (not device-specific) return [ @@ -37,7 +137,15 @@ def get_kernel_artifacts(self) -> list[KernelObjectArtifact]: ] def _mlir_callback_args(self): - return super()._mlir_callback_args() + [self.scalar_factor] + return super()._mlir_callback_args() + [ + self.scalar_factor, + self.add_y, + self.mul_x, + self.causal_mask, + self.mask_block_dim if self.mask_block_dim is not None else 0, + self.rows_per_block if self.rows_per_block is not None else 0, + self.row_offset, + ] def get_mlir_artifact(self) -> PythonGeneratedMLIRArtifact: return PythonGeneratedMLIRArtifact( diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index a8ed8ad3..8b717dcf 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -299,7 +299,7 @@ def my_matmul( gemm_object, [C_l1_ty_internal], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32" + matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_f32" matmul_kernel = Kernel( matmul_func_name, gemm_object, @@ -314,7 +314,9 @@ def my_matmul( gemm_object, [C_l1_ty], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + matmul_func_name = ( + f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + ) matmul_kernel = Kernel( matmul_func_name, gemm_object, diff --git a/iron/operators/mha_prefill_lxl_sd/__init__.py b/iron/operators/mha_prefill_lxl_sd/__init__.py new file mode 100644 index 00000000..82f09a67 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py new file mode 100644 index 00000000..86fcae17 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -0,0 +1,604 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +A layer-by-layer (LxL) single-dispatch (SD) implementation of multi-head attention (MHA). +""" + +import math + +import aie.utils as aie_utils + +from iron.common.context import AIEContext +from iron.common.fusion import FusedMLIROperator +from iron.operators.axpy.op import AXPY +from iron.operators.gemm.op import GEMM +from iron.operators.rope.op import RoPE +from iron.operators.strided_copy.op import StridedCopy +from iron.operators.repeat.op import Repeat +from iron.operators.softmax.op import Softmax +from iron.operators.transpose.op import Transpose +from iron.operators.elementwise_add.op import ElementwiseAdd + + +def _pick_tile_n(N, num_cols, max_tile_n=64): + tile_n = N // num_cols + while tile_n > max_tile_n: + tile_n //= 2 + assert N % (tile_n * num_cols) == 0 + return tile_n + + +def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True, num_cols=None, + disable_softmax=False): + """Build core attention sub-ops and runlist (no projections/RoPE/GQA). + + Expects pre-processed inputs: + queries: (H, S, d) deinterleaved, contiguous per head + keys: (H, d, S) transposed and GQA-repeated + values: (H, S, d) GQA-repeated + + Produces: + attn_context: (H, S, d) — per-head context vectors + + If causal_mask=False, the elementwise-add masking step is omitted. + + If disable_softmax=True, the softmax step is omitted (output is incorrect; + intended for performance-isolation benchmarks only). + """ + if num_cols is None: + num_cols = aie_utils.get_current_device().cols + B = 2 # bytes per bf16 element + + # ---- M-splitting for the per-head GEMMs ---- + # At very long sequence lengths the per-GEMM design's runtime sequence + # (number of `rt.fill`/`rt.drain` MLIR ops) grows linearly with M, which + # makes each sub-operator's MLIR module very large and can OOM the + # compiler at S >= 16K. We cap each GEMM invocation's M dimension at + # `gemm_M_chunk` and split the per-head computation into multiple + # back-to-back GEMM invocations with sliced buffer offsets. Single + # dispatch is preserved (same fused runlist, just with more entries). + # + # M_chunk must be a multiple of `tile_m * n_aie_rows = 64` and must + # divide S evenly. At min(S, 4096) we get: + # S <= 4096: 1 invocation per head per phase (no splitting) + # S = 8192: 2 invocations per head per phase + # S = 16384: 4 invocations per head per phase + # S = 32768: 8 invocations per head per phase + gemm_M_chunk = min(S, 4096) + assert S % gemm_M_chunk == 0, ( + f"S ({S}) must be a multiple of gemm_M_chunk ({gemm_M_chunk})" + ) + n_m_chunks = S // gemm_M_chunk + + gemm_scores = GEMM( + M=gemm_M_chunk, + K=d, + N=S, + num_aie_columns=num_cols, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(S, num_cols), + context=elf_ctx, + ) + # Scale by 1/sqrt(d) — uses AXPY in scale-only mode (add_y=False) so the + # scalar is baked into the kernel call instead of being passed as an + # H*S*S broadcast buffer. At S=32K, H=12 this saves a 24 GB input. + scale = AXPY( + size=H * S * S, + tile_size=S * S // num_cols, + num_aie_columns=num_cols, + scalar_factor=1.0 / math.sqrt(d), + add_y=False, + context=elf_ctx, + ) + if causal_mask: + # Apply causal mask via AXPY in scalar-add + causal-mask mode. The + # kernel computes (in place) `Z[i,j] = -INF if (j > i within head) + # else Y[i,j]` using a tile-position idx_buffer. This avoids + # materialising an H*S*S input mask buffer (saves 24 GB at S=32K). + # Single-core; tiles entirely below the diagonal still flow through + # DMA (kernel does only a copy in that case). + # + # Same BD-overflow workaround as for softmax: each invocation's + # transfer must fit under the compiler's int32 byte limit (< 2^30 + # bf16 elements). + # * If S² fits, each invocation processes some whole heads. + # * Otherwise (S>=32K), split each head into `mask_subblocks` + # row-range slices and emit one AXPY instance per row_offset. + MASK_MAX_ELEMENTS_PER_INV = (1 << 30) - 1 + if S * S <= MASK_MAX_ELEMENTS_PER_INV: + # Multi-head batched: pick max heads/invocation that divides H. + heads_per_mask_inv = max(1, MASK_MAX_ELEMENTS_PER_INV // (S * S)) + while H % heads_per_mask_inv != 0: + heads_per_mask_inv -= 1 + mask_subblocks = 1 + mask_rows_per_block = S + else: + # Sub-head: split each head into `mask_subblocks` row-range slices + # such that rows_per_block * S <= MASK_MAX_ELEMENTS_PER_INV. + heads_per_mask_inv = 1 + mask_subblocks = (S * S + MASK_MAX_ELEMENTS_PER_INV - 1) // MASK_MAX_ELEMENTS_PER_INV + while S % mask_subblocks != 0: + mask_subblocks += 1 + mask_rows_per_block = S // mask_subblocks + assert mask_rows_per_block * S <= MASK_MAX_ELEMENTS_PER_INV + n_mask_invocations = (H // heads_per_mask_inv) * mask_subblocks + + # Multi-core parallelism for the AXPY causal mask: + # * Block-aligned (heads_per_mask_inv >= 2): each core handles whole + # blocks (heads), so num_aie_columns must divide heads_per_mask_inv. + # * Within-block (heads_per_mask_inv == 1, sub-head mode): each core + # handles a contiguous row-range slice of the one (S/N, S) block, + # so num_aie_columns must divide mask_rows_per_block. + if heads_per_mask_inv >= 2: + mask_num_cols = min(num_cols, heads_per_mask_inv) + while heads_per_mask_inv % mask_num_cols != 0: + mask_num_cols -= 1 + else: + mask_num_cols = num_cols + while mask_rows_per_block % mask_num_cols != 0: + mask_num_cols -= 1 + + # Build one AXPY instance per (row_offset) — same kernel/design + # parameters otherwise. When mask_subblocks=1 there's exactly one. + mask_ops = [ + AXPY( + size=heads_per_mask_inv * mask_rows_per_block * S, + tile_size=min(4096, S), + num_aie_columns=mask_num_cols, + scalar_factor=float("-inf"), + mul_x=False, + add_y=True, + causal_mask=True, + mask_block_dim=S, + rows_per_block=mask_rows_per_block, + row_offset=sub_idx * mask_rows_per_block, + context=elf_ctx, + ) + for sub_idx in range(mask_subblocks) + ] + # Use online/partial softmax when full-row tiles would exhaust AIE local + # memory (each double-buffered FIFO pair uses 4 * tile_size bytes; at + # S >= 8192 the in+out FIFOs alone consume the full 64 KB data memory). + softmax_chunk_size = 1024 if S >= 8192 else None + + # ---- Row-splitting for the softmax invocation ---- + # The shim DMA BD length field is a 32-bit unsigned word count (~4.29 B + # words ≈ 17 GB), but the current compiler lowering computes the BD length + # in bytes through int32 arithmetic and silently overflows when a single + # invocation's transfer exceeds 2 GB (= 2^31 bytes = 2^30 bf16 elements). + # We split the softmax call into N back-to-back invocations on disjoint + # row ranges to keep each transfer under that effective limit. The + # softmax buffers (attn_scores_masked / attn_scores_scaled / attn_weights) + # are row-major (S, S) per head, so a contiguous row-range slice maps + # directly to a contiguous byte range. + # Strict bound: bytes per BD must fit in signed int32 (< 2^31), so the + # per-invocation element count must be strictly less than 2^30. + SOFTMAX_MAX_ELEMENTS_PER_INV = (1 << 30) - 1 + total_softmax_rows = H * S + if total_softmax_rows * S <= SOFTMAX_MAX_ELEMENTS_PER_INV: + n_softmax_invocations = 1 + else: + # Smallest n that simultaneously divides total_softmax_rows evenly + # AND keeps each invocation's transfer at or below the limit. + n_softmax_invocations = ( + total_softmax_rows * S + SOFTMAX_MAX_ELEMENTS_PER_INV - 1 + ) // SOFTMAX_MAX_ELEMENTS_PER_INV + while ( + total_softmax_rows % n_softmax_invocations != 0 + or (total_softmax_rows // n_softmax_invocations) * S + > SOFTMAX_MAX_ELEMENTS_PER_INV + ): + n_softmax_invocations += 1 + softmax_rows_per_inv = total_softmax_rows // n_softmax_invocations + assert softmax_rows_per_inv % 16 == 0, ( + f"softmax_rows_per_inv ({softmax_rows_per_inv}) must be a multiple of 16; " + f"got total_rows={total_softmax_rows}, n_invocations={n_softmax_invocations}" + ) + + # Parallelise softmax across cores: each core handles a row-aligned + # slice of the (rows_per_inv, S) data. Pick the largest divisor of + # softmax_rows_per_inv that is <= num_cols. + softmax_num_cols = num_cols + while softmax_rows_per_inv % softmax_num_cols != 0: + softmax_num_cols -= 1 + + softmax = Softmax( + rows=softmax_rows_per_inv, + cols=S, + num_aie_columns=softmax_num_cols, + num_channels=1, + rtp_vector_size=S, + chunk_size=softmax_chunk_size, + context=elf_ctx, + ) + # Context GEMM is capped at 4 cores: with N=d=64, tile_n must be a + # multiple of 16 (matmul kernel constraint n % (2*t) == 0), so + # tile_n*num_aie_columns = 64 means num_aie_columns <= 4. + gemm_context = GEMM( + M=gemm_M_chunk, + K=S, + N=d, + num_aie_columns=min(4, num_cols), + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(d, min(4, num_cols)), + context=elf_ctx, + prio_accuracy=True, + ) + + # Per-head byte sizes + qh = S * d * B # queries per head: (S, d) + kdS = d * S * B # keys per head: (d, S) + kSd = S * d * B # values per head: (S, d) + sh = S * S * B # scores/weights per head: (S, S) + ch = S * d * B # context per head: (S, d) + + # Per-M-chunk byte sizes (the M dimension is contiguous in row-major + # storage so M-slices map directly to byte ranges within each head) + q_chunk = gemm_M_chunk * d * B # queries chunk: (M_chunk, d) + s_chunk = gemm_M_chunk * S * B # scores chunk: (M_chunk, S) + w_chunk = gemm_M_chunk * S * B # weights chunk: (M_chunk, S) + c_chunk = gemm_M_chunk * d * B # context chunk: (M_chunk, d) + + # In-place attn buffer: single (H, S, S) scratch slot used in-place + # throughout the chain (score → scale → [mask] → softmax → context). + # Halves scratch memory (24 GB instead of 48 at S=32K, H=12) and is + # also marginally faster than the 2-buffer aliasing version (better + # cache locality with one buffer touched repeatedly). + attn_buf = "attn" + scores_buf = scaled_buf = masked_buf = weights_buf = attn_buf + + score_calls = [ + ( + gemm_scores, + f"queries[{h*qh + i*q_chunk}:{h*qh + (i+1)*q_chunk}]", + f"keys[{h*kdS}:{(h+1)*kdS}]", + f"{scores_buf}[{h*sh + i*s_chunk}:{h*sh + (i+1)*s_chunk}]", + ) + for h in range(H) + for i in range(n_m_chunks) + ] + + context_calls = [ + ( + gemm_context, + f"{weights_buf}[{h*sh + i*w_chunk}:{h*sh + (i+1)*w_chunk}]", + f"values[{h*kSd}:{(h+1)*kSd}]", + f"attn_context[{h*ch + i*c_chunk}:{h*ch + (i+1)*c_chunk}]", + ) + for h in range(H) + for i in range(n_m_chunks) + ] + + # Build the softmax runlist entries (one per invocation when row-split). + softmax_input_buf = masked_buf if causal_mask else scaled_buf + softmax_chunk_bytes = softmax_rows_per_inv * S * B + if n_softmax_invocations == 1: + softmax_calls = [(softmax, softmax_input_buf, weights_buf)] + else: + softmax_calls = [ + ( + softmax, + f"{softmax_input_buf}[{i*softmax_chunk_bytes}:{(i+1)*softmax_chunk_bytes}]", + f"{weights_buf}[{i*softmax_chunk_bytes}:{(i+1)*softmax_chunk_bytes}]", + ) + for i in range(n_softmax_invocations) + ] + + runlist = [ + *score_calls, + (scale, scores_buf, scaled_buf), + ] + + if causal_mask: + # AXPY causal-mask mode takes only (input, output) — the mask values + # are baked into the kernel call (scalar -INF), no buffer needed. + # Multiple invocations on disjoint slices when needed to stay under + # the BD-length compiler-overflow limit. Layout per invocation: + # * Multi-head: contiguous range of heads_per_mask_inv whole heads + # (mask_subblocks == 1, rows_per_block == S) + # * Sub-head: contiguous range of mask_rows_per_block rows + # starting at sub_idx * mask_rows_per_block within + # one head; emitted for every head × every sub-block + n_head_groups = H // heads_per_mask_inv + head_group_bytes = heads_per_mask_inv * S * S * B # full head-group span + sub_chunk_bytes = mask_rows_per_block * S * B + mask_calls = [] + for g in range(n_head_groups): + for sub_idx in range(mask_subblocks): + start = g * head_group_bytes + sub_idx * sub_chunk_bytes + end = start + heads_per_mask_inv * sub_chunk_bytes + mask_calls.append( + ( + mask_ops[sub_idx], + f"{scaled_buf}[{start}:{end}]", + f"{masked_buf}[{start}:{end}]", + ) + ) + if n_head_groups == 1 and mask_subblocks == 1: + # Whole-buffer fast path (avoids slice notation in MLIR) + runlist += [(mask_ops[0], scaled_buf, masked_buf)] + else: + runlist += mask_calls + + if not disable_softmax: + runlist += softmax_calls + runlist += context_calls + + buffer_sizes = { + "queries": H * S * d * B, + "keys": H * d * S * B, + "values": H * S * d * B, + "attn": H * S * S * B, + "attn_context": H * S * d * B, + } + + return runlist, buffer_sizes + + +class AttentionPrefillFused(FusedMLIROperator): + """Fused attention prefill (core, no projections/RoPE). + + Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. + """ + + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + dispatch="auto", + disable_softmax=False, + ): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + elf_ctx = context or AIEContext() + runlist, buffer_sizes = _build_core_ops( + num_heads, + num_kv_groups, + head_dim, + seq_len, + elf_ctx, + causal_mask=causal_mask, + disable_softmax=disable_softmax, + ) + + mask_suffix = "_causal" if causal_mask else "_nomask" + if disable_softmax: + mask_suffix += "_nosm" + input_args = ["queries", "keys", "values"] + + super().__init__( + name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s{mask_suffix}", + runlist=runlist, + input_args=input_args, + output_args=["attn_context"], + buffer_sizes=buffer_sizes, + dispatch=dispatch, + context=elf_ctx, + ) + + +class AttentionPrefillProjectedFused(FusedMLIROperator): + """Fused attention prefill with Q/K/V projections and RoPE. + + Accepts raw input (S, E) and rope_angles (S, d). + """ + + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + dispatch="auto", + ): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + self._dispatch_arg = dispatch + + H, G, d, E, S = num_heads, num_kv_groups, head_dim, embedding_dim, seq_len + group_size = H // G + B = 2 + num_cols = aie_utils.get_current_device().cols + + elf_ctx = context or AIEContext() + + # ---- Projection + RoPE ---- + gemm_query = GEMM( + M=S, + K=E, + N=H * d, + num_aie_columns=num_cols, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(H * d, num_cols), + context=elf_ctx, + ) + gemm_kv = GEMM( + M=S, + K=E, + N=G * d, + num_aie_columns=num_cols, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(G * d, num_cols), + context=elf_ctx, + ) + rope_queries = RoPE(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) + rope_keys = RoPE(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) + + # ---- Deinterleave ---- + deinterleave_q = StridedCopy( + input_sizes=(H, S, d), + input_strides=(d, H * d, 1), + input_offset=0, + output_sizes=(H, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * H * d, + output_buffer_size=H * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, + ) + deinterleave_kv = StridedCopy( + input_sizes=(G, S, d), + input_strides=(d, G * d, 1), + input_offset=0, + output_sizes=(G, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * G * d, + output_buffer_size=G * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, + ) + + # ---- Transpose keys + GQA repeat ---- + transpose_keys = Transpose( + M=S, + N=d, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=elf_ctx, + ) + repeat_kv = Repeat( + rows=G, + cols=d * S, + repeat=group_size, + transfer_size=d, + context=elf_ctx, + ) + + kSd = S * d * B + kdS = d * S * B + + prefix_runlist = [ + (gemm_query, "input", "W_query", "queries_projected"), + (gemm_kv, "input", "W_key", "keys_projected"), + (gemm_kv, "input", "W_value", "values_projected"), + (rope_queries, "queries_projected", "rope_angles", "queries_roped"), + (rope_keys, "keys_projected", "rope_angles", "keys_roped"), + (deinterleave_q, "queries_roped", "queries"), + (deinterleave_kv, "keys_roped", "keys_deint"), + (deinterleave_kv, "values_projected", "values_deint"), + *[ + ( + transpose_keys, + f"keys_deint[{g*kSd}:{(g+1)*kSd}]", + f"keys_transposed[{g*kdS}:{(g+1)*kdS}]", + ) + for g in range(G) + ], + (repeat_kv, "keys_transposed", "keys"), + (repeat_kv, "values_deint", "values"), + ] + prefix_buffer_sizes = { + "queries_projected": S * H * d * B, + "keys_projected": S * G * d * B, + "values_projected": S * G * d * B, + "queries_roped": S * H * d * B, + "keys_roped": S * G * d * B, + "keys_deint": G * S * d * B, + "values_deint": G * S * d * B, + "keys_transposed": G * d * S * B, + } + + core_runlist, core_buffer_sizes = _build_core_ops( + H, + G, + d, + S, + elf_ctx, + causal_mask=causal_mask, + num_cols=num_cols, + ) + + # ---- Reinterleave + output projection ---- + reinterleave = StridedCopy( + input_sizes=(1, 1, 1, H * S * d), + input_strides=(0, 0, 0, 1), + input_offset=0, + output_sizes=(H, 256, S // 256, d), + output_strides=(d, 256 * H * d, H * d, 1), + output_offset=0, + input_buffer_size=H * S * d, + output_buffer_size=S * H * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, + ) + gemm_output = GEMM( + M=S, + K=H * d, + N=E, + num_aie_columns=num_cols, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(E, num_cols), + context=elf_ctx, + prio_accuracy=True, + ) + + suffix_runlist = [ + (reinterleave, "attn_context", "context_interleaved"), + (gemm_output, "context_interleaved", "W_output", "attn_output"), + ] + suffix_buffer_sizes = { + "context_interleaved": S * H * d * B, + } + + mask_suffix = "_causal" if causal_mask else "_nomask" + input_args = [ + "input", + "rope_angles", + "W_query", + "W_key", + "W_value", + "W_output", + ] + + super().__init__( + name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s{mask_suffix}", + runlist=prefix_runlist + core_runlist + suffix_runlist, + input_args=input_args, + output_args=["attn_output"], + buffer_sizes={ + **prefix_buffer_sizes, + **core_buffer_sizes, + **suffix_buffer_sizes, + }, + dispatch=dispatch, + context=elf_ctx, + ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py new file mode 100644 index 00000000..925ace54 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def generate_random_inputs(H, G, d, E, S, causal=True, seed=42): + """Generate just the *inputs* needed for the core MHA attention test, with + no expensive PyTorch reference computation. + + Suitable for the benchmark test (which doesn't verify the full output) and + for very large sequence lengths where the full golden reference is + impractical. The causal mask is built with a single ``torch.triu`` instead + of the H*S² nested-loop construction in :func:`generate_golden_reference`. + + Returned dict matches the input keys consumed by the benchmark test + (``queries_deinterleaved``, ``keys_for_scores``, ``values_for_context``, + ``attn_scale_factor``, ``causal_mask``), plus ``_scale`` (the scalar + 1/sqrt(d)) and ``_causal`` (the bool flag) for use by sample verification. + """ + torch.manual_seed(seed) + val_range = 0.5 + + # Pre-deinterleaved/transposed/repeated Q, K, V (the layout the + # AttentionPrefillFused operator consumes directly). + queries_deinterleaved = (torch.randn(H, S, d) * val_range).to(torch.bfloat16) + keys_for_scores = (torch.randn(H, d, S) * val_range).to(torch.bfloat16) + values_for_context = (torch.randn(H, S, d) * val_range).to(torch.bfloat16) + + scale = 1.0 / (d**0.5) + attn_scale_factor = torch.full((H * S * S,), scale, dtype=torch.bfloat16) + + out = { + "queries_deinterleaved": queries_deinterleaved, + "keys_for_scores": keys_for_scores, + "values_for_context": values_for_context, + "attn_scale_factor": attn_scale_factor, + "_scale": scale, + "_causal": causal, + } + + if causal: + # Vectorized causal mask: -inf strictly above the diagonal, broadcast + # across all H heads. Shape (H*S, S) to match the operator layout. + single_mask = torch.triu( + torch.full((S, S), float("-inf"), dtype=torch.bfloat16), + diagonal=1, + ) + out["causal_mask"] = ( + single_mask.unsqueeze(0).expand(H, -1, -1).reshape(H * S, S).contiguous() + ) + + return out + + +def compute_attn_context_at_rows( + queries_deinterleaved, + keys_for_scores, + values_for_context, + scale, + causal, + sample_hms, +): + """Compute the expected ``attn_context[h, m, :]`` row for each (h, m) in + ``sample_hms``. + + Cheap even at very large S: per sample is O(S * d) — a single (1, d) @ (d, S) + matmul plus an O(S) softmax plus an (S,) @ (S, d) reduction. + + Args: + queries_deinterleaved: (H, S, d) bfloat16 tensor + keys_for_scores: (H, d, S) bfloat16 tensor + values_for_context: (H, S, d) bfloat16 tensor + scale: 1 / sqrt(d) (Python float) + causal: bool — apply causal mask (zero out k > m) + sample_hms: iterable of (h, m) tuples to compute + + Returns: + dict mapping (h, m) -> torch.Tensor of shape (d,) in bfloat16. + """ + out = {} + for h, m in sample_hms: + q = queries_deinterleaved[h, m, :].float() # (d,) + k = keys_for_scores[h, :, :].float() # (d, S) + scores = q @ k # (S,) + scaled = scores * scale # (S,) + if causal: + # Match the operator's behaviour: positions strictly greater than m + # receive -inf and contribute zero after softmax. + scaled = scaled.clone() + scaled[m + 1 :] = float("-inf") + weights = torch.softmax(scaled, dim=-1) # (S,) + v = values_for_context[h, :, :].float() # (S, d) + out[(h, m)] = (weights @ v).to(torch.bfloat16) # (d,) + return out + + +def _apply_rope_4d(x, angles): + """Apply RoPE to a 4D tensor using interleaved cos/sin angles. + + x: (batch, heads, seq_len, head_dim) + angles: (seq_len, head_dim) with interleaved [cos_0, sin_0, cos_1, sin_1, ...] + Returns: same shape as x with RoPE applied (two-halves method). + """ + half = x.shape[-1] // 2 + cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + sin = angles[:, 1::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + x1, x2 = x[..., :half], x[..., half:] + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + +def _bf16_matmul(a, b): + """(float32 matmul) → bfloat16, matching NPU accumulation.""" + return (a.float() @ b.float()).to(torch.bfloat16) + + +def generate_golden_reference( + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + seed=42, +): + """Generate golden reference for fused attention prefill. + + Parameters: + num_heads (H): number of query attention heads + num_kv_groups (G): number of KV heads (G=H for MHA, G 1: + keys_for_scores = ( + keys_transposed.reshape(G, d * S) + .repeat_interleave(group_size, dim=0) + .reshape(H, d, S) + ) + values_for_context = ( + values_deinterleaved.reshape(G, S * d) + .repeat_interleave(group_size, dim=0) + .reshape(H, S, d) + ) + else: + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) + + # ---- Score GEMM per head ---- + attn_scores = torch.stack( + [_bf16_matmul(queries_deinterleaved[h], keys_for_scores[h]) for h in range(H)] + ) # (H, S, S) + + # ---- Scale ---- + attn_scores_scaled = (attn_scores.float() * scale).to(torch.bfloat16) + + # ---- Causal mask ---- + attn_scores_masked = ( + attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() + ).to(torch.bfloat16) + + # ---- Softmax ---- + attn_weights = torch.nn.functional.softmax( + attn_scores_masked.float().reshape(H, S, S), dim=-1 + ).to( + torch.bfloat16 + ) # (H, S, S) + + # ---- Context GEMM per head ---- + attn_context = torch.stack( + [_bf16_matmul(attn_weights[h], values_for_context[h]) for h in range(H)] + ) # (H, S, d) + + # ---- Re-interleave context: (H, S, d) → (S, H*d) ---- + context_interleaved = attn_context.transpose(0, 1).contiguous().reshape(S, H * d) + + # ---- Output projection ---- + attn_output = _bf16_matmul(context_interleaved, W_output) + + return { + "input": x, + "rope_angles": rope_angles, + "W_query": W_query, + "W_key": W_key, + "W_value": W_value, + "W_output": W_output, + "attn_scale_factor": attn_scale_factor, + "causal_mask": causal_mask, + "queries_raw": queries_raw, + "keys_raw": keys_raw, + "values_raw": values_raw, + "queries_roped": queries_roped, + "keys_roped": keys_roped, + "queries_deinterleaved": queries_deinterleaved, + "keys_deinterleaved": keys_deinterleaved, + "keys_transposed": keys_transposed, + "values_deinterleaved": values_deinterleaved, + "keys_for_scores": keys_for_scores, + "values_for_context": values_for_context, + "attn_scores": attn_scores, + "attn_scores_scaled": attn_scores_scaled, + "attn_scores_masked": attn_scores_masked, + "attn_weights": attn_weights, + "attn_context": attn_context, + "context_interleaved": context_interleaved, + "attn_output": attn_output, + } diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py new file mode 100644 index 00000000..73d80440 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +from ml_dtypes import bfloat16 + +from iron.common.test_utils import verify_buffer + +from iron.operators.mha_prefill_lxl_sd.op import ( + AttentionPrefillFused, + AttentionPrefillProjectedFused, +) +from iron.operators.mha_prefill_lxl_sd.reference import ( + generate_golden_reference, + generate_random_inputs, + compute_attn_context_at_rows, +) + +REL_TOL = 0.08 +ABS_TOL = 2.0 +MAX_ERROR_RATE = 0.03 + + +def get_params(): + return [ + pytest.param(2, 2, 64, 256, 256, id="H2"), + pytest.param(32, 8, 64, 2048, 256, id="Llama3.2-256seq"), + pytest.param(12, 12, 64, 768, 256, id="GPT2-Small-256seq"), + ] + + +def get_benchmark_params(): + """GPT-2 Small across sequence lengths 256..32768, with/without causal mask.""" + params = [] + S = 256 + while S <= 32768: + for mask in [True, False]: + for dispatch in ["auto", "separate"]: + tag = "causal" if mask else "nomask" + suffix = f"-{dispatch}" if dispatch != "auto" else "" + params.append(pytest.param( + 12, 12, 64, 768, S, mask, dispatch, + id=f"GPT2-S{S}-{tag}{suffix}", + )) + S *= 2 + return params + + +def _load_input(fc, name, tensor): + """Load a tensor into a named sub-buffer of the fused callable.""" + np_buf = tensor.contiguous().view(torch.uint16).numpy().view(bfloat16) + fc.get_buffer(name).data[:] = np_buf.flatten() + + +def _get_scratch_tensor(fc, name, shape): + """Read a named buffer from the fused callable's scratch space.""" + fc.scratch_buffer._sync_from_device() + sub = fc.get_buffer(name) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) + + +def _get_output_tensor(fc, name, shape): + """Read a named buffer from the fused callable's output space.""" + fc.output_buffer._sync_from_device() + sub = fc.get_buffer(name) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) + + +def _verify_output(fc, golden, H, d, S, E): + """Chain-consistent output verification shared by both test variants.""" + npu_context = torch.from_numpy( + _get_scratch_tensor(fc, "context_interleaved", (S, H * d)) + ).bfloat16() + chain_ref = (npu_context.float() @ golden["W_output"].float()).to(torch.bfloat16) + + fc.output_buffer._sync_from_device() + output_np = fc.get_buffer("attn_output").data + output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() + + errors = verify_buffer( + output, + "attn_output", + chain_ref.reshape(S, E), + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" + + +def _core_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the core attention operator.""" + score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) + context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) + return score_flops + context_flops + + +def _projected_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the projected attention operator.""" + query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) + kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each + output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj + + +# --------------------------------------------------------------------------- +# Core attention tests (pre-projected Q, K, V) +# --------------------------------------------------------------------------- + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd(H, G, d, E, S): + """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + actual = _get_output_tensor(fc, "attn_context", (H, S, d)) + expected = golden["attn_context"].float().numpy().reshape(H, S, d) + errors = verify_buffer( + torch.from_numpy(actual).bfloat16(), + "attn_context", + torch.from_numpy(expected).bfloat16().reshape(H, S, d), + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" + + +# --------------------------------------------------------------------------- +# Projected attention tests (with Q/K/V projections + RoPE) +# --------------------------------------------------------------------------- + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_attention_prefill_projected_fused(H, G, d, E, S): + """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AttentionPrefillProjectedFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "input", golden["input"]) + _load_input(fc, "rope_angles", golden["rope_angles"]) + _load_input(fc, "W_query", golden["W_query"]) + _load_input(fc, "W_key", golden["W_key"]) + _load_input(fc, "W_value", golden["W_value"]) + _load_input(fc, "W_output", golden["W_output"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _projected_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + _verify_output(fc, golden, H, d, S, E) + + +# --------------------------------------------------------------------------- +# Benchmark: GPT-2 Small core MHA across sequence lengths, +/- causal mask +# --------------------------------------------------------------------------- + + +@pytest.mark.benchmark +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S,causal,dispatch", get_benchmark_params()) +def test_mha_prefill_benchmark(H, G, d, E, S, causal, dispatch): + """Benchmark core MHA for GPT-2 Small across sequence lengths. + + Uses cheap random inputs (no full PyTorch reference) and verifies + correctness by recomputing the expected ``attn_context`` row for a small + number of randomly-chosen (head, row) positions — feasible at any S. + """ + inputs = generate_random_inputs(H, G, d, E, S, causal=causal) + + op = AttentionPrefillFused(H, G, d, E, S, causal_mask=causal, dispatch=dispatch) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", inputs["queries_deinterleaved"]) + _load_input(fc, "keys", inputs["keys_for_scores"]) + _load_input(fc, "values", inputs["values_for_context"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + # ---- Sample-based correctness check ---- + # Pick a handful of random (head, row) pairs and recompute the expected + # attn_context row for each (cheap: O(S*d) per sample). + actual_context = _get_output_tensor(fc, "attn_context", (H, S, d)) + rng = np.random.default_rng(seed=0) + n_samples = 16 + sample_hms = [(int(rng.integers(0, H)), int(rng.integers(0, S))) for _ in range(n_samples)] + expected_rows = compute_attn_context_at_rows( + inputs["queries_deinterleaved"], + inputs["keys_for_scores"], + inputs["values_for_context"], + inputs["_scale"], + causal, + sample_hms, + ) + failures = [] + for (h, m), exp in expected_rows.items(): + act = torch.from_numpy(actual_context[h, m, :]).bfloat16() + diff = (act.float() - exp.float()).abs() + rel = diff / (exp.float().abs() + 1e-6) + # An element fails only if it exceeds BOTH abs_tol and rel_tol + bad = (diff > ABS_TOL) & (rel > REL_TOL) + if bad.any(): + failures.append( + f"(h={h}, m={m}): {int(bad.sum())}/{d} bad, " + f"max_abs={diff.max().item():.4f}, max_rel={rel.max().item():.4f}" + ) + assert not failures, "Sample verification failed:\n " + "\n ".join(failures) + + +# --------------------------------------------------------------------------- +# Intermediate checks (extensive, not run by default) +# --------------------------------------------------------------------------- + +INTERMEDIATE_CHECKS = [ + ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S), "scratch"), + ( + "attn_scores_masked", + "attn_scores_masked", + lambda H, G, S, d: (H, S, S), + "scratch", + ), + ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d), "output"), +] + + +@pytest.mark.extensive +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): + """Check intermediate buffers of core attention (for debugging).""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + + fc() + + for buf_name, golden_key, shape_fn, buf_type in INTERMEDIATE_CHECKS: + shape = shape_fn(H, G, S, d) + if buf_type == "output": + actual = _get_output_tensor(fc, buf_name, shape) + else: + actual = _get_scratch_tensor(fc, buf_name, shape) + expected = golden[golden_key].float().numpy().reshape(shape) + diff = np.abs(actual - expected) + print( + f" [{buf_name}] shape={shape} " + f"nan={int(np.isnan(actual).sum())} " + f"max_abs_err={diff.max():.4f} mean_abs_err={diff.mean():.6f}" + ) diff --git a/iron/operators/softmax/design.py b/iron/operators/softmax/design.py index 5cb68c39..b3d73fb3 100644 --- a/iron/operators/softmax/design.py +++ b/iron/operators/softmax/design.py @@ -20,6 +20,183 @@ from ml_dtypes import bfloat16 +def _softmax_partial( + dev, + num_elements, + num_aie_columns, + num_channels, + tile_size, + chunk_size, + func_prefix="", + kernel_obj_file="softmax.o", +): + """Online / tiled softmax that processes each row in sub-tile chunks. + + Each row of *tile_size* elements is processed in two passes: + 1. Stats pass – reads chunks, accumulates running max and sum(exp). + 2. Norm pass – reads the same chunks again, writes exp(x-max)/sum. + + Two separate input ObjectFifos are used so that the DMA can feed each pass + independently from the same DDR source buffer. + """ + total_cores = num_aie_columns * num_channels + per_core_elements = num_elements // total_cores + if num_elements % total_cores != 0: + raise ValueError( + f"Number of elements ({num_elements}) must be a multiple of {total_cores}." + ) + + rows_per_core = per_core_elements // tile_size + chunks_per_row = tile_size // chunk_size + dtype = bfloat16 + + # Tensor / tile types + tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] + chunk_ty = np.ndarray[(chunk_size,), np.dtype[dtype]] + stats_ty = np.ndarray[(16,), np.dtype[dtype]] # only [0..1] used + + chunk = num_elements // num_aie_columns // num_channels + + # --- Object FIFOs ------------------------------------------------------- + of_in_stats = [ + ObjectFifo(chunk_ty, name=f"in_stats_{i}_{j}") + for i in range(num_aie_columns) + for j in range(num_channels) + ] + of_in_norm = [ + ObjectFifo(chunk_ty, name=f"in_norm_{i}_{j}") + for i in range(num_aie_columns) + for j in range(num_channels) + ] + of_outs = [ + ObjectFifo(chunk_ty, name=f"out_{i}_{j}") + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Kernel declarations ------------------------------------------------ + init_kernel = Kernel( + f"{func_prefix}softmax_partial_init_bf16", + f"{func_prefix}{kernel_obj_file}", + [stats_ty], + ) + stats_kernel = Kernel( + f"{func_prefix}softmax_partial_stats_bf16", + f"{func_prefix}{kernel_obj_file}", + [chunk_ty, stats_ty, np.int32], + ) + norm_kernel = Kernel( + f"{func_prefix}softmax_partial_norm_bf16", + f"{func_prefix}{kernel_obj_file}", + [chunk_ty, chunk_ty, stats_ty, np.int32], + ) + + # --- Local stats buffers (one per core) --------------------------------- + stats_buffers = [ + Buffer( + initial_value=np.zeros(16, dtype=dtype), + name=f"stats_{i}_{j}", + ) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + barriers = [ + WorkerRuntimeBarrier() + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Worker body -------------------------------------------------------- + def core_body( + of_s, of_n, of_out, + init_k, stats_k, norm_k, + stats_buf, barrier, + ): + barrier.wait_for_value(1) + for _ in range_(rows_per_core): + # Reset running max / sum for the new row + init_k(stats_buf) + + # Pass 1 – accumulate max and sum(exp) + for _ in range_(chunks_per_row): + elem = of_s.acquire(1) + stats_k(elem, stats_buf, chunk_size) + of_s.release(1) + + # Pass 2 – normalise: exp(x - max) / sum + for _ in range_(chunks_per_row): + elem_in = of_n.acquire(1) + elem_out = of_out.acquire(1) + norm_k(elem_in, elem_out, stats_buf, chunk_size) + of_n.release(1) + of_out.release(1) + + # --- Workers ------------------------------------------------------------ + def _worker_args(k): + return [ + of_in_stats[k].cons(), + of_in_norm[k].cons(), + of_outs[k].prod(), + init_kernel, + stats_kernel, + norm_kernel, + stats_buffers[k], + barriers[k], + ] + + workers = [ + Worker(core_body, _worker_args(i * num_channels + j), stack_size=0xD00) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Tensor access patterns (identical for both input FIFOs) ------------ + taps = [ + TensorAccessPattern( + (1, num_elements), + chunk * i * num_channels + chunk * j, + [1, 1, 1, chunk], + [0, 0, 0, 1], + ) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + # --- Runtime sequence --------------------------------------------------- + rt = Runtime() + with rt.sequence(tensor_ty, tensor_ty) as (A, C): + rt.start(*workers) + + for k in range(num_aie_columns * num_channels): + rt.set_barrier(barriers[k], 1) + + tg = rt.task_group() + + for i in range(num_aie_columns): + for j in range(num_channels): + k = i * num_channels + j + # Feed the stats-pass FIFO + rt.fill( + of_in_stats[k].prod(), A, taps[k], task_group=tg, + ) + # Feed the norm-pass FIFO (same source data) + rt.fill( + of_in_norm[k].prod(), A, taps[k], task_group=tg, + ) + + for i in range(num_aie_columns): + for j in range(num_channels): + k = i * num_channels + j + rt.drain( + of_outs[k].cons(), C, taps[k], wait=True, task_group=tg, + ) + + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) + + def softmax( dev, num_elements, @@ -31,7 +208,22 @@ def softmax( mask_patch_value=0, func_prefix="", kernel_obj_file="softmax.o", + chunk_size=None, ): + # ---- Partial (online) softmax path ---- + if chunk_size is not None: + return _softmax_partial( + dev, + num_elements, + num_aie_columns, + num_channels, + tile_size, + chunk_size, + func_prefix, + kernel_obj_file, + ) + + # ---- Full-row softmax path (original) ---- per_tile_elements = tile_size if rtp_vector_size is None: rtp_vector_size = per_tile_elements diff --git a/iron/operators/softmax/op.py b/iron/operators/softmax/op.py index 71aec051..fddfa0a1 100644 --- a/iron/operators/softmax/op.py +++ b/iron/operators/softmax/op.py @@ -20,7 +20,12 @@ @dataclass class Softmax(MLIROperator): - """AIE-accelerated Softmax operation""" + """AIE-accelerated Softmax operation + + When *chunk_size* is set (and < cols), uses an online / tiled softmax + that processes each row in two passes with sub-tile chunks, avoiding the + local-memory exhaustion that occurs with very long rows (e.g. S >= 8192). + """ rows: int cols: int @@ -28,6 +33,7 @@ class Softmax(MLIROperator): num_channels: int = 1 rtp_vector_size: int | None = None mask_patch_value: int = 0 + chunk_size: int | None = None context: object = field(default=None, repr=False) @property @@ -43,6 +49,15 @@ def __post_init__(self): raise ValueError( f"rows ({self.rows}) must be a multiple of num_aie_columns ({self.num_aie_columns})" ) + if self.chunk_size is not None: + if self.cols % self.chunk_size != 0: + raise ValueError( + f"cols ({self.cols}) must be a multiple of chunk_size ({self.chunk_size})" + ) + if self.chunk_size % 64 != 0: + raise ValueError( + f"chunk_size ({self.chunk_size}) must be a multiple of 64" + ) MLIROperator.__init__(self, context=self.context) @property @@ -69,6 +84,7 @@ def get_mlir_artifact(self): "rtp_vector_size": self.rtp_vector_size, "mask_patch_value": self.mask_patch_value, "kernel_obj_file": self._kernel_link_file, + "chunk_size": self.chunk_size, }, ), ) diff --git a/iron/operators/softmax/test.py b/iron/operators/softmax/test.py index 066d2309..8cca9f69 100755 --- a/iron/operators/softmax/test.py +++ b/iron/operators/softmax/test.py @@ -85,3 +85,52 @@ def test_softmax(input_length, num_aie_columns, num_channels, tile_size, aie_con print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") assert not errors, f"Test failed with errors: {errors}" + + +# --------------------------------------------------------------------------- +# Partial (online / tiled) softmax tests — enables long rows (S >= 8192) +# --------------------------------------------------------------------------- + + +def get_partial_params(): + """GPT-2 style: 12 heads × S rows of S columns, tested via partial softmax.""" + params = [] + for S in [8192]: + H = 12 # GPT-2 Small heads + rows = H * S + cols = S + chunk_size = 1024 + params.append(pytest.param(rows, cols, chunk_size, id=f"GPT2-S{S}")) + return params + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", +) +@pytest.mark.parametrize("rows,cols,chunk_size", get_partial_params()) +def test_softmax_partial(rows, cols, chunk_size, aie_context): + """Test partial / online softmax with sub-tile chunks for long rows.""" + + golden_ref = generate_golden_reference(rows=rows, cols=cols) + + operator = Softmax( + rows=rows, + cols=cols, + num_aie_columns=1, + num_channels=1, + chunk_size=chunk_size, + context=aie_context, + ) + + input_buffers = {"in": golden_ref["input"]} + output_buffers = {"output": golden_ref["output"]} + + errors, latency_us, bandwidth_gbps = run_test( + operator, input_buffers, output_buffers, rel_tol=0.08, abs_tol=1e-6 + ) + + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + + assert not errors, f"Test failed with errors: {errors}" diff --git a/pytest.ini b/pytest.ini index 44f08847..a3566ee2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,4 +9,5 @@ python_functions = test_* markers = extensive: extensive test suite (deselect with '-m "not extensive"') supported_devices(*devices): mark test as only supported on the given devices (e.g. "npu1", "npu2"). All devices supported by default. + benchmark: benchmark-only tests (select with '-m benchmark') addopts = -v --tb=short --import-mode=importlib