Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d52936f
add single-dispatch layer-by-layer MHA
andrej Apr 6, 2026
dfe5f88
add GPT-2 sizes as test cases, make causal mask an option
andrej Apr 6, 2026
71237a2
as benchmarked
andrej Apr 6, 2026
92e6607
fix DMA dimension overflow
andrej Apr 6, 2026
43e4d07
create separate attn_scores_scaled buffer
andrej Apr 7, 2026
af75210
move output GEMM out of core MHA
andrej Apr 7, 2026
ee87e94
remove symbol renaming after rebase to use link_with, other fixes
andrej Apr 7, 2026
ee02731
format
andrej Apr 7, 2026
abf37ab
make mha_prefill_lxl_sd use all available columns
andrej Apr 15, 2026
4caac12
update test result CSV iteratively rather than all at once
andrej Apr 15, 2026
8a65c6d
make FusedMLIROperator work on Phoenix via multiple xclbin calls
andrej Apr 15, 2026
daf9162
make dispatch mode selectable, add tests
andrej Apr 15, 2026
ffe0515
use partial softmax on long sequence lengths
andrej Apr 15, 2026
6ce23a2
go up to 32768 sequence length for mha_lxl_sd benchmark tests
andrej Apr 16, 2026
0403ee2
stochastic testing for large sequence lengths; split GEMM+softmax int…
andrej Apr 17, 2026
dab067c
reuse buffers to avoid OOM
andrej Apr 17, 2026
bb47493
support longer sequence lengths and reduce buffer sizes: new AXPY mod…
andrej Apr 20, 2026
4abbae4
speed up causal masking
andrej Apr 20, 2026
a233ec2
oops --- softmax should use all cores!
andrej Apr 20, 2026
ada6c1f
parallelization for AXPY when single blocks are too big
andrej Apr 20, 2026
859ee97
Phoenix support for partial softmax
andrej Apr 20, 2026
4375996
reactivate sample-based verification for MHA
andrej Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions aie_kernels/aie2/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "lut_based_ops.h"

#include <aie_api/aie.hpp>
#include <math.h>
#include <stdint.h>

using namespace aie;
Expand Down Expand Up @@ -57,13 +58,132 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out
return;
}

// ---------------------------------------------------------------------------
// Online (partial / tiled) softmax helpers
//
// These three kernels implement a two-pass online softmax that processes a row
// in sub-tile chunks, keeping running max and sum statistics in a small local
// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1]
// used):
// stats[0] = running max
// stats[1] = running sum (of exp(x - max))
// ---------------------------------------------------------------------------

void softmax_partial_stats_impl(bfloat16 *restrict input,
bfloat16 *stats,
const int32_t vector_size)
{
event0();

const int elem_iters = vector_size / 16;

float running_max = (float)stats[0];
float running_sum = (float)stats[1];

aie::vector<bfloat16, 16> input_bf16;
aie::accum<accfloat, 16> exp_val_accum = aie::zeros<accfloat, 16>();

auto it_in = aie::cbegin_vector<16>((bfloat16 *)input);

// Single-pass online algorithm: for each vector chunk, check if max
// needs updating, rescale the running sum if so, then accumulate
// exp(x - max).
for (int i = 0; i < elem_iters; i++) {
input_bf16 = *it_in++;
float chunk_max = aie::reduce_max(input_bf16);

if (chunk_max > running_max) {
// Rescale accumulated exp values by exp(old_max - new_max)
aie::vector<bfloat16, 16> correction =
to_v16bfloat16(getExpBf16(
aie::broadcast<bfloat16, 16>((bfloat16)(running_max - chunk_max))));
float scale = (float)correction[0];
// Rescale the partial vector accumulator
aie::vector<bfloat16, 16> scale_vec =
aie::broadcast<bfloat16, 16>((bfloat16)scale);
exp_val_accum = aie::mul(exp_val_accum.to_vector<bfloat16>(), scale_vec);
// Rescale the running scalar sum from previous chunks
running_sum *= scale;
running_max = chunk_max;
}

aie::vector<bfloat16, 16> shifted = aie::sub(
input_bf16, aie::broadcast<bfloat16, 16>((bfloat16)running_max));
aie::vector<bfloat16, 16> exp_val = to_v16bfloat16(getExpBf16(shifted));
exp_val_accum = add(exp_val_accum, exp_val);
}

// Reduce the vector accumulator and add to running sum
aie::vector<float, 16> reduce = exp_val_accum.to_vector<float>();
running_sum += aie::reduce_add(reduce);

stats[0] = (bfloat16)running_max;
stats[1] = (bfloat16)running_sum;

event1();
}

