diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index f2b0b07fed..e29d510c1f 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -57,6 +57,8 @@ fi python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" +# Currently there is no built-in method to disable autotune: https://github.com/triton-lang/triton/issues/7932 +TRITON_SKIP_AUTOTUNING=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mhc.xml $TE_PATH/tests/pytorch/test_mhc.py || test_fail "test_mhc.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py new file mode 100644 index 0000000000..b497a5ca4c --- /dev/null +++ b/tests/pytorch/test_mhc.py @@ -0,0 +1,482 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from attr import dataclass +import pytest +import torch +import torch.nn.functional as F + +from utils import reset_rng_states +from transformer_engine.pytorch.triton.mhc import ( + mHCScaleFusedOp, + mHCExpandCombineOp, + mHCAggregateOp, + mHCProjectionOp, + mHCSinkhornOp, +) + +seed = 1234 +reset_rng_states() + +# Enable TF32 for matmul to ensure consistency between the fused and reference implementations +torch.backends.cuda.matmul.allow_tf32 = False + + +def mHCProjectionRef(x, phi): + """ + Reference operator for mHC's projection building operation. + + x: (M, nC) where M = s * b + phi: (2n + n^2, nC), which consists of the following matrices + - phi_pre: (n, nC) + - phi_post: (n, nC) + - phi_res: (n^2, nC) + n: number of Hyper Connection streams + C: hidden dimension per stream + """ + x_dtype = x.dtype + x = x.to(torch.float32) + phi = phi.to(torch.float32) + + Hs = x @ phi.T # (M, 2n + n^2) + + x_fp32 = x.to(torch.float32) # Use fp32 for better numerical stability in variance calculation + ms = (x_fp32 * x_fp32).mean(dim=1) + + return Hs.to(x_dtype), ms + + +def mHCScaleRef(H, alpha, beta, ms, n): + """ + Reference operator for mHC's pre and post calculations + + :param: H: (M, 2n + n^2), the unprocessed H matrices where M = s * b + :param: alpha: (3,), three scalar parameters + :param: beta: (1, 2n + n^2), bias term + :param: r: (M,), the denominator for RMSNorm + :param: n: int, the width of Hyper-Connection + + :return Hs: (M, 2n + n^2), the processed H matrices + """ + + M, _ = H.shape + H_dtype = H.dtype + H = H.to(torch.float32) + alpha = alpha.to(torch.float32) + beta = beta.to(torch.float32) + eps = torch.finfo(torch.float32).eps + rms = torch.sqrt(ms + eps) # (M,) + rms = rms.to(torch.float32) + + H_pre = H[:, :n] # (M, n) + H_post = H[:, n : 2 * n] # (M, n) + H_res = H[:, 2 * n :] # (M, n^2) + + beta_pre = beta[0, :n] + beta_post = beta[0, n : 2 * n] + beta_res = beta[0, 2 * n : 2 * n + n * n] + + alpha_pre, alpha_post, alpha_res = alpha[0], alpha[1], alpha[2] + + H_pre = H_pre * alpha_pre + H_post = H_post * alpha_post + H_res = H_res * alpha_res + + H_pre = H_pre / rms[:, None] + H_post = H_post / rms[:, None] + H_res = H_res / rms[:, None] + + H_pre = H_pre + beta_pre + H_post = H_post + beta_post + H_res = H_res + beta_res + + H_pre = F.sigmoid(H_pre) + H_post = 2 * F.sigmoid(H_post) + + out = torch.cat([H_pre, H_post, H_res], dim=-1) # (M, 2n + n^2) + + return out.to(H_dtype) + + +def mHCSinkhornRef(H_res, n=4, iterations=20): + """ + Sinkhorn-Knopp algorithm to convert a matrix into a doubly stochastic matrix. + Calculated in log space for numerical stability. + + :param H_res: a tensor of shape (s, b, n, n) + :return: a tensor of shape (s, b, n, n) + """ + s, b = H_res.shape[:2] + device = H_res.device + dtype = H_res.dtype + + H_res_f = H_res.to( + torch.float32 + ).clone() # Use float32 for better numerical stability during Sinkhorn iterations + + log_mu = torch.zeros(s, b, n, device=device, dtype=torch.float32) + log_nu = torch.zeros(s, b, n, device=device, dtype=torch.float32) + + f = torch.zeros(s, b, n, device=device, dtype=torch.float32) + g = torch.zeros(s, b, n, device=device, dtype=torch.float32) + + for _ in range(iterations): + # Update f: logsumexp over the column dimension (3) + f = log_mu - torch.logsumexp(H_res_f + g.unsqueeze(2), dim=3) + # Update g: logsumexp over the row dimension (2) + g = log_nu - torch.logsumexp(H_res_f + f.unsqueeze(3), dim=2) + + log_P = f.unsqueeze(3) + H_res_f + g.unsqueeze(2) + H_res_out = torch.exp(log_P).to(dtype) # Convert back to original dtype + + return H_res_out + + +def mHCAggregateRef(x, H_pre, n): + """ + Reference operator for applying mHC's pre matrix H to a vector x. + + x: (s, b, C, n) + H_pre: (s, b, n) + """ + H_pre = H_pre.contiguous() + + s, b, C, n = x.shape + H_pre = H_pre.view(s, b, n, 1) + + out = (x @ H_pre).view(s, b, C) + + return out + + +def mHCExpandCombineRef(f, bias, H_post, x, H_res, n): + """ + Reference operator for applying mHC's post transformation and residual transformation + + f: (s, b, C) + bias: (C,) or None + H_post: (s, b, n) + x: (s, b, C, n) + H_res: (s, b, n, n) + """ + + s, b, C, n = x.shape + + # My triton kernels use FMA and MMA instructions with fp32 accumulator for bf16 test cases + # which has better numerical stability than this pytorch implementation + # To match the kernel's accuracy we need to cast to fp32 here to match kernels' result + input_dtype = f.dtype + f = f.to(torch.float32) + bias = bias.to(torch.float32) if bias is not None else None + H_post = H_post.to(torch.float32) + x = x.to(torch.float32) + H_res = H_res.to(torch.float32) + + if bias is not None: + f = f + bias[None, None, :] + + f = f.view(s, b, C, 1) + H_post = H_post.view(s, b, 1, n) + + out = f @ H_post + x @ H_res # (s, b, C, n) + + return out.to(input_dtype) + + +@dataclass +class MHCConfig: + s: int = 2048 # Sequence length + b: int = 32 # Batch size + C: int = 1024 # Hidden dimension + n: int = 4 # Number of Hyper Connection streams + + allow_n = [ + 4, + ] + + def __init__(self, b, s, C, n=4): + assert n in self.allow_n, f"n must be one of {self.allow_n}" + self.b = b + self.s = s + self.C = C + self.n = n + + @staticmethod + def desc(cfg): + return f"b{cfg.b}_s{cfg.s}_C{cfg.C}_n{cfg.n}" + + +mhc_configs = [ + MHCConfig(8, 32, 32), + MHCConfig(8, 128, 16 * 64), + MHCConfig( + 4, + 128, + 16 * 64, + ), + MHCConfig(2, 2048, 24 * 128), + MHCConfig( + 1, + 2048, + 24 * 128, + ), + MHCConfig( + 13, + 1, + 16 * 128, + ), + MHCConfig( + 7, + 1, + 16 * 256, + ), + MHCConfig( + 8, + 1, + 16 * 192, + ), + MHCConfig( + 8, + 128, + 5129, + ), + MHCConfig( + 8, + 512, + 8000, + ), + MHCConfig( + 4, + 1024, + 8192, + ), + MHCConfig( + 2, + 4096, + 8192, + ), + MHCConfig( + 8, + 128, + 16384, + ), +] + + +def get_tols(dtype): + if dtype == torch.bfloat16: + tols = dict(atol=2.5e-2, rtol=2.5e-2) + else: + tols = dict(atol=5e-3, rtol=5e-3) + return tols + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +def test_mhc_projection(cfg: MHCConfig, dtype): + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + nC = n * C + N = 2 * n + n * n + + tols = get_tols(dtype) + use_tf32 = False + + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) + phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") + + x_ref = x.detach().clone().requires_grad_(True) + phi_ref = phi.detach().clone().requires_grad_(True) + + ref_out_Hs, ref_out_ms = mHCProjectionRef(x_ref, phi_ref) + fused_out_Hs_padded, fused_out_ms = mHCProjectionOp.apply(x, phi, use_tf32) + fused_out_Hs = fused_out_Hs_padded[:, :N] + + torch.testing.assert_close(fused_out_Hs, ref_out_Hs, **tols) + torch.testing.assert_close(fused_out_ms, ref_out_ms, **tols) + (ref_out_Hs.sum() + ref_out_ms.sum()).backward() + (fused_out_Hs.sum() + fused_out_ms.sum()).backward() + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + torch.testing.assert_close(phi.grad, phi_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +def test_mhc_elementwise(cfg: MHCConfig, dtype): + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + N = 2 * n + n * n + + tols = get_tols(dtype) + + H_padded = torch.randn(s * b, 32, device="cuda", requires_grad=True, dtype=dtype) + H = H_padded[:, :N] + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + ms_raw = torch.randn(s * b, device="cuda", dtype=dtype).abs() + 1.0 + ms = ms_raw.detach().clone().requires_grad_(True) + + H_ref = H.detach().clone().requires_grad_(True) + alpha_ref = alpha.detach().clone().requires_grad_(True) + beta_ref = beta.detach().clone().requires_grad_(True) + ms_ref = ms.detach().clone().requires_grad_(True) + + ref_out = mHCScaleRef(H_ref[:, :N], alpha_ref, beta_ref, ms_ref, n) + fused_out_padded = mHCScaleFusedOp.apply(H_padded, alpha, beta, ms, n) + fused_out = fused_out_padded[:, :N] + + torch.testing.assert_close(fused_out, ref_out, **tols) + + ref_out.sum().backward() + fused_out.sum().backward() + + torch.testing.assert_close(H_padded.grad[:, :N], H_ref.grad, **tols) + torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols) + torch.testing.assert_close(beta.grad, beta_ref.grad, **tols) + torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +def test_mhc_combined(cfg: MHCConfig, dtype): + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + N = 2 * n + n * n + nC = n * C + + tols = get_tols(dtype) + + tols = get_tols(dtype) + use_tf32 = False + + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) + phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") + + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + + x_ref = x.detach().clone().requires_grad_(True) + phi_ref = phi.detach().clone().requires_grad_(True) + + alpha_ref = alpha.detach().clone().requires_grad_(True) + beta_ref = beta.detach().clone().requires_grad_(True) + + ref_out_H, ref_out_r = mHCProjectionRef(x_ref, phi_ref) + fused_out_H_padded, fused_out_r = mHCProjectionOp.apply(x, phi, use_tf32) + + ref_out = mHCScaleRef(ref_out_H[:, :N], alpha_ref, beta_ref, ref_out_r, n) + fused_out_padded = mHCScaleFusedOp.apply(fused_out_H_padded, alpha, beta, fused_out_r, n) + fused_out = fused_out_padded[:, :N] + + def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): + dtype = x_ref.dtype + x_ref = x_ref.to(torch.float32) + phi_ref = phi_ref.to(torch.float32) + alpha_ref = alpha_ref.to(torch.float32) + beta_ref = beta_ref.to(torch.float32) + + x_rmsnorm = F.rms_norm(x_ref, normalized_shape=(nC,)) + H = x_rmsnorm @ phi_ref.T + H_pre = H[:, :n] + H_post = H[:, n : 2 * n] + H_res = H[:, 2 * n :] + + out_pre = H_pre * alpha_ref[0] + beta_ref[:, :n] + out_post = H_post * alpha_ref[1] + beta_ref[:, n : 2 * n] + out_res = H_res * alpha_ref[2] + beta_ref[:, 2 * n :] + + out_pre = out_pre.sigmoid() + out_post = 2 * out_post.sigmoid() + out_res = out_res + + return out_pre.to(dtype), out_post.to(dtype), out_res.to(dtype) + + H_pre_combined, H_post_combined, _ = mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref) + + torch.testing.assert_close(H_pre_combined, ref_out[:, :n], **tols) + torch.testing.assert_close(H_post_combined, ref_out[:, n : 2 * n], **tols) + + torch.testing.assert_close(H_pre_combined, fused_out[:, :n], **tols) + torch.testing.assert_close(H_post_combined, fused_out[:, n : 2 * n], **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +@pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"]) +def test_mhc_sinkhorn_knopp(cfg: MHCConfig, dtype, recompute): + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + + tols = get_tols(dtype) + + x = torch.randn(s, b, n, n, device="cuda", requires_grad=True, dtype=dtype) + x_ref = x.detach().clone().requires_grad_(True) + + ref_out = mHCSinkhornRef(x_ref, n) + fused_out = mHCSinkhornOp.apply(x, n, recompute) + + torch.testing.assert_close(fused_out, ref_out, **tols) + + ref_out.sum().backward() + fused_out.sum().backward() + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +def test_mhc_aggregate(cfg: MHCConfig, dtype): + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + + tols = get_tols(dtype) + + x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype) + H_pre = torch.randn(s, b, n, device="cuda", requires_grad=True, dtype=dtype) + + x_ref = x.detach().clone().requires_grad_(True) + H_pre_ref = H_pre.detach().clone().requires_grad_(True) + + ref_out = mHCAggregateRef(x_ref, H_pre_ref, n) + fused_out = mHCAggregateOp.apply(x, H_pre, n, False) + + torch.testing.assert_close(fused_out, ref_out, **tols) + + ref_out.sum().backward() + fused_out.sum().backward() + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + torch.testing.assert_close(H_pre.grad, H_pre_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +@pytest.mark.parametrize("with_bias", [True, False], ids=["with_bias", "no_bias"]) +def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias): + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + + tols = get_tols(dtype) + + f = torch.randn(s, b, C, device="cuda", requires_grad=True, dtype=dtype) + bias = None + if with_bias: + bias = torch.randn(C, device="cuda", requires_grad=True, dtype=dtype) + H_post = torch.randn(s, b, n, device="cuda", requires_grad=True, dtype=dtype) + x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype) + H_res = torch.randn(s, b, n, n, device="cuda", requires_grad=True, dtype=dtype) + + f_ref = f.detach().clone().requires_grad_(True) + bias_ref = None if bias is None else bias.detach().clone().requires_grad_(True) + H_post_ref = H_post.detach().clone().requires_grad_(True) + x_ref = x.detach().clone().requires_grad_(True) + H_res_ref = H_res.detach().clone().requires_grad_(True) + + ref_out = mHCExpandCombineRef(f_ref, bias_ref, H_post_ref, x_ref, H_res_ref, n) + fused_out = mHCExpandCombineOp.apply(f, bias, H_post, x, H_res, n, False) + + torch.testing.assert_close(fused_out, ref_out, **tols) + + ref_out.sum().backward() + fused_out.sum().backward() + + torch.testing.assert_close(f.grad, f_ref.grad, **tols) + torch.testing.assert_close(H_post.grad, H_post_ref.grad, **tols) + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + torch.testing.assert_close(H_res.grad, H_res_ref.grad, **tols) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py new file mode 100644 index 0000000000..e83f7043ae --- /dev/null +++ b/transformer_engine/common/triton/mhc.py @@ -0,0 +1,1762 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import itertools +import os + +import triton +import triton.language as tl + + +def projection_config_fwd(): + block_m = [32, 64, 128] + block_k = [512, 1024] + step_k = [32, 64, 128] + warps = [2, 4] + stages = [3, 4] + + configs = [] + for m, bk, sk, w, s in itertools.product(block_m, block_k, step_k, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk, "STEP_SIZE_K": sk}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("TRITON_SKIP_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +def projection_config_bwd(): + block_m = [32, 64, 128] + block_k = [32, 64, 128] + warps = [2, 4] + stages = [2, 3, 4] + + configs = [] + for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) + ) + if os.environ.get("TRITON_SKIP_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune(configs=projection_config_fwd(), key=["M", "K"], reset_to_zero=["h_ptr", "ms_ptr"]) +@triton.jit +def _mhc_projection_fwd_fused( + x_ptr, # (M, K) + phi_ptr, # (N, K) + h_ptr, # (M, 32) + ms_ptr, # (M,) + M, + N, + K, + stride_xm, + stride_xk: tl.constexpr, + stride_phin, + stride_phik: tl.constexpr, + stride_hm: tl.constexpr, + stride_hn: tl.constexpr, + stride_ms: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + STEP_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + precision: tl.constexpr, +): + """ + Kernel for computing the matmul Y = X @ W.T and ms = (X * X).mean(dim=1) in a fused manner. + """ + + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_k >= 0) + tl.assume(stride_xm > 0) + tl.assume(stride_xk == 1) + tl.assume(stride_phin == K) + tl.assume(stride_phik == 1) + tl.assume(stride_hm == 32) + tl.assume(stride_hn == 1) + tl.assume(stride_ms == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_K % 32 == 0) + tl.assume(BLOCK_SIZE_N == 32) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_full = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + + h_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + ms_acc = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + k_base = pid_k * BLOCK_SIZE_K + for k_start in range(0, tl.cdiv(BLOCK_SIZE_K, STEP_SIZE_K)): + k_offs = k_base + k_start * STEP_SIZE_K + tl.arange(0, STEP_SIZE_K) + mask_k = k_offs < K + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + k_offs[None, :] * stride_phik + phi = tl.load( + phi_ptrs, + mask=(offs_n_full[:, None] < N) & mask_k[None, :], + other=0.0, + cache_modifier=".ca", + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + # RMSNorm denominator computation + ms_acc += tl.sum(x * x, axis=1) + # Matrix multiplication + h_acc = tl.dot( + x, tl.trans(phi, (1, 0)), h_acc, input_precision=precision, out_dtype=tl.float32 + ) + + h_ptrs = h_ptr + offs_m[:, None] * stride_hm + offs_n_full[None, :] * stride_hn + tl.atomic_add(h_ptrs, h_acc, mask=mask_m[:, None], sem="relaxed") + + offs_rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_rm = offs_rm < M + offs_rm %= M + ms_ptrs = ms_ptr + offs_rm * stride_ms + ms = ms_acc / tl.cast(K, tl.float32) + tl.atomic_add(ms_ptrs, ms, mask=masks_rm, sem="relaxed") + + +@triton.autotune( + configs=projection_config_bwd(), + key=["M", "K"], +) +@triton.jit +def _mhc_projection_bwd_fused( + x_ptr, + grad_x_ptr, # (M, K) + phi_ptr, # (N, K) + grad_h_ptr, # (M, N) + grad_ms_ptr, # (M,) + M, + N, + K, + stride_xm, + stride_xk: tl.constexpr, + stride_grad_xm, + stride_grad_xk: tl.constexpr, + stride_phin, + stride_phik: tl.constexpr, + stride_grad_phin, + stride_grad_phik: tl.constexpr, + stride_grad_hm: tl.constexpr, + stride_grad_hn: tl.constexpr, + stride_grad_ms: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + precision: tl.constexpr, +): + """ + This computes + Each block handles (BLOCK_SIZE_M, N) of dY and (N, BLOCK_SIZE_K) of W^T, where N is covered by BLOCK_SIZE_N + and also handles the element-wise multiplication part, and writes back (BLOCK_SIZE_M, BLOCK_SIZE_K) of dX each time + """ + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_k >= 0) + tl.assume(stride_xm > 0) + tl.assume(stride_xk == 1) + tl.assume(stride_grad_hm == 32) + tl.assume(stride_grad_hn == 1) + tl.assume(stride_phin == K) + tl.assume(stride_phik == 1) + tl.assume(stride_grad_phin == K) + tl.assume(stride_grad_phik == 1) + tl.assume(stride_grad_ms == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_K % 32 == 0) + tl.assume(BLOCK_SIZE_N == 32) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n_full = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + mask_k = offs_k < K + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + grad_h_ptrs = ( + grad_h_ptr + offs_m[:, None] * stride_grad_hm + offs_n_full[None, :] * stride_grad_hn + ) + grad_h = tl.load( + grad_h_ptrs, mask=mask_m[:, None] & (offs_n_full[None, :] < N), other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + offs_k[None, :] * stride_phik + offs_r = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + grad_ms_ptrs = grad_ms_ptr + offs_r * stride_grad_ms + + phi = tl.load( + phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + grad_ms = tl.load( + grad_ms_ptrs, mask=offs_r < M, other=0.0, cache_modifier=".ca" + ) # (BLOCK_SIZE_M,) + + grad_x = x * (grad_ms * 2 / tl.cast(K, tl.float32))[:, None] + grad_x = tl.dot( + grad_h, phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_k[None, :] * stride_grad_xk + grad_x = grad_x.to(x.dtype) + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_k[None, :]) + + +def scale_config(): + block_m = [128, 256, 512, 1024] + warps = [2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, w, s in itertools.product(block_m, warps, stages): + configs.append(triton.Config({"BLOCK_SIZE_M": m}, num_warps=w, num_stages=s)) + + if os.environ.get("TRITON_SKIP_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=scale_config(), + key=["M"], +) +@triton.jit +def _mhc_scale_fwd_fused( + h_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + a_ptr, # (3,) + b_ptr, # (2n + n^2) + ms_ptr, # (M,) + out_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + M, + n, + stride_hm, + stride_hn, + stride_a, + stride_b, + stride_ms, + stride_out_m, + stride_out_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + eps: tl.constexpr, +): + pid = tl.program_id(0) # 1D grid + + tl.assume(M > 0) + tl.assume(n == 4) + tl.assume(stride_hm == 32) + tl.assume(stride_hn == 1) + tl.assume(stride_out_m == 32) + tl.assume(stride_out_n == 1) + tl.assume(stride_a == 1) + tl.assume(stride_b == 1) + tl.assume(stride_ms == 1) + tl.assume(BLOCK_SIZE_N == 32) + + N = 2 * n + n * n + + offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + cols = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + + # Expand a to BLOCK_SIZE_N length + offs_a = tl.zeros_like(cols) + offs_a = tl.where((cols >= n) & (cols < 2 * n), 1, offs_a) + offs_a = tl.where((cols >= 2 * n) & (cols < 2 * n + n * n), 2, offs_a) + # Pick a[0] from a for the first 4 columns, a[1] for the next 4 columns, and a[2] for the rest of the columns + a = tl.load( + a_ptr + offs_a * stride_a, mask=offs_a < N, other=0.0 + ) # a[2*n + n*n:] is filled with garbage + a = tl.where(cols < N, a, 0.0) # Mask out the garbage values in a + + b = tl.load(b_ptr + cols * stride_b, mask=cols < N, other=0.0) # (BLOCK_SIZE_N,) + ms = tl.load(ms_ptr + offs_m * stride_ms, mask=mask_m, other=0.0) # (BLOCK_SIZE_M,) + # In projection kernel we use split-K so we only have the accumulated ms, + # and now we need to take sqrt on the accumulated ms to obtain the RMSNorm denominator. + rms = tl.sqrt(ms + eps) + + h = tl.load( + h_ptr + offs_m[:, None] * stride_hm + cols[None, :] * stride_hn, + mask=mask_m[:, None], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + + h = a[None, :] * h + h = tl.fma( + h, 1.0 / rms[:, None], b[None, :] + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N), where the first 2n columns are H_pre and H_post, and the rest are H_res + h_sigmoid_pre = tl.sigmoid(h) + h_sigmold_post = 2 * h_sigmoid_pre + + # Use this mask to select h[:, :2n] + h = tl.where(cols[None, :] < n, h_sigmoid_pre, h) + h = tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), h_sigmold_post, h) + + tl.store( + out_ptr + offs_m[:, None] * stride_out_m + cols[None, :] * stride_out_n, + h, + mask=mask_m[:, None], + ) + + +@triton.autotune( + configs=scale_config(), + key=["M"], + reset_to_zero=["grad_a_ptr", "grad_b_ptr"], +) +@triton.jit +def _mhc_scale_bwd_fused( + grad_out_ptr, + out_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + grad_h_ptr, + h_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + grad_a_ptr, + a_ptr, # (3,) + grad_b_ptr, # (2n + n^2,) + grad_ms_ptr, + ms_ptr, # (M,) + M, + n, + stride_grad_out_m, + stride_grad_out_n, + stride_out_m, + stride_out_n, + stride_grad_hm, + stride_grad_hn, + stride_hm, + stride_hn, + stride_grad_a, + stride_a, + stride_grad_b, + stride_grad_ms, + stride_ms, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + eps: tl.constexpr, +): + pid = tl.program_id(0) + + tl.assume(M > 0) + tl.assume(n == 4) + tl.assume(stride_grad_out_m == 32) + tl.assume(stride_grad_out_n == 1) + tl.assume(stride_out_m == 32) + tl.assume(stride_out_n == 1) + tl.assume(stride_grad_hm == 32) + tl.assume(stride_grad_hn == 1) + tl.assume(stride_hm == 32) + tl.assume(stride_hn == 1) + tl.assume(stride_grad_a == 1) + tl.assume(stride_a == 1) + tl.assume(stride_grad_b == 1) + tl.assume(stride_grad_ms == 1) + tl.assume(stride_ms == 1) + tl.assume(BLOCK_SIZE_N == 32) + + N = 2 * n + n * n + + offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + cols = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + mask_n = cols < N + + # Expand a to BLOCK_SIZE_N length + offs_a = tl.zeros_like(cols) + offs_a = tl.where((cols >= n) & (cols < 2 * n), 1, offs_a) + offs_a = tl.where((cols >= 2 * n) & (cols < 2 * n + n * n), 2, offs_a) + # Pick a[0] from a for the first 4 columns, a[1] for the next 4 columns, and a[2] for the rest of the columns + a = tl.load( + a_ptr + offs_a * stride_a, mask=offs_a < 3, other=0.0 + ) # a[2*n + n*n:] is filled with garbage + a = tl.where(cols < N, a, 0.0) # Mask out the garbage values in a + + ms_offsets = offs_m + ms_mask = mask_m + ms = tl.load(ms_ptr + ms_offsets * stride_ms, mask=ms_mask, other=1.0) # (BLOCK_SIZE_M,) + rms = tl.sqrt(ms + eps) + + grad_out = tl.load( + grad_out_ptr + offs_m[:, None] * stride_grad_out_m + cols[None, :] * stride_grad_out_n, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + out = tl.load( + out_ptr + offs_m[:, None] * stride_out_m + cols[None, :] * stride_out_n, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + h = tl.load( + h_ptr + offs_m[:, None] * stride_hm + cols[None, :] * stride_hn, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + + # Gradiient of H before H_pre and H_post go through sigmoid + grad_out_out = grad_out * out + grad_h_pre = grad_out_out * (1 - out) + grad_h_post = grad_out_out * 0.5 * (2 - out) + grad_h = grad_out + grad_h = tl.where(cols[None, :] < n, grad_h_pre, grad_h) + grad_h = tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_h_post, grad_h) + + grad_a = tl.sum(h * grad_h / rms[:, None], axis=0).to(a.dtype) + # Write grad_a[0:4].sum to grad_a_ptr[0], grad_a[4:8].sum to grad_a_ptr[1], and grad_a[8:24].sum to grad_a_ptr[2] + tl.atomic_add(grad_a_ptr, tl.where(cols[None, :] < n, grad_a, 0.0).sum(), sem="relaxed") + tl.atomic_add( + grad_a_ptr + stride_grad_a, + tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_a, 0.0).sum(), + sem="relaxed", + ) + tl.atomic_add( + grad_a_ptr + 2 * stride_grad_a, + tl.where((cols[None, :] >= 2 * n) & (cols[None, :] < 2 * n + n * n), grad_a, 0.0).sum(), + sem="relaxed", + ) + + grad_b = tl.sum(grad_h, axis=0).to(a.dtype) + tl.atomic_add(grad_b_ptr + cols * stride_grad_b, grad_b, mask=cols < N, sem="relaxed") + + grad_rms = (tl.sum((-grad_h * h * a[None, :]), axis=1) / (rms * rms)).to(rms.dtype) + grad_ms = grad_rms / (2 * rms) + tl.store(grad_ms_ptr + ms_offsets * stride_grad_ms, grad_ms, mask=ms_mask) + + grad_h = a[None, :] * grad_h / rms[:, None] + tl.store( + grad_h_ptr + offs_m[:, None] * stride_grad_hm + cols[None, :] * stride_grad_hn, + grad_h, + mask=mask_m[:, None] & mask_n[None, :], + ) + + +def sinkhorn_config(): + block = [128, 256, 512, 1024] + warps = [2, 4, 8] + stages = [1, 2, 3, 4] + configs = [] + for b, w, s in itertools.product(block, warps, stages): + configs.append(triton.Config({"BLOCK_SIZE": b}, num_warps=w, num_stages=s)) + if os.environ.get("TRITON_SKIP_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_fwd_fused_recompute( + x_ptr, # (M, n*n) + output_ptr, + stride_xm, + stride_xn, + stride_out_m, + stride_out_n, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + """ + Fused Sinkhorn-Knopp algorithm to convert a matrix into a doubly stochastic matrix. + Calculated in log space for numerical stability. + + :param X: a tensor of shape (s, b, n, n), input + :param output_ptr: a tensor of shape (s, b, n, n), output + :param hist_f_ptr: a tensor of shape (iters+1, s, b, n), to store f history + :param hist_g_ptr: a tensor of shape (iters+1, s, b, n), to store g history + :param s: sequence length + :param b: batch size + :param BLOCK_SIZE: size of the blocks to process + :param iters: number of Sinkhorn iterations + """ + pid = tl.program_id(0) # 1D grid + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + for _ in range(iters): + # Update f: logsumexp over the column dimension (1) + f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) + f_max = tl.max(f, axis=2) + f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns + f = log_mu - f - f_max + + # Update g: logsumexp over the row dimension (2) + g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) + g_max = tl.max(g, axis=1) + g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows + g = log_nu - g - g_max + + log_P = f[:, :, None] + x + g[:, None, :] + log_P = tl.reshape( + log_P, + ( + BATCH_SIZE, + n * n, + ), + ) + P = tl.exp(log_P) + + output_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + tl.store(output_ptrs, P, mask=mask_batch[:, None]) + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_bwd_fused_recompute( + grad_out_ptr, + output_ptr, + grad_x_ptr, + x_ptr, + hist_f_ptr, + hist_g_ptr, + stride_grad_out_m, + stride_grad_out_n, + stride_out_m, + stride_out_n, + stride_grad_xm, + stride_grad_xn, + stride_xm, + stride_xn, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + """ + Backward pass for the fused Sinkhorn-Knopp algorithm with intermediate values recomputed. + + :param grad_out_ptr: pointer to the gradient of the output + :param grad_x_ptr: pointer to the gradient of the input + :param x_ptr: pointer to the input tensor + :param hist_f_ptr: pointer to the tensor storing f history, (iters+1, s, b, n) + :param hist_g_ptr: pointer to the tensor storing g history, (iters+1, s, b, n) + :param s: sequence length + :param b: batch size + :param n: size of the submatrix (n x n) + :param BLOCK_SIZE: size of the blocks to process + :param iters: number of iterations + """ + pid = tl.program_id(0) # 1D grid + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + offs_n_hist = tl.arange(0, n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + exp_x = tl.exp(x) + + P_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + P = tl.load(P_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + P = tl.reshape(P, (BATCH_SIZE, n, n)) + + grad_out_ptrs = ( + grad_out_ptr + + offs_batch[:, None] * stride_grad_out_m + + offs_nn[None, :] * stride_grad_out_n + ) + grad_out = tl.load(grad_out_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + grad_out = tl.reshape(grad_out, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + sbn = M * n + + # Recompute the full history of f and g + log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + for iter_idx in range(iters): + # Update f: logsumexp over the column dimension (1) + f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) + f_max = tl.max(f, axis=2) + f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns + f = log_mu - f - f_max + + f_hist_ptrs = ( + hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + + # Update g: logsumexp over the row dimension (2) + g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) + g_max = tl.max(g, axis=1) + g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows + g = log_nu - g - g_max + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + # Backward pass + grad_log_P = grad_out * P # (BATCH_SIZE, n, n) + zeros = tl.zeros_like(grad_log_P) + grad_g = tl.sum(grad_log_P, axis=1) # (BATCH_SIZE, n) + grad_x = grad_log_P + + g_hist_ptrs = hist_g_ptr + iters * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + g = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g = tl.reshape(g, (BATCH_SIZE, n)) + + for iter_idx in range(iters, 0, -1): + f_hist_ptrs = hist_f_ptr + iter_idx * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + f = tl.load(f_hist_ptrs, mask=mask_batch[:, None], other=0.0) + f = tl.reshape(f, (BATCH_SIZE, n)) + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx - 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + g_next = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g_next = tl.reshape(g_next, (BATCH_SIZE, n)) + + term_g = -grad_g[:, None, :] * tl.exp(f[:, :, None] + g[:, None, :]) * exp_x + grad_f = tl.sum(term_g + grad_log_P, axis=2) # (BATCH_SIZE, n) + # Only the last iteration's f will contribute to gradients with both grad_g1 and grad_log_P + grad_log_P = zeros # Zero out grad_log_P for next iterations + + g = g_next + + term_f = -grad_f[:, :, None] * tl.exp(f[:, :, None] + g[:, None, :]) * exp_x + grad_g = tl.sum(term_f, axis=1) # (BATCH_SIZE, n) + + grad_x += term_f + term_g + + grad_x_ptrs = ( + grad_x_ptr + offs_batch[:, None] * stride_grad_xm + offs_nn[None, :] * stride_grad_xn + ) + tl.store( + grad_x_ptrs, + tl.reshape( + grad_x, + ( + BATCH_SIZE, + n * n, + ), + ), + mask=mask_batch[:, None], + ) + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_fwd_fused( + x_ptr, # (M, n*n) + output_ptr, + hist_f_ptr, + hist_g_ptr, # Assume this is contiguous and laid out as (iters+1, M, n) + stride_xm, + stride_xn, + stride_out_m, + stride_out_n, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + """ + Fused Sinkhorn-Knopp algorithm to convert a matrix into a doubly stochastic matrix. + Calculated in log space for numerical stability. + + :param X: a tensor of shape (s, b, n, n), input + :param output_ptr: a tensor of shape (s, b, n, n), output + :param hist_f_ptr: a tensor of shape (iters+1, s, b, n), to store f history + :param hist_g_ptr: a tensor of shape (iters+1, s, b, n), to store g history + :param s: sequence length + :param b: batch size + :param BLOCK_SIZE: size of the blocks to process + :param iters: number of Sinkhorn iterations + """ + pid = tl.program_id(0) # 1D grid + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + offs_n_hist = tl.arange(0, n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + sbn = M * n + + # Store the initial f and g to history + f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + for iter_idx in range(iters): + # Update f: logsumexp over the column dimension (1) + f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) + f_max = tl.max(f, axis=2) + f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns + f = log_mu - f - f_max + + f_hist_ptrs = ( + hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + + # Update g: logsumexp over the row dimension (2) + g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) + g_max = tl.max(g, axis=1) + g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows + g = log_nu - g - g_max + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + log_P = f[:, :, None] + x + g[:, None, :] + log_P = tl.reshape( + log_P, + ( + BATCH_SIZE, + n * n, + ), + ) + P = tl.exp(log_P) + + output_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + tl.store(output_ptrs, P, mask=mask_batch[:, None]) + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_bwd_fused( + grad_out_ptr, + output_ptr, + grad_x_ptr, + x_ptr, + hist_f_ptr, + hist_g_ptr, # Assume this is contiguous and laid out as (iters+1, M, n) + stride_grad_out_m, + stride_grad_out_n, + stride_out_m, + stride_out_n, + stride_grad_xm, + stride_grad_xn, + stride_xm, + stride_xn, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + """ + Backward pass for the fused Sinkhorn-Knopp algorithm with intermediate values recomputed. + + :param grad_out_ptr: pointer to the gradient of the output + :param grad_x_ptr: pointer to the gradient of the input + :param x_ptr: pointer to the input tensor + :param hist_f_ptr: pointer to the tensor storing f history, (iters+1, s, b, n) + :param hist_g_ptr: pointer to the tensor storing g history, (iters+1, s, b, n) + :param s: sequence length + :param b: batch size + :param n: size of the submatrix (n x n) + :param BLOCK_SIZE: size of the blocks to process + :param iters: number of iterations + """ + pid = tl.program_id(0) # 1D grid + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + offs_n_hist = tl.arange(0, n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + exp_x = tl.exp(x) + + P_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + P = tl.load(P_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + P = tl.reshape(P, (BATCH_SIZE, n, n)) + + grad_out_ptrs = ( + grad_out_ptr + + offs_batch[:, None] * stride_grad_out_m + + offs_nn[None, :] * stride_grad_out_n + ) + grad_out = tl.load(grad_out_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + grad_out = tl.reshape(grad_out, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + sbn = M * n + + # Backward pass + grad_log_P = grad_out * P # (BATCH_SIZE, n, n) + zeros = tl.zeros_like(grad_log_P) + grad_g = tl.sum(grad_log_P, axis=1) # (BATCH_SIZE, n) + grad_x = grad_log_P + + g_hist_ptrs = hist_g_ptr + iters * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + g = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g = tl.reshape(g, (BATCH_SIZE, n)) + + for iter_idx in range(iters, 0, -1): + f_hist_ptrs = hist_f_ptr + iter_idx * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + f = tl.load(f_hist_ptrs, mask=mask_batch[:, None], other=0.0) + f = tl.reshape(f, (BATCH_SIZE, n)) + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx - 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + g_next = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g_next = tl.reshape(g_next, (BATCH_SIZE, n)) + + term_g = -grad_g[:, None, :] * tl.exp(f[:, :, None] + g[:, None, :]) * exp_x + grad_f = tl.sum(term_g + grad_log_P, axis=2) # (BATCH_SIZE, n) + # Only the last iteration's f will contribute to gradients with both grad_g1 and grad_log_P + grad_log_P = zeros # Zero out grad_log_P for next iterations + + g = g_next + + term_f = -grad_f[:, :, None] * tl.exp(f[:, :, None] + g[:, None, :]) * exp_x + grad_g = tl.sum(term_f, axis=1) # (BATCH_SIZE, n) + + grad_x += term_f + term_g + + grad_x_ptrs = ( + grad_x_ptr + offs_batch[:, None] * stride_grad_xm + offs_nn[None, :] * stride_grad_xn + ) + tl.store( + grad_x_ptrs, + tl.reshape( + grad_x, + ( + BATCH_SIZE, + n * n, + ), + ), + mask=mask_batch[:, None], + ) + + +def aggregate_config(): + block_m = [32, 64] + block_c = [32, 64] + warps = [2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + if os.environ.get("TRITON_SKIP_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=aggregate_config(), + key=["M", "C"], +) +@triton.jit +def _mhc_aggregate_fwd( + x_ptr, # # (M, C, n) + H_pre_ptr, # (M, n) + output_ptr, # (M, C) + M, + C, + n: tl.constexpr, + stride_xm, + stride_xCn, + stride_output_m, + stride_output_c, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, +): + """ + output = x @ H_pre: (M, C, n) @ (M, n, 1) = (M, C, 1) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_output_m > 0 and stride_output_c == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + offs_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_pre = tl.load( + H_pre_ptr + offs_H_pre, mask=offs_H_pre < M * n, other=0.0, cache_modifier=".ca" + ) # (BLOCK_SIZE_M * n) + H_pre = H_pre.reshape(BLOCK_SIZE_M, 2, 2) + H_pre01, H_pre23 = tl.split(H_pre) + H_pre0, H_pre1 = tl.split(H_pre01) + H_pre2, H_pre3 = tl.split(H_pre23) # (BLOCK_SIZE_M, 1) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + + x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) + x01, x23 = tl.split(x) + x0, x1 = tl.split(x01) + x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C) + + # x @ H_pre: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) + # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: + # x @ H_pre = x[:, :, 0] * H_pre[:, 0] + # + x[:, :, 1] * H_pre[:, 1] + # + x[:, :, 2] * H_pre[:, 2] + # + x[:, :, 3] * H_pre[:, 3] + out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) + out_acc = tl.fma(x0, H_pre0[:, None], out_acc) + out_acc = tl.fma(x1, H_pre1[:, None], out_acc) + out_acc = tl.fma(x2, H_pre2[:, None], out_acc) + out_acc = tl.fma(x3, H_pre3[:, None], out_acc) + + out = out_acc.to(x.dtype) + + output_ptrs = output_ptr + offs_m[:, None] * stride_output_m + offs_c[None, :] * stride_output_c + tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_c[None, :]) + + +@triton.autotune(configs=aggregate_config(), key=["M", "C"], reset_to_zero=["grad_H_pre_ptr"]) +@triton.jit +def _mhc_aggregate_bwd( + grad_output_ptr, # (M, C) + H_pre_ptr, + grad_H_pre_ptr, + x_ptr, + grad_x_ptr, # # (M, C, n) + M, + C, + n: tl.constexpr, + stride_grad_output_m, + stride_grad_output_c, + stride_xm, + stride_xCn, + stride_grad_xm, + stride_grad_xCn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, + precision: tl.constexpr, +): + """ + Forward: + out = x @ H_pre: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) + Backward: + grad_H_pre = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) = (BLOCK_SIZE_M, n, 1) + grad_H_pre.T = grad_output.T @ x: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + which is easier to compute since transposing grad_H_pre and grad_output is just view change + grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) + tl.assume(stride_grad_output_m > 0 and stride_grad_output_c == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + grad_output_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_c[None, :] * stride_grad_output_c + ) + grad_output = tl.load( + grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + + grad_H_pre = tl.dot( + tl.reshape(grad_output, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + input_precision=precision, + out_dtype=tl.float32, + ) + grad_H_pre = tl.reshape(grad_H_pre, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre + tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") + + H_pre_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_pre = tl.load( + H_pre_ptr + H_pre_offs, mask=H_pre_offs < M * n, other=0.0, cache_modifier=".ca" + ) # (BLOCK_SIZE_M * n) + H_pre = tl.reshape(H_pre, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) + + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + tl.store( + grad_x_ptrs, + grad_x, + mask=mask_m[:, None] & mask_cn[None, :], + ) + + +def expand_combine_config(): + block_m = [32, 64] + block_c = [32, 64] + warps = [2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + if os.environ.get("TRITON_SKIP_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], +) +@triton.jit +def _mhc_expand_combine_fwd( + f_ptr, # (M, C) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + output_ptr, # # (M, C, n) + M, + C, + n: tl.constexpr, + stride_fm, + stride_fc, + stride_xm, + stride_xCn, + stride_output_m, + stride_output_Cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, +): + """ + output = f @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_output_m > 0 and stride_output_Cn == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load( + H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca" + ) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + # Residual connection path: res_out = f @ H_post: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # Due to broadcasting, it's equivalent to a multiplicaiton + out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca" + ) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # Manifold connection path: manifold_out = H_res @ x: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: + # x @ H_res = x[:, :, 0] @ H_res[:, 0, :] + # + x[:, :, 1] @ H_res[:, 1, :] + # + x[:, :, 2] @ H_res[:, 2, :] + # + x[:, :, 3] @ H_res[:, 3, :] + + x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) + x01, x23 = tl.split( + x_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + x0, x1 = tl.split(x01) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + H_resT = tl.reshape(tl.trans(H_res, (0, 2, 1)), (BLOCK_SIZE_M, n, 2, 2)) + H_res01, H_res23 = tl.split(H_resT) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc) + out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc) + out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc) + out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc) + + out = out_acc.to(x.dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + output_ptrs = ( + output_ptr + offs_m[:, None] * stride_output_m + offs_cn[None, :] * stride_output_Cn + ) + tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], + reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr"], +) +@triton.jit +def _mhc_expand_combine_bwd( + grad_output_ptr, # (M, C, n) + f_ptr, # (M, C) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + grad_H_post_ptr, # (M, n) + grad_f_ptr, # (M, C) + grad_H_res_ptr, # (M, n, n) + grad_x_ptr, # (M, C, n) + M, + C, + n: tl.constexpr, + stride_grad_output_m, + stride_grad_output_Cn, + stride_fm, + stride_fc, + stride_xm, + stride_xCn, + stride_grad_fm, + stride_grad_fc, + stride_grad_xm, + stride_grad_xCn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, + precision: tl.constexpr, +): + """ + Each block + It reads + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of f, which is the output of the attention / FFN module + - (BLOCK_SIZE_M, n) of H_post, which is applied for the transformation of the attention / FFN output + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of x, which is the skip connection's input + - (BLOCK_SIZE_M, n*n) of H_res, which is applied for the transformation of the skip connection + and writes + - (BLOCK_SIZE_M, n) of grad_H_post + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of grad_f + - (BLOCK_SIZE_M, n, n) of grad_H_res + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of grad_x + + Forward: + out = f @ H_post + x @ H_res + Backward: + GEMM: + grad_H_post = f.T @ grad_output: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + Not GEMM: + grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) + grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) + tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) + tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 + ) # (BLOCK_SIZE_M, n, n) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + grad_out_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_cn[None, :] * stride_grad_output_Cn + ) + grad_out = tl.load( + grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + grad_out = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.dot( + tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post + tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + grad_H_res = tl.dot( + tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 + ) # (BLOCK_SIZE_M, n, n) + grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) + offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res + tl.atomic_add( + grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" + ) + + grad_out_reshape = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + grad_out01, grad_out23 = tl.split( + grad_out_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + grad_out0, grad_out1 = tl.split( + grad_out01 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_out2, grad_out3 = tl.split( + grad_out23 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) + # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: + # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) + # + grad_out[:, :, 1] @ H_post.T[:, 1, :] + # + grad_out[:, :, 2] @ H_post.T[:, 2, :] + # + grad_out[:, :, 3] @ H_post.T[:, 3, :] + # where H_post.T[:, i, :] = H_post[:, :, i] + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) + H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) + H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + + grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) + # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) + grad_f = grad_f_acc.to(f.dtype) + + grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc + tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) + + # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul + # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] + # + grad_out[:, :, 1] @ H_res.T[:, 1, :] + # + grad_out[:, :, 2] @ H_res.T[:, 2, :] + # + grad_out[:, :, 3] @ H_res.T[:, 3, :] + # where H_res.T[:, i, :] = H_res[:, :, i] + # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] + + H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) + H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) + + grad_x = grad_x_acc.to(x.dtype) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], +) +@triton.jit +def _mhc_expand_combine_with_bias_fwd( + f_ptr, # (M, C) + bias_ptr, # (C,) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + output_ptr, # # (M, C, n) + M, + C, + n: tl.constexpr, + stride_fm, + stride_fc, + stride_bias, + stride_xm, + stride_xCn, + stride_output_m, + stride_output_Cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, +): + """ + output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_bias == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_output_m > 0 and stride_output_Cn == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + + offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load( + H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca" + ) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + # Residual connection path: res_out = f @ H_post + bias @ H_post: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # Due to broadcasting, it's equivalent to a multiplicaiton + out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) + out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca" + ) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # Manifold connection path: manifold_out = H_res @ x: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: + # x @ H_res = x[:, :, 0] @ H_res[:, 0, :] + # + x[:, :, 1] @ H_res[:, 1, :] + # + x[:, :, 2] @ H_res[:, 2, :] + # + x[:, :, 3] @ H_res[:, 3, :] + + x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) + x01, x23 = tl.split( + x_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + x0, x1 = tl.split(x01) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + H_resT = tl.reshape(tl.trans(H_res, (0, 2, 1)), (BLOCK_SIZE_M, n, 2, 2)) + H_res01, H_res23 = tl.split(H_resT) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc) + out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc) + out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc) + out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc) + + out = out_acc.to(x.dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + output_ptrs = ( + output_ptr + offs_m[:, None] * stride_output_m + offs_cn[None, :] * stride_output_Cn + ) + tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], + reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr", "grad_bias_ptr"], +) +@triton.jit +def _mhc_expand_combine_with_bias_bwd( + grad_output_ptr, # (M, C, n) + f_ptr, # (M, C) + bias_ptr, # (C,) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + grad_H_post_ptr, # (M, n) + grad_f_ptr, # (M, C) + grad_bias_ptr, # (C,) + grad_H_res_ptr, # (M, n, n) + grad_x_ptr, # (M, C, n) + M, + C, + n: tl.constexpr, + stride_grad_output_m, + stride_grad_output_Cn, + stride_fm, + stride_fc, + stride_bias, + stride_xm, + stride_xCn, + stride_grad_fm, + stride_grad_fc, + stride_grad_bias, + stride_grad_xm, + stride_grad_xCn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, + precision: tl.constexpr, +): + """ + Each block + It reads + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of f, which is the output of the attention / FFN module + - (BLOCK_SIZE_M, n) of H_post, which is applied for the transformation of the attention / FFN output + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of x, which is the skip connection's input + - (BLOCK_SIZE_M, n*n) of H_res, which is applied for the transformation of the skip connection + and writes + - (BLOCK_SIZE_M, n) of grad_H_post + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of grad_f + - (BLOCK_SIZE_M, n, n) of grad_H_res + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of grad_x + + Forward: + out = f @ H_post + x @ H_res + Backward: + GEMM: + grad_H_post = f.T @ grad_output: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + Not GEMM: + grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) + grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_bias == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) + tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) + tl.assume(stride_grad_bias == 1) + tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + + H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 + ) # (BLOCK_SIZE_M, n, n) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + grad_out_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_cn[None, :] * stride_grad_output_Cn + ) + grad_out = tl.load( + grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + grad_out = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.dot( + tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.dot( + tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + acc=grad_H_post, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post + tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + grad_H_res = tl.dot( + tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 + ) # (BLOCK_SIZE_M, n, n) + grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) + offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res + tl.atomic_add( + grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" + ) + + grad_out_reshape = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + grad_out01, grad_out23 = tl.split( + grad_out_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + grad_out0, grad_out1 = tl.split( + grad_out01 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_out2, grad_out3 = tl.split( + grad_out23 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) + # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: + # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) + # + grad_out[:, :, 1] @ H_post.T[:, 1, :] + # + grad_out[:, :, 2] @ H_post.T[:, 2, :] + # + grad_out[:, :, 3] @ H_post.T[:, 3, :] + # where H_post.T[:, i, :] = H_post[:, :, i] + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) + H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) + H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + + grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) + # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) + grad_f = grad_f_acc.to(f.dtype) + + grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc + tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) + + grad_bias = grad_f.sum(axis=0) # (BLOCK_SIZE_C,) + grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias + tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") + + # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul + # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] + # + grad_out[:, :, 1] @ H_res.T[:, 1, :] + # + grad_out[:, :, 2] @ H_res.T[:, 2, :] + # + grad_out[:, :, 3] @ H_res.T[:, 3, :] + # where H_res.T[:, i, :] = H_res[:, :, i] + # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] + + H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) + H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) + + grad_x = grad_x_acc.to(x.dtype) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py index d86cededd7..6d3141253d 100644 --- a/transformer_engine/pytorch/triton/__init__.py +++ b/transformer_engine/pytorch/triton/__init__.py @@ -3,3 +3,4 @@ # See LICENSE for license information. """PyTorch wrappers for Triton kernels.""" +from transformer_engine.pytorch.triton import mhc diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py new file mode 100644 index 0000000000..5a994e038e --- /dev/null +++ b/transformer_engine/pytorch/triton/mhc.py @@ -0,0 +1,813 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from flask import g +import torch +import triton +import triton.language as tl + +from transformer_engine.common.triton.mhc import ( + _mhc_scale_fwd_fused, + _mhc_scale_bwd_fused, + _mhc_expand_combine_with_bias_fwd, + _mhc_expand_combine_with_bias_bwd, + _mhc_expand_combine_fwd, + _mhc_expand_combine_bwd, + _mhc_aggregate_fwd, + _mhc_aggregate_bwd, + _mhc_projection_fwd_fused, + _mhc_projection_bwd_fused, + _mhc_sinkhorn_fwd_fused, + _mhc_sinkhorn_fwd_fused_recompute, + _mhc_sinkhorn_bwd_fused, + _mhc_sinkhorn_bwd_fused_recompute, +) + + +class mHCProjectionOp(torch.autograd.Function): + """ + Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, seciton 4.3.1 of the DeepSeek mHC paper) + :param x: input tensor of shape (M, K), where M=s*b is the batch size and K=nC is the hidden dimension after expansion. + :param phi: projection matrix of shape (N, K), where N=n+n+n*n + + H = x @ phi^T: (M, K) @ (K, N) -> (M, N), which is padded to (M, 32) for better memory access pattern in the next kernels. + ms = mean(x^2, dim=-1): (M,) + + :return: H of shape (M, 32), where only the first N elements in the last dimension are valid + :return: ms of shape (M,), which is the mean square used for RMSNorm in the next kernel + + Note: the current implementation only supports n=4 + """ + + @staticmethod + def forward(ctx, x, phi, use_tf32=True): + x = x.contiguous() + + ctx.use_tf32 = use_tf32 + ctx.dtype = x.dtype + + M, K = x.shape + device = x.device + + N = phi.shape[0] + assert N == 24, "Currently only n=4 is supported, which means phi should have 24 rows" + + # Pad H to (s, b, 32) for better memory access pattern in the kernel, but only the first N elements in the last dimension are valid + H = torch.zeros((M, 32), device=device, dtype=torch.float32) + ms = torch.zeros( + (M), device=device, dtype=torch.float32 + ) # Mean square for s, used to compute RMSNorm in the next kernel + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + + if use_tf32: + _mhc_projection_fwd_fused[grid]( + x_ptr=x, # (M, K) + phi_ptr=phi, # (N, K) + h_ptr=H, # (M, 32) + ms_ptr=ms, # (M,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_phin=K, + stride_phik=1, + stride_hm=32, + stride_hn=1, + stride_ms=1, + BLOCK_SIZE_N=32, + precision="tf32", + ) + else: + _mhc_projection_fwd_fused[grid]( + x_ptr=x, # (M, K) + phi_ptr=phi, # (N, K) + h_ptr=H, # (M, 32) + ms_ptr=ms, # (M,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_phin=K, + stride_phik=1, + stride_hm=32, + stride_hn=1, + stride_ms=1, + BLOCK_SIZE_N=32, + precision="ieee", + ) + + ctx.save_for_backward(x, phi, ms) + ctx.phi_dtype = phi.dtype + + return H.to(ctx.dtype), ms # Keep ms in fp32 + + @staticmethod + def backward(ctx, grad_H, grad_ms): + x, phi, ms = ctx.saved_tensors + M, K = x.shape + device = x.device + + N = phi.shape[0] + + grad_H = grad_H.contiguous().view(M, -1) + grad_ms = grad_ms.contiguous().view( + M, + ) + ms = ms.contiguous().view( + M, + ) + + grad_x = torch.zeros((M, K), device=device, dtype=x.dtype) + grad_phi = (grad_H.T @ x)[:N, :].to( + ctx.phi_dtype + ) # (2n + n^2, M) @ (M, nC) = (2n + n^2, nC), note that the last dimension of grad_H is already padded to 32 + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + + if ctx.use_tf32: + _mhc_projection_bwd_fused[grid]( + x_ptr=x, + grad_x_ptr=grad_x, # (M, K) + phi_ptr=phi, # (N, K) + grad_h_ptr=grad_H, # (M, 32) + grad_ms_ptr=grad_ms, # (M,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_grad_xm=K, + stride_grad_xk=1, + stride_phin=K, + stride_phik=1, + stride_grad_phin=K, + stride_grad_phik=1, + stride_grad_hm=32, + stride_grad_hn=1, + stride_grad_ms=1, + BLOCK_SIZE_N=32, + precision="tf32", + ) + else: + _mhc_projection_bwd_fused[grid]( + x_ptr=x, + grad_x_ptr=grad_x, # (M, K) + phi_ptr=phi, # (N, K) + grad_h_ptr=grad_H, # (M, 32) + grad_ms_ptr=grad_ms, # (M,), + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_grad_xm=K, + stride_grad_xk=1, + stride_phin=K, + stride_phik=1, + stride_grad_phin=K, + stride_grad_phik=1, + stride_grad_hm=32, + stride_grad_hn=1, + stride_grad_ms=1, + BLOCK_SIZE_N=32, + precision="ieee", + ) + + return grad_x.to(ctx.dtype), grad_phi.to(ctx.dtype), None + + +class mHCScaleFusedOp(torch.autograd.Function): + """ + Fused scale operation to compute the scaled H matrices (see eq. 16-18, section 4.3.1 of the DeepSeek mHC paper) + :param H: input H matrix of shape (M, 32), where M=s*b, and only the first N elements in the last dimension are valid + :param alpha: scaling factor for H, of shape (3,), where + alpha[0] is applied to H[:, 0:n] for H_pre + alpha[1] is applied to H[:, n:2n] for H_post + alpha[2] is applied to H[:, 2n:2n+n*n] for H_res + :param beta: bias term for H, of shape (2*n+n*n,), where + beta[0:n] is applied to H[:, 0:n] for H_pre + beta[n:2n] is applied to H[:, n:2n] for H_post + beta[2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + :param ms: mean square for each row of H from the projection kernel, of shape (M,), used for RMSNorm scaling + :param n: number of hyper connections, where only n=4 is supported in the current implementation + + H_pre = H[:, 0:n] * alpha[0] / sqrt(ms) + beta[0:n] + H_post = H[:, n:2n] * alpha[1] / sqrt(ms) + beta[n:2n] + H_res = H[:, 2n:2n+n*n] * alpha[2] / sqrt(ms) + beta[2n:2n+n*n] + + H_pre = sigmoid(H_pre) + H_post = 2*sigmoid(H_post) + + :return: out of shape (M, 32), where only the first N elements in the last dimension are valid + """ + + @staticmethod + def forward(ctx, H, alpha, beta, ms, n): + assert n == 4, "Only n=4 is supported in this implementation" + + ctx.dtype = H.dtype + H = H.to(torch.float32) + alpha = alpha.to(torch.float32) + beta = beta.to(torch.float32) + ms = ms.to(torch.float32) + + M, _ = H.shape + + H = H.contiguous() + beta = beta.contiguous() + ms = ms.contiguous() + + out = torch.empty( + (M, 32), device=H.device, dtype=H.dtype + ) # Pad the output to 32 in the last dimension + + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),) + + _mhc_scale_fwd_fused[grid]( + h_ptr=H, # (M, N), which is padded to (M, 32) + b_ptr=beta, # (N,) + a_ptr=alpha, # (N,) + ms_ptr=ms, # (M,) + out_ptr=out, # (M, N), which is padded to (M, 32) + M=M, + n=n, + stride_hm=32, + stride_hn=1, + stride_a=1, + stride_b=1, + stride_ms=1, + stride_out_m=32, + stride_out_n=1, # strides for out, which is padded to 32 in the last dimension + BLOCK_SIZE_N=32, + eps=torch.finfo(ms.dtype).eps, + ) + + ctx.save_for_backward(H, alpha, ms, out) + ctx.n = n + + return out.to(ctx.dtype) # Cast back to the original dtype of H + + @staticmethod + def backward(ctx, grad_out): + H, alpha, ms, out = ctx.saved_tensors + n = ctx.n + + grad_out = grad_out.contiguous() + grad_out = grad_out.to(torch.float32) + + M, _ = grad_out.shape + N = 2 * n + n * n + + grad_h = torch.zeros( + (M, 32), device=grad_out.device, dtype=grad_out.dtype + ) # Pad the grad_h to 32 in the last dimension + grad_alpha = torch.zeros((3,), device=grad_out.device, dtype=grad_out.dtype) + grad_beta_padded = torch.zeros((1, 32), device=grad_out.device, dtype=grad_out.dtype) + grad_beta = grad_beta_padded[ + :, :N + ] # Use only the first N elements for grad_beta, the rest are just padding + grad_ms = torch.zeros((M,), device=grad_out.device, dtype=grad_out.dtype) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),) + + _mhc_scale_bwd_fused[grid]( + grad_out_ptr=grad_out, + out_ptr=out, + grad_h_ptr=grad_h, + h_ptr=H, + grad_a_ptr=grad_alpha, + a_ptr=alpha, + grad_b_ptr=grad_beta, + grad_ms_ptr=grad_ms, + ms_ptr=ms, + M=M, + n=n, + stride_grad_out_m=32, + stride_grad_out_n=1, + stride_out_m=32, + stride_out_n=1, + stride_grad_hm=32, + stride_grad_hn=1, + stride_hm=32, + stride_hn=1, + stride_grad_a=1, + stride_a=1, + stride_grad_b=1, + stride_grad_ms=1, + stride_ms=1, + BLOCK_SIZE_N=32, + eps=torch.finfo(ms.dtype).eps, + ) + + return ( + grad_h.to(ctx.dtype), + grad_alpha.to(ctx.dtype), + grad_beta.to(ctx.dtype), + grad_ms.to(ctx.dtype), + None, + ) + + +class mHCSinkhornOp(torch.autograd.Function): + """ + Sinkhorn operation to compute the final H_res matrix (see eq. 19, section 4.3.1 of the DeepSeek mHC paper) + :param H_res: input H_res matrix of shape (M, n*n) + :param n: number of hyper connections, where only n=4 is supported in the current implementation + :param recompute_hist: whether to recompute the intermediate history in the backward pass to save memory + :param iters: number of Sinkhorn iterations, according to the DeepSeek paper 20 is enough for convergence + + Sinkhorn operation conducts iterative normalization process that alternately rescales rows and columns to sum to 1. + This kernel performance this operation in the log space for numerical stability. + + :return: out of shape (s, b, n, n), which is the final H_res after Sinkhorn normalization + """ + + @staticmethod + def forward(ctx, H_res, n=4, recompute_hist=True, iters=20): + assert n == 4, "Only n=4 is supported in this implementation" + + s, b, _, _ = H_res.shape + + ctx.dtype = H_res.dtype + H_res = H_res.to(torch.float32) + + H_res = H_res.contiguous().view(s * b, n * n) + + hist_f, hist_g = None, None + if not recompute_hist: + # History buffers: (iters+1, s, b, n) + hist_f = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + hist_g = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + H_res_out = torch.empty_like(H_res) # (s*b, n*n) + + grid = lambda meta: (triton.cdiv(s * b * n * n, meta["BLOCK_SIZE"]),) + + if recompute_hist: + _mhc_sinkhorn_fwd_fused_recompute[grid]( + x_ptr=H_res, + output_ptr=H_res_out, + stride_xm=n * n, + stride_xn=1, + stride_out_m=n * n, + stride_out_n=1, + M=s * b, + n=n, + iters=iters, + ) + else: + _mhc_sinkhorn_fwd_fused[grid]( + x_ptr=H_res, + output_ptr=H_res_out, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_xm=n * n, + stride_xn=1, + stride_out_m=n * n, + stride_out_n=1, + M=s * b, + n=n, + iters=iters, + ) + + if recompute_hist: + ctx.save_for_backward(H_res, H_res_out) + else: + ctx.save_for_backward(H_res, H_res_out, hist_f, hist_g) + ctx.recompute_hist = recompute_hist + ctx.iters = iters + ctx.n = n + + H_res_out = H_res_out.view(s, b, n, n) + return H_res_out.to(ctx.dtype) # Cast back to the original dtype of H + + @staticmethod + def backward(ctx, grad_out): + + s, b, n, _ = grad_out.shape + M = s * b + + hist_f, hist_g = None, None + recompute_hist = ctx.recompute_hist + iters = ctx.iters + if recompute_hist: + H_res, H_res_out = ctx.saved_tensors + hist_f = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + hist_g = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + else: + H_res, H_res_out, hist_f, hist_g = ctx.saved_tensors + + iters = ctx.iters + n = ctx.n + + grad_res_out = grad_out.clone().contiguous().view(M, n * n) + + grad_res = torch.empty_like(H_res) + + grid = lambda meta: (triton.cdiv(M * n * n, meta["BLOCK_SIZE"]),) + + if recompute_hist: + _mhc_sinkhorn_bwd_fused_recompute[grid]( + grad_out_ptr=grad_res_out, + output_ptr=H_res_out, + grad_x_ptr=grad_res, + x_ptr=H_res, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_grad_out_m=n * n, + stride_grad_out_n=1, + stride_out_m=n * n, + stride_out_n=1, + stride_grad_xm=n * n, + stride_grad_xn=1, + stride_xm=n * n, + stride_xn=1, + M=M, + n=n, + iters=iters, + ) + else: + _mhc_sinkhorn_bwd_fused[grid]( + grad_out_ptr=grad_res_out, + output_ptr=H_res_out, + grad_x_ptr=grad_res, + x_ptr=H_res, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_grad_out_m=n * n, + stride_grad_out_n=1, + stride_out_m=n * n, + stride_out_n=1, + stride_grad_xm=n * n, + stride_grad_xn=1, + stride_xm=n * n, + stride_xn=1, + M=M, + n=n, + iters=iters, + ) + + grad_res = grad_res.view(s, b, n, n) + + return grad_res.to(ctx.dtype), None, None, None + + +class mHCAggregateOp(torch.autograd.Function): + """ + Aggregate operation to merge n activation streams to one (see section 4.3.1 of the DeepSeek mHC paper) + :param x: input activation tensor of shape (s, b, C, n), + where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections. Note that C is equal to the original hidden dimension divided by n. + :param H_pre: input H_pre matrix of shape (s, b, n) + :param n: number of hyper connections, where only n=4 is supported in the current implementation + :param use_tf32: whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. + This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + + out = x @ H_pre: (s, b, C, n) @ (s, b, n, 1) -> (s, b, C, 1) -> (s, b, C) after squeezing the last dimension + + :return: out of shape (s, b, C), which is the aggregated output after merging n hyper connections + """ + + @staticmethod + def forward(ctx, x, H_pre, n, use_tf32=True): + assert n == 4, "Only n=4 is supported in this implementation" + + x = x.contiguous() + H_pre = H_pre.contiguous() + + s, b, C, n = x.shape + nC = n * C + M = s * b + + out = torch.empty((s, b, C), device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + _mhc_aggregate_fwd[grid]( + x_ptr=x, + H_pre_ptr=H_pre, + output_ptr=out, + M=M, + C=C, + n=n, + stride_xm=nC, + stride_xCn=1, + stride_output_m=C, + stride_output_c=1, + ) + + ctx.save_for_backward(x, H_pre) + ctx.n = n + ctx.use_tf32 = use_tf32 + + return out + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + + x, H_pre = ctx.saved_tensors + n = ctx.n + + s, b, C, n = x.shape + nC = n * C + assert n == 4, "Only n=4 is supported in this implementation" + M = s * b + + grad_x = torch.empty_like(x) + grad_H_pre = torch.zeros( + (s, b, n), dtype=torch.float32, device=H_pre.device + ) # We need to use atomic_add for this so we need higher precision + + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + if ctx.use_tf32: + _mhc_aggregate_bwd[grid]( + grad_output_ptr=grad_output, + H_pre_ptr=H_pre, + grad_H_pre_ptr=grad_H_pre, + x_ptr=x, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=C, + stride_grad_output_c=1, + stride_xm=nC, + stride_xCn=1, + stride_grad_xm=nC, + stride_grad_xCn=1, + precision="tf32", + ) + else: + _mhc_aggregate_bwd[grid]( + grad_output_ptr=grad_output, + H_pre_ptr=H_pre, + grad_H_pre_ptr=grad_H_pre, + x_ptr=x, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=C, + stride_grad_output_c=1, + stride_xm=nC, + stride_xCn=1, + stride_grad_xm=nC, + stride_grad_xCn=1, + precision="ieee", + ) + + grad_H_pre = grad_H_pre.to(H_pre.dtype) # Cast back to the original dtype of H_pre + + return grad_x, grad_H_pre, None, None + + +class mHCExpandCombineOp(torch.autograd.Function): + """ + Expand and combine operation for merging n hyper connections (see section 4.3.1 of the DeepSeek mHC paper) + :param f: input activation tensor of shape (s, b, C), which is the output from the attention / FFN sub-layer in a transformer block + :param bias: optional bias tensor of shape C from the last linear layer, where f + bias is fused in this kernel for better performance + :param H_post: input H_post matrix of shape (s, b, n) + :param x: input activation tensor of shape (s, b, C, n), which is the hyper connection input before the aggregation operation + :param H_res: input H_res matrix of shape (s, b, n) + :param n: number of hyper connections + :param use_tf32: whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. + This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + + out = (f [+ bias]) @ H_post + x @ H_res: (s, b, C, 1) @ (s, b, 1, n) + (s, b, C, n) @ (s, b, n, n) -> (s, b, C, n) + + :return: out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections + """ + + @staticmethod + def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): + assert n == 4, "Only n=4 is supported in this implementation" + + x = x.contiguous() + f = f.contiguous() + if bias is not None: + bias = bias.contiguous() + H_post = H_post.contiguous() + H_res = H_res.contiguous() + + s, b, C, n = x.shape + Cn = C * n + M = s * b + + out = torch.empty((s, b, C, n), device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + if bias is None: + _mhc_expand_combine_fwd[grid]( + f_ptr=f, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + output_ptr=out, + M=M, + C=C, + n=n, + stride_fm=C, + stride_fc=1, + stride_xm=Cn, + stride_xCn=1, + stride_output_m=Cn, + stride_output_Cn=1, + ) + else: + _mhc_expand_combine_with_bias_fwd[grid]( + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + output_ptr=out, + M=M, + C=C, + n=n, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=Cn, + stride_xCn=1, + stride_output_m=Cn, + stride_output_Cn=1, + ) + + ctx.n = n + ctx.have_bias = bias is not None + if bias is not None: + ctx.save_for_backward(f, bias, H_post, x, H_res) + else: + ctx.save_for_backward(f, H_post, x, H_res) + ctx.use_tf32 = use_tf32 + + return out + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + s, b, C, n = grad_output.shape + + if ctx.have_bias: + f, bias, H_post, x, H_res = ctx.saved_tensors + else: + bias = None + f, H_post, x, H_res = ctx.saved_tensors + M = s * b + + grad_f = torch.empty_like(f) + grad_bias = torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None + grad_H_post = torch.zeros_like( + H_post, dtype=torch.float32 + ) # We need to use atomic_add for this so we need higher precision + grad_x = torch.empty_like(x) + grad_H_res = torch.zeros_like( + H_res, dtype=torch.float32 + ) # We need to use atomic_add for this so we need higher precision + + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + if ctx.use_tf32: + if bias is None: + _mhc_expand_combine_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="tf32", + ) + else: + _mhc_expand_combine_with_bias_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_bias_ptr=grad_bias, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_bias=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="tf32", + ) + else: + if bias is None: + _mhc_expand_combine_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="ieee", + ) + else: + _mhc_expand_combine_with_bias_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_bias_ptr=grad_bias, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_bias=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="ieee", + ) + + grad_H_post = grad_H_post.to(H_post.dtype) # Cast back to the original dtype of H_post + grad_H_res = grad_H_res.to(H_res.dtype) # Cast back to the original dtype of H_res + if bias is not None: + grad_bias = grad_bias.to(bias.dtype) + + return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None