void softmax_partial_norm_impl(bfloat16 *restrict input,
bfloat16 *restrict output,
bfloat16 *stats,
const int32_t vector_size)
{
event0();

const int elem_iters = vector_size / 16;

float max_val = (float)stats[0];
float sum_val = (float)stats[1];
bfloat16 inv_sum = (bfloat16)aie::inv(sum_val);

aie::vector<bfloat16, 16> max_val_vec =
aie::broadcast<bfloat16, 16>((bfloat16)max_val);

aie::vector<bfloat16, 16> input_bf16;
aie::accum<accfloat, 16> out_vals;

auto it_in = aie::cbegin_restrict_vector<16>((bfloat16 *)input);
auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output);

for (int i = 0; i < elem_iters; i++) {
input_bf16 = *it_in++;
aie::vector<bfloat16, 16> shifted = aie::sub(input_bf16, max_val_vec);
aie::vector<bfloat16, 16> exp_val = to_v16bfloat16(getExpBf16(shifted));
out_vals = aie::mul(exp_val, inv_sum);
*it_out++ = out_vals.to_vector<bfloat16>();
}

event1();
}

extern "C" {

void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size)
{
softmax_simple_bf16(input, output, input_size);
}

void softmax_partial_init_bf16(bfloat16 *stats)
{
stats[0] = (bfloat16)(-INFINITY);
stats[1] = (bfloat16)(0.0f);
}

void softmax_partial_stats_bf16(bfloat16 *restrict input,
bfloat16 *stats,
const int32_t vector_size)
{
softmax_partial_stats_impl(input, stats, vector_size);
}

void softmax_partial_norm_bf16(bfloat16 *restrict input,
bfloat16 *restrict output,
bfloat16 *stats,
const int32_t vector_size)
{
softmax_partial_norm_impl(input, output, stats, vector_size);
}

void mask_bf16(bfloat16 *inout, const int32_t unmasked_size, const int32_t total_size)
{
for (int32_t i = unmasked_size; i < total_size; i++) {
Expand Down
136 changes: 135 additions & 1 deletion aie_kernels/aie2p/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <aie_api/aie.hpp>
#include <stdint.h>
#include <math.h>

#define SM_VEC_LEN 64 // 32
#define log2e 1.4453125 // 1.44269504089
Expand Down Expand Up @@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out
aie::vector<bfloat16, SM_VEC_LEN> in_elems, exp_val, input_bf16, log2e_vec, max_val_vec;
aie::accum<accfloat, SM_VEC_LEN> out_vals, exp_val_accum, scaled_accum, exp_in_accum;

float max_val = 0;
float max_val = -INFINITY;
float accum_exp_val = 0;
float running_max = 0;
bfloat16 col_sum_inv;
Expand Down Expand Up @@ -159,6 +160,118 @@ void partial_softmax_alias_bf16(bfloat16 *restrict input_vector,
return;
}

// ---------------------------------------------------------------------------
// Online (partial / tiled) softmax helpers
//
// These three kernels implement a two-pass online softmax that processes a row
// in sub-tile chunks, keeping running max and sum statistics in a small local
// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1]
// used):
// stats[0] = running max (scaled by log2e)
// stats[1] = running sum (of exp2(x*log2e - max))
// ---------------------------------------------------------------------------

void softmax_partial_stats_impl(bfloat16 *restrict input,
bfloat16 *stats,
const int32_t vector_size)
{
event0();

const int elem_iters = vector_size / SM_VEC_LEN;

aie::vector<bfloat16, SM_VEC_LEN> input_bf16;
aie::accum<accfloat, SM_VEC_LEN> scaled_accum, exp_in_accum;
aie::accum<accfloat, SM_VEC_LEN> exp_val_accum = aie::zeros<accfloat, SM_VEC_LEN>();

aie::vector<bfloat16, SM_VEC_LEN> log2e_vec =
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)log2e);

// --- Phase 1: find local max (scaled by log2e) -------------------------
float local_max = -INFINITY;
auto it_in1 = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
for (int i = 0; i < elem_iters; i++) {
input_bf16 = *it_in1++;
scaled_accum = aie::mul(input_bf16, log2e_vec);
float chunk_max = aie::reduce_max(scaled_accum.to_vector<bfloat16>());
if (chunk_max > local_max) {
local_max = chunk_max;
}
}

// --- Phase 2: update running max, rescale running sum ------------------
float old_max = (float)stats[0];
float old_sum = (float)stats[1];

if (local_max > old_max) {
// New max is larger — rescale the old sum by exp2(old_max - new_max)
aie::vector<float, SM_VEC_LEN> diff_vec =
aie::broadcast<float, SM_VEC_LEN>(old_max - local_max);
aie::vector<bfloat16, SM_VEC_LEN> corr = aie::exp2<bfloat16>(diff_vec);
old_sum = old_sum * (float)corr[0];
old_max = local_max;
}

// --- Phase 3: accumulate exp2(input * log2e - max) for this chunk ------
aie::vector<bfloat16, SM_VEC_LEN> max_val_vec =
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)old_max);

auto it_in2 = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
for (int i = 0; i < elem_iters; i++) {
input_bf16 = *it_in2++;
scaled_accum = aie::mul(input_bf16, log2e_vec);
exp_in_accum = aie::sub(scaled_accum, max_val_vec);
aie::vector<bfloat16, SM_VEC_LEN> exp_val =
aie::exp2<bfloat16>(exp_in_accum.to_vector<float>());
exp_val_accum = add(exp_val_accum, exp_val);
}

aie::vector<float, SM_VEC_LEN> reduce = exp_val_accum.to_vector<float>();
float local_sum = aie::reduce_add(reduce);

// --- Phase 4: store updated stats --------------------------------------
stats[0] = (bfloat16)old_max;
stats[1] = (bfloat16)(old_sum + local_sum);

event1();
}

void softmax_partial_norm_impl(bfloat16 *restrict input,
bfloat16 *restrict output,
bfloat16 *stats,
const int32_t vector_size)
{
event0();

const int elem_iters = vector_size / SM_VEC_LEN;

float max_val = (float)stats[0];
float sum_val = (float)stats[1];
bfloat16 inv_sum = (bfloat16)aie::inv(sum_val);

aie::vector<bfloat16, SM_VEC_LEN> log2e_vec =
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)log2e);
aie::vector<bfloat16, SM_VEC_LEN> max_val_vec =
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)max_val);

aie::vector<bfloat16, SM_VEC_LEN> input_bf16;
aie::accum<accfloat, SM_VEC_LEN> scaled_accum, exp_in_accum, out_vals;

auto it_in = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
auto it_out = aie::begin_restrict_vector<SM_VEC_LEN>((bfloat16 *)output);

for (int i = 0; i < elem_iters; i++) {
input_bf16 = *it_in++;
scaled_accum = aie::mul(input_bf16, log2e_vec);
exp_in_accum = aie::sub(scaled_accum, max_val_vec);
aie::vector<bfloat16, SM_VEC_LEN> exp_val =
aie::exp2<bfloat16>(exp_in_accum.to_vector<float>());
out_vals = aie::mul(exp_val, inv_sum);
*it_out++ = out_vals.to_vector<bfloat16>();
}

event1();
}

extern "C" {

void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size)
Expand All @@ -177,6 +290,27 @@ void partial_softmax_bf16(bfloat16 *restrict input,
partial_softmax_alias_bf16(input, output, scale_buffer, input_size, row_idx, num_rows, scale);
}

void softmax_partial_init_bf16(bfloat16 *stats)
{
stats[0] = (bfloat16)(-INFINITY);
stats[1] = (bfloat16)(0.0f);
}

void softmax_partial_stats_bf16(bfloat16 *restrict input,
bfloat16 *stats,
const int32_t vector_size)
{
softmax_partial_stats_impl(input, stats, vector_size);
}

void softmax_partial_norm_bf16(bfloat16 *restrict input,
bfloat16 *restrict output,
bfloat16 *stats,
const int32_t vector_size)
{
softmax_partial_norm_impl(input, output, stats, vector_size);
}

void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size)
{
// TODO: Optimize this to use vector code
Expand Down
Loading
Loading