diff --git a/tests/jax/test_distributed_permutation.py b/tests/jax/test_distributed_permutation.py new file mode 100644 index 0000000000..5b6d8fec47 --- /dev/null +++ b/tests/jax/test_distributed_permutation.py @@ -0,0 +1,597 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed/sharded execution of MoE permutation primitives. + +Testing Strategy: +================= +MoE permutation is data-dependent - the destination index for each token depends +on how many tokens before it are routed to the same expert. This means: + +1. We CANNOT compare sharded output against global reference directly +2. Instead, we verify that each GPU's LOCAL output is correct according to its + LOCAL routing (which produces LOCAL row_id_map with LOCAL indices) + +For data-parallel MoE without expert parallelism: +- Each GPU has ALL experts replicated +- Each GPU processes a subset of tokens (sharded on token/batch dimension) +- Each GPU computes its own local row_id_map from its local routing_map slice +- Each GPU's output is local and doesn't need to match global output + +These tests verify: +1. Local token_dispatch: sharded input -> local row_id_map -> local permute (forward + backward) +2. Local roundtrip: dispatch + combine recovers original input (forward + backward) +""" + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from distributed_test_base import generate_configs +from utils import assert_allclose, pytest_parametrize_wrapper + +# High-level API with VJP support +from transformer_engine.jax.permutation import ( + token_dispatch, + token_combine, +) + +# Reference implementations from test_permutation.py +from test_permutation import ( + reference_make_row_id_map, + _reference_permute_impl, + _reference_unpermute_impl, + reference_token_combine, +) + +# Dispatch/combine test cases: (num_tokens, num_experts, hidden_size, topk) +# topk = number of experts each token is routed to +# Includes small, medium-large, and largest stress test cases. +ALL_DISPATCH_COMBINE_CASES = [ + (128, 4, 64, 2), + (4096, 32, 1280, 2), + (4096, 256, 4096, 6), +] +DISPATCH_COMBINE_CASES = { + "L0": ALL_DISPATCH_COMBINE_CASES[0:1], + "L2": ALL_DISPATCH_COMBINE_CASES, +} + +# Dispatch/combine with padding test cases: (num_tokens, num_experts, hidden_size, topk, align_size) +ALL_DISPATCH_COMBINE_PADDING_CASES = [ + (128, 4, 64, 2, 8), + (4096, 32, 1280, 2, 128), + (4096, 256, 4096, 6, 16), +] +DISPATCH_COMBINE_PADDING_CASES = { + "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:1], + "L2": ALL_DISPATCH_COMBINE_PADDING_CASES, +} + +# Dtypes for testing +ALL_DTYPES = [jnp.float32, jnp.bfloat16] +DTYPES = { + "L0": [jnp.float32], + "L2": ALL_DTYPES, +} + + +class TestDistributedPermutation: + """Test distributed/sharded execution of MoE permutation primitives. + + These tests validate that custom partitioning produces correct LOCAL results + when inputs are sharded across multiple devices. + + Key insight: With data-parallel MoE, each GPU independently processes its + local tokens. The row_id_map is generated locally and contains LOCAL indices. + We verify correctness by comparing each shard's output against the reference + implementation run on that shard's local data. + """ + + @staticmethod + def compute_padded_output_size( + num_tokens: int, + num_experts: int, + topk: int, + align_size: int, + num_dp_devices: int, + ) -> int: + """Compute global_num_out_tokens for distributed padding tests. + + Each device processes local_num_tokens tokens. We compute the worst-case + padded output size per device, then multiply by num_dp_devices to get + a global size that ensures global / num_dp >= local_worst. + """ + local_num_tokens = num_tokens // num_dp_devices + local_raw_out = local_num_tokens * topk + local_worst = ((local_raw_out + num_experts * (align_size - 1)) // align_size) * align_size + return local_worst * num_dp_devices + + @staticmethod + def generate_routing_map( + num_tokens: int, + num_experts: int, + topk: int = 2, # Number of experts each token is routed to (max 1s per row). + key: jax.Array = None, + ): + if key is None: + key = jax.random.PRNGKey(0) + + routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32) + for token_idx in range(num_tokens): + key, subkey = jax.random.split(key) + expert_indices = jax.random.choice(subkey, num_experts, shape=(topk,), replace=False) + routing_map = routing_map.at[token_idx, expert_indices].set(1) + + return routing_map + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,hidden_size,topk", + DISPATCH_COMBINE_CASES, + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) + def test_local_token_dispatch( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_shardy, + ): + """ + Test token_dispatch with sharded inputs. + + Verifies that sharded execution produces the same result as chunk-wise + reference execution. The sharded primitive: + 1. Receives global num_out_tokens (partition function divides it) + 2. Each GPU operates on its local shard independently + 3. Results are gathered (concatenated) across GPUs + + Output ordering: [GPU0_expert0, GPU0_expert1, ... | GPU1_expert0, ...] + + The reference processes each chunk independently and concatenates, + matching the sharded execution's output ordering. + Tests both forward pass (output values) and backward pass (gradients). + """ + jax.config.update("jax_use_shardy_partitioner", use_shardy) + key = jax.random.PRNGKey(42) + + # Generate global inputs + key, inp_key, prob_key = jax.random.split(key, 3) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 + ) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + # Shard on token (batch) dimension + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + + # Compute num_out_tokens as concrete values + # Global num_out_tokens is passed to token_dispatch (partition function divides it) + # Local num_out_tokens is used for reference implementation + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + global_num_out_tokens = num_tokens * topk + local_num_tokens = num_tokens // num_dp_devices + local_num_out_tokens = local_num_tokens * topk + + with mesh: + inp_sharding = NamedSharding(mesh, sharded_pspec) + routing_sharding = NamedSharding(mesh, sharded_pspec) + probs_sharding = NamedSharding(mesh, sharded_pspec) + + # Shard the inputs + inp_sharded = jax.device_put(inp, inp_sharding) + routing_sharded = jax.device_put(routing_map, routing_sharding) + probs_sharded = jax.device_put(probs, probs_sharding) + + # ================================================================ + # Forward pass test + # ================================================================ + @jax.jit + def target_dispatch(x, rm, p): + # Pass global num_out_tokens - partition function divides it + out, perm_probs, rid_map, _, _ = token_dispatch( + x, rm, global_num_out_tokens, probs=p + ) + return out, perm_probs, rid_map + + # Reference: process each GPU's shard independently, then concatenate + # This matches how the sharded primitive operates: + # - Each GPU processes its local shard + # - Results are gathered (concatenated) across GPUs + # Output ordering: [GPU0_exp0, GPU0_exp1, ... | GPU1_exp0, GPU1_exp1, ...] + inp_shards = jnp.reshape(inp, (num_dp_devices, local_num_tokens, hidden_size)) + routing_shards = jnp.reshape( + routing_map, (num_dp_devices, local_num_tokens, num_experts) + ) + probs_shards = jnp.reshape(probs, (num_dp_devices, local_num_tokens, num_experts)) + + ref_outputs = [] + ref_perm_probs_list = [] + ref_rid_maps = [] + for i in range(num_dp_devices): + shard_rid_map = reference_make_row_id_map(routing_shards[i]) + shard_out, shard_perm_probs = _reference_permute_impl( + inp_shards[i], shard_rid_map, probs_shards[i], local_num_out_tokens + ) + ref_outputs.append(shard_out) + ref_perm_probs_list.append(shard_perm_probs) + ref_rid_maps.append(shard_rid_map) + + # Concatenate like all_gather would + ref_out = jnp.concatenate(ref_outputs, axis=0) + ref_perm_probs = jnp.concatenate(ref_perm_probs_list, axis=0) + ref_rid_map = jnp.concatenate(ref_rid_maps, axis=0) + + # Run target on sharded inputs + target_out, target_perm_probs, target_rid_map = target_dispatch( + inp_sharded, routing_sharded, probs_sharded + ) + + # Compare forward outputs + assert_allclose(jax.device_get(target_out), ref_out, dtype=dtype) + assert_allclose(jax.device_get(target_perm_probs), ref_perm_probs, dtype=dtype) + + # Verify row_id_map n_routed column matches routing_map sum + target_rid_map_np = jax.device_get(target_rid_map) + assert jnp.array_equal( + target_rid_map_np[:, -1], ref_rid_map[:, -1] + ), "n_routed column mismatch" + + # Sanity checks + target_out_np = jax.device_get(target_out) + target_perm_probs_np = jax.device_get(target_perm_probs) + assert not np.any(np.isnan(target_out_np)), "Output contains NaN" + assert not np.any(np.isnan(target_perm_probs_np)), "Permuted probs contain NaN" + assert np.all(target_perm_probs_np >= 0), "Permuted probs contain negative values" + + # ================================================================ + # Backward pass test (gradients) + # ================================================================ + def target_loss(x, rm, p): + out, perm_probs, _, _, _ = token_dispatch(x, rm, global_num_out_tokens, probs=p) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + # Reference loss: process chunks independently and sum + def ref_chunk_loss(inp_chunk, routing_chunk, probs_chunk): + rid_map = reference_make_row_id_map(routing_chunk) + out, perm_probs = _reference_permute_impl( + inp_chunk, rid_map, probs_chunk, local_num_out_tokens + ) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + target_grad_fn = jax.jit(jax.grad(target_loss, argnums=(0, 2))) + ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss, argnums=(0, 2))) + + target_inp_grad, target_probs_grad = target_grad_fn( + inp_sharded, routing_sharded, probs_sharded + ) + + # Compute reference gradients per chunk, then concatenate + ref_inp_grads = [] + ref_probs_grads = [] + for i in range(num_dp_devices): + chunk_inp_grad, chunk_probs_grad = ref_chunk_grad_fn( + inp_shards[i], routing_shards[i], probs_shards[i] + ) + ref_inp_grads.append(chunk_inp_grad) + ref_probs_grads.append(chunk_probs_grad) + + ref_inp_grad = jnp.concatenate(ref_inp_grads, axis=0) + ref_probs_grad = jnp.concatenate(ref_probs_grads, axis=0) + + assert_allclose(jax.device_get(target_inp_grad), ref_inp_grad, dtype=dtype) + assert_allclose(jax.device_get(target_probs_grad), ref_probs_grad, dtype=dtype) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,hidden_size,topk", + DISPATCH_COMBINE_CASES, + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) + def test_local_roundtrip( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_shardy, + ): + """ + Test roundtrip: token_dispatch followed by token_combine with sharded inputs. + + Each GPU: + 1. Gets a shard of the input and routing_map + 2. Performs local dispatch (permute) + 3. Performs local combine (unpermute) + 4. With uniform merging probs, should recover original input + + Tests both forward pass and backward pass (gradient should be 2*x). + """ + jax.config.update("jax_use_shardy_partitioner", use_shardy) + key = jax.random.PRNGKey(42) + + # Generate global inputs + key, inp_key = jax.random.split(key, 2) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) + + # Uniform merging probs for perfect roundtrip + uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum( + jnp.sum(routing_map, axis=1, keepdims=True), 1.0 + ) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + + # Compute num_out_tokens as concrete value + # Global num_out_tokens is passed to token_dispatch (partition function divides it) + global_num_out_tokens = num_tokens * topk + + with mesh: + inp_sharding = NamedSharding(mesh, sharded_pspec) + routing_sharding = NamedSharding(mesh, sharded_pspec) + merging_sharding = NamedSharding(mesh, sharded_pspec) + + inp_sharded = jax.device_put(inp, inp_sharding) + routing_sharded = jax.device_put(routing_map, routing_sharding) + merging_sharded = jax.device_put(uniform_merging_probs, merging_sharding) + + # ================================================================ + # Forward pass test + # ================================================================ + @jax.jit + def roundtrip(x, rm, mprobs): + dispatched, _, rid_map, _, _ = token_dispatch(x, rm, global_num_out_tokens) + return token_combine(dispatched, rid_map, mprobs) + + roundtrip_out = roundtrip(inp_sharded, routing_sharded, merging_sharded) + + # Should recover original input + assert_allclose(jax.device_get(roundtrip_out), jax.device_get(inp_sharded), dtype=dtype) + + # ================================================================ + # Backward pass test (gradients) + # ================================================================ + def roundtrip_loss(x, rm, mprobs): + dispatched, _, rid_map, _, _ = token_dispatch(x, rm, global_num_out_tokens) + combined = token_combine(dispatched, rid_map, mprobs) + return jnp.sum(combined**2) + + # With uniform merging probs, roundtrip is identity, so gradient should be 2*x + grad_fn = jax.jit(jax.grad(roundtrip_loss, argnums=0)) + computed_grad = grad_fn(inp_sharded, routing_sharded, merging_sharded) + + expected_grad = 2.0 * inp_sharded + + assert_allclose( + jax.device_get(computed_grad), jax.device_get(expected_grad), dtype=dtype + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,hidden_size,topk,align_size", + DISPATCH_COMBINE_PADDING_CASES, + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) + def test_local_token_dispatch_with_padding( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + hidden_size, + topk, + align_size, + dtype, + use_shardy, + ): + """ + Test token_dispatch with padding using sharded inputs. + + Tests both forward pass (output values) and backward pass (gradients). + """ + jax.config.update("jax_use_shardy_partitioner", use_shardy) + key = jax.random.PRNGKey(42) + + # Generate global inputs + key, inp_key, prob_key = jax.random.split(key, 3) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 + ) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + + # For padding + sharding, we need to account for per-shard padding overhead. + # Each shard needs E*(A-1) extra space for worst-case padding. + # Compute global_num_out_tokens such that global / num_dp >= local_worst. + global_num_out_tokens = self.compute_padded_output_size( + num_tokens, num_experts, topk, align_size, num_dp_devices + ) + + with mesh: + inp_sharding = NamedSharding(mesh, sharded_pspec) + routing_sharding = NamedSharding(mesh, sharded_pspec) + probs_sharding = NamedSharding(mesh, sharded_pspec) + + inp_sharded = jax.device_put(inp, inp_sharding) + routing_sharded = jax.device_put(routing_map, routing_sharding) + probs_sharded = jax.device_put(probs, probs_sharding) + + # ================================================================ + # Forward pass test + # ================================================================ + @jax.jit + def dispatch_with_padding(x, rm, p): + out, perm_probs, rid_map, pad_offsets, _ = token_dispatch( + x, rm, global_num_out_tokens, probs=p, align_size=align_size + ) + return out, perm_probs, rid_map, pad_offsets + + out, perm_probs, rid_map, pad_offsets = dispatch_with_padding( + inp_sharded, routing_sharded, probs_sharded + ) + + # Sanity checks + out_np = jax.device_get(out) + perm_probs_np = jax.device_get(perm_probs) + assert not np.any(np.isnan(out_np)), "Output contains NaN" + assert not np.any(np.isnan(perm_probs_np)), "Permuted probs contain NaN" + assert np.all(perm_probs_np >= 0), "Permuted probs contain negative values" + + # ================================================================ + # Backward pass test (gradients) + # ================================================================ + def loss_with_padding(x, rm, p): + out, perm_probs, _, _, _ = token_dispatch( + x, rm, global_num_out_tokens, probs=p, align_size=align_size + ) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + grad_fn = jax.jit(jax.grad(loss_with_padding, argnums=(0, 2))) + inp_grad, probs_grad = grad_fn(inp_sharded, routing_sharded, probs_sharded) + + # Gradients should not contain NaN + assert not np.any(np.isnan(jax.device_get(inp_grad))), "Input gradient contains NaN" + assert not np.any(np.isnan(jax.device_get(probs_grad))), "Probs gradient contains NaN" + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,hidden_size,topk,align_size", + DISPATCH_COMBINE_PADDING_CASES, + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) + def test_local_roundtrip_with_padding( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + hidden_size, + topk, + align_size, + dtype, + use_shardy, + ): + """ + Test roundtrip with padding/alignment using sharded inputs. + + With uniform merging probs, should recover original input. + Tests both forward pass and backward pass. + """ + jax.config.update("jax_use_shardy_partitioner", use_shardy) + key = jax.random.PRNGKey(42) + + # Generate inputs + key, inp_key = jax.random.split(key, 2) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) + + # Uniform merging probs + uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum( + jnp.sum(routing_map, axis=1, keepdims=True), 1.0 + ) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + + # For padding + sharding, we need to account for per-shard padding overhead. + # Each shard needs E*(A-1) extra space for worst-case padding. + # Compute global_num_out_tokens such that global / num_dp >= local_worst. + global_num_out_tokens = self.compute_padded_output_size( + num_tokens, num_experts, topk, align_size, num_dp_devices + ) + + with mesh: + inp_sharding = NamedSharding(mesh, sharded_pspec) + routing_sharding = NamedSharding(mesh, sharded_pspec) + merging_sharding = NamedSharding(mesh, sharded_pspec) + + inp_sharded = jax.device_put(inp, inp_sharding) + routing_sharded = jax.device_put(routing_map, routing_sharding) + merging_sharded = jax.device_put(uniform_merging_probs, merging_sharding) + + # ================================================================ + # Forward pass test + # ================================================================ + @jax.jit + def roundtrip_with_padding(x, rm, mprobs): + dispatched, _, rid_map, pad_offsets, _ = token_dispatch( + x, rm, global_num_out_tokens, align_size=align_size + ) + return token_combine(dispatched, rid_map, mprobs, pad_offsets) + + roundtrip_out = roundtrip_with_padding(inp_sharded, routing_sharded, merging_sharded) + + # Should recover original input + assert_allclose(jax.device_get(roundtrip_out), jax.device_get(inp_sharded), dtype=dtype) + + # ================================================================ + # Backward pass test (gradients) + # ================================================================ + def roundtrip_loss_with_padding(x, rm, mprobs): + dispatched, _, rid_map, pad_offsets, _ = token_dispatch( + x, rm, global_num_out_tokens, align_size=align_size + ) + combined = token_combine(dispatched, rid_map, mprobs, pad_offsets) + return jnp.sum(combined**2) + + # With uniform merging probs, roundtrip is identity, so gradient should be 2*x + grad_fn = jax.jit(jax.grad(roundtrip_loss_with_padding, argnums=0)) + computed_grad = grad_fn(inp_sharded, routing_sharded, merging_sharded) + + expected_grad = 2.0 * inp_sharded + + assert_allclose( + jax.device_get(computed_grad), jax.device_get(expected_grad), dtype=dtype + ) diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index e53b2a9455..4602f41cfd 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -201,8 +201,15 @@ def _permute_kernel( scale_ptr, permuted_scale_ptr, pad_offsets_ptr, + # Pre-allocated output buffers for JAX input_output_aliases. + # These are aliased to output_ptr/permuted_probs_ptr in JAX, so they point to the same memory. + # In PyTorch, pass the same tensors as output_ptr/permuted_probs_ptr. + output_buf_ptr, # pylint: disable=unused-argument + permuted_probs_buf_ptr, # pylint: disable=unused-argument # sizes scale_hidden_dim, + num_tokens, # pylint: disable=unused-argument + num_out_tokens, # pylint: disable=unused-argument # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -228,12 +235,17 @@ def _permute_kernel( FUSION_PAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): + # Note: When FUSION_PAD=True, output buffers should be pre-zeroed by the caller + # to ensure padding positions contain zeros. + # PyTorch: Use torch.zeros() for output buffer allocation + # JAX: Pre-zeroed buffers should be passed (when input_output_aliases works) expert_idx = 0 pid_t = tl.program_id(0) pid_h = tl.program_id(1) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = cur_off < hidden_size + src_row = pid_t.to(tl.int64) input_off = src_row * stride_input_token + cur_off * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) @@ -306,6 +318,10 @@ def _unpermute_kernel( merging_probs_ptr, permuted_probs_ptr, pad_offsets_ptr, + # Dummy parameters for JAX input_output_aliases compatibility (matches _permute_kernel signature pattern) + # These are unused in the unpermute kernel but maintain consistency with the permute kernel. + output_buf_ptr, # pylint: disable=unused-argument + unpermuted_probs_buf_ptr, # pylint: disable=unused-argument # strides stride_row_id_map_token, stride_row_id_map_expert, diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 2e16e674cc..405d5f7661 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -137,7 +137,7 @@ def token_dispatch( ) -@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) def _token_dispatch( inp: jnp.ndarray, routing_map: jnp.ndarray, @@ -240,6 +240,7 @@ def _token_dispatch_fwd_rule( num_experts, worst_case_out_tokens, hidden_size, + align_size=align_size, ) else: # No padding @@ -268,7 +269,6 @@ def _token_dispatch_fwd_rule( def _token_dispatch_bwd_rule( - _routing_map: jnp.ndarray, _num_out_tokens: int, _worst_case_out_tokens: int, _align_size: Optional[int], @@ -281,8 +281,12 @@ def _token_dispatch_bwd_rule( Optional[jnp.ndarray], Optional[jnp.ndarray], ], -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Backward pass rule for token_dispatch.""" +) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray]]: + """Backward pass rule for token_dispatch. + + Returns gradients for (inp, routing_map, probs). + routing_map gradient is None since it's a discrete routing decision. + """ row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads @@ -309,7 +313,9 @@ def _token_dispatch_bwd_rule( hidden_size, ) - return inp_grad, probs_grad if with_probs else None + # Return gradients for (inp, routing_map, probs) + # routing_map is non-differentiable (discrete routing), so return None + return inp_grad, None, probs_grad if with_probs else None _token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule) @@ -497,6 +503,8 @@ def _token_combine_bwd_rule( else: # Simple case: just permute gradients back if pad_offsets is not None: + # Note: align_size uses default (128) since buffer sizes are already + # determined from forward pass (stored in residuals as num_out_tokens) inp_grad, _ = permute_with_mask_map_and_pad( output_grad, row_id_map, @@ -506,6 +514,7 @@ def _token_combine_bwd_rule( num_experts, num_out_tokens, hidden_size, + align_size=128, # Default, sizes already computed in forward ) # The permute kernel only writes to positions that tokens map to. # Padded positions may contain uninitialized (NaN) values - replace with zeros. diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 849673fe31..bd8bd8ff13 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -8,9 +8,13 @@ import jax import jax.numpy as jnp +from jax.sharding import PartitionSpec +from jax.experimental.custom_partitioning import SdyShardingRule import triton from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive +from transformer_engine.jax.cpp_extensions.misc import get_padded_spec, NamedSharding +from transformer_engine.jax.sharding import get_mesh_axis_size from transformer_engine.common.triton.permutation import ( _row_id_map_pass_1_kernel, _row_id_map_pass_2_kernel, @@ -93,7 +97,6 @@ def impl(routing_map, num_tokens, num_experts, block_size): @staticmethod def lowering(ctx, routing_map, *, num_tokens, num_experts, block_size): """MLIR lowering using triton_call_lowering.""" - # Compute strides routing_stride_token = num_experts routing_stride_expert = 1 row_id_stride_token = num_experts * 2 + 1 @@ -101,11 +104,10 @@ def lowering(ctx, routing_map, *, num_tokens, num_experts, block_size): grid = (num_experts, triton.cdiv(num_tokens, block_size)) - # All scalar arguments must be passed as constexprs return triton_call_lowering( ctx, _row_id_map_pass_1_kernel, - routing_map, # Only tensor arguments here + routing_map, grid=grid, constexprs={ "num_tokens": num_tokens, @@ -117,6 +119,76 @@ def lowering(ctx, routing_map, *, num_tokens, num_experts, block_size): }, ) + @staticmethod + def infer_sharding_from_operands( + num_tokens, num_experts, block_size, mesh, arg_infos, result_infos + ): + """Infer output sharding from input sharding.""" + del num_tokens, num_experts, block_size, result_infos + routing_map_spec = get_padded_spec(arg_infos[0]) + # row_id_map has same token dimension sharding as routing_map + # Shape: (num_tokens, num_experts * 2 + 1) + row_id_map_sharding = NamedSharding( + mesh, + PartitionSpec(routing_map_spec[0], None), + desc="RowIdMapPass1.row_id_map_sharding", + ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + workspace_sharding = NamedSharding( + mesh, + PartitionSpec(None, None), + desc="RowIdMapPass1.workspace_sharding", + ) + return [row_id_map_sharding, workspace_sharding] + + @staticmethod + def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos): + """Row id map 1st pass partition.""" + del num_tokens, result_infos + routing_map_spec = get_padded_spec(arg_infos[0]) + + # Input sharding + arg_shardings = (arg_infos[0].sharding,) + + # Output shardings + row_id_map_sharding = NamedSharding( + mesh, + PartitionSpec(routing_map_spec[0], None), + desc="RowIdMapPass1.row_id_map_sharding", + ) + workspace_sharding = NamedSharding( + mesh, + PartitionSpec(None, None), + desc="RowIdMapPass1.workspace_sharding", + ) + out_shardings = [row_id_map_sharding, workspace_sharding] + + def sharded_impl(routing_map): + # Each shard processes its local tokens + local_num_tokens = routing_map.shape[0] + return RowIdMapPass1Primitive.impl( + routing_map, + num_tokens=local_num_tokens, + num_experts=num_experts, + block_size=block_size, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, result_types): + """Shardy sharding rule for this primitive.""" + del num_tokens, num_experts, block_size, mesh, value_types, result_types + prefix = "RowIdMapPass1" + # routing_map shape: (num_tokens, num_experts) + input_spec = (f"{prefix}_tokens", f"{prefix}_experts") + # row_id_map shape: (num_tokens, num_experts * 2 + 1) + # Note: row_id_cols != experts since it's num_experts * 2 + 1 + row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") + # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks") + return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec)) + register_primitive(RowIdMapPass1Primitive) @@ -185,6 +257,69 @@ def lowering(ctx, row_id_map, workspace, *, num_tokens, num_experts, block_size) }, ) + @staticmethod + def infer_sharding_from_operands( + num_tokens, num_experts, block_size, mesh, arg_infos, result_infos + ): + """Infer output sharding from input sharding.""" + del num_tokens, num_experts, block_size, result_infos + row_id_map_spec = get_padded_spec(arg_infos[0]) + # Output has same sharding as input (in-place operation) + row_id_map_sharding = NamedSharding( + mesh, + PartitionSpec(*row_id_map_spec), + desc="RowIdMapPass2.row_id_map_sharding", + ) + workspace_sharding = NamedSharding( + mesh, + PartitionSpec(None, None), + desc="RowIdMapPass2.workspace_sharding", + ) + return [row_id_map_sharding, workspace_sharding] + + @staticmethod + def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos): + """Partition the primitive for distributed execution.""" + del num_tokens, result_infos + row_id_map_spec = get_padded_spec(arg_infos[0]) + + # Input shardings + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) + + # Output shardings (same as inputs for in-place operation) + row_id_map_sharding = NamedSharding( + mesh, + PartitionSpec(*row_id_map_spec), + desc="RowIdMapPass2.row_id_map_sharding", + ) + workspace_sharding = NamedSharding( + mesh, + PartitionSpec(None, None), + desc="RowIdMapPass2.workspace_sharding", + ) + out_shardings = [row_id_map_sharding, workspace_sharding] + + def sharded_impl(row_id_map, workspace): + local_num_tokens = row_id_map.shape[0] + return RowIdMapPass2Primitive.impl( + row_id_map, + workspace, + num_tokens=local_num_tokens, + num_experts=num_experts, + block_size=block_size, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, result_types): + """Shardy sharding rule for this primitive.""" + del num_tokens, num_experts, block_size, mesh, value_types, result_types + prefix = "RowIdMapPass2" + row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols") + workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks") + return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec)) + register_primitive(RowIdMapPass2Primitive) @@ -240,6 +375,52 @@ def lowering(ctx, row_id_map, *, num_tokens, num_experts): }, ) + @staticmethod + def infer_sharding_from_operands(num_tokens, num_experts, mesh, arg_infos, result_infos): + """Infer output sharding from input sharding.""" + del num_tokens, num_experts, result_infos + row_id_map_spec = get_padded_spec(arg_infos[0]) + # Output has same sharding as input (in-place operation) + return NamedSharding( + mesh, + PartitionSpec(*row_id_map_spec), + desc="RowIdMapPass3.row_id_map_sharding", + ) + + @staticmethod + def partition(num_tokens, num_experts, mesh, arg_infos, result_infos): + """Partition the primitive for distributed execution.""" + del num_tokens, result_infos + row_id_map_spec = get_padded_spec(arg_infos[0]) + + # Input sharding + arg_shardings = (arg_infos[0].sharding,) + + # Output sharding (same as input for in-place operation) + out_sharding = NamedSharding( + mesh, + PartitionSpec(*row_id_map_spec), + desc="RowIdMapPass3.row_id_map_sharding", + ) + + def sharded_impl(row_id_map): + local_num_tokens = row_id_map.shape[0] + return RowIdMapPass3Primitive.impl( + row_id_map, + num_tokens=local_num_tokens, + num_experts=num_experts, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(num_tokens, num_experts, mesh, value_types, result_types): + """Shardy sharding rule for this primitive.""" + del num_tokens, num_experts, mesh, value_types, result_types + prefix = "RowIdMapPass3" + row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols") + return SdyShardingRule((row_id_map_spec,), (row_id_map_spec,)) + register_primitive(RowIdMapPass3Primitive) @@ -251,8 +432,12 @@ class PermuteWithMaskMapPrimitive(BasePrimitive): name = "te_permute_with_mask_map_triton" multiple_results = True - # scale, permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False) - # pad_offsets can be shape (0,) when not doing padding, or (num_experts,) when padding + # Outer primitive has 6 tensor inputs: inp, row_id_map, probs, scale, permuted_scale, pad_offsets + # Static args for outer primitive: num_tokens, num_experts, num_out_tokens, hidden_size, + # with_probs, with_pad, align_size + # Inner primitive adds output_buf, permuted_probs_buf) + + # impl_static_args is for the outer primitive's impl() which has 6 tensor inputs. impl_static_args = ( 6, 7, @@ -260,7 +445,8 @@ class PermuteWithMaskMapPrimitive(BasePrimitive): 9, 10, 11, - ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, with_pad + 12, + ) inner_primitive = None outer_primitive = None @@ -272,6 +458,8 @@ def abstract( scale_aval, # dummy, same shape as inp permuted_scale_aval, # dummy, same shape as inp pad_offsets_aval, + output_buf_aval=None, # Pre-zeroed output buffer (inner primitive only) + permuted_probs_buf_aval=None, # Pre-zeroed permuted_probs buffer (inner primitive only) *, num_tokens, num_experts, @@ -279,10 +467,12 @@ def abstract( hidden_size, with_probs, with_pad, + align_size, ): """Shape/dtype inference for permute.""" del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval - del num_tokens, num_experts, with_pad + del num_tokens, num_experts, with_pad, align_size + del output_buf_aval, permuted_probs_buf_aval # Used for input_output_aliases only output_shape = (num_out_tokens, hidden_size) output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) @@ -308,9 +498,29 @@ def impl( hidden_size, with_probs, with_pad, + align_size, # align_size is only used for sharding, but must be passed since abstract() requires it ): """Forward to inner primitive.""" + assert PermuteWithMaskMapPrimitive.inner_primitive is not None + + # Create pre-zeroed output buffers for the inner primitive. + # When with_pad=True, this ensures padding positions contain zeros. + # These buffers are aliased to the outputs via input_output_aliases in the lowering. + if with_pad: + output_buf = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) + if with_probs: + permuted_probs_buf = jnp.zeros((num_out_tokens,), dtype=probs.dtype) + else: + permuted_probs_buf = jnp.zeros((0,), dtype=inp.dtype) + else: + # When not padding, use empty buffers (kernel ignores them, lowering skips aliasing) + output_buf = jnp.empty((num_out_tokens, hidden_size), dtype=inp.dtype) + if with_probs: + permuted_probs_buf = jnp.empty((num_out_tokens,), dtype=probs.dtype) + else: + permuted_probs_buf = jnp.empty((0,), dtype=inp.dtype) + return PermuteWithMaskMapPrimitive.inner_primitive.bind( inp, row_id_map, @@ -318,12 +528,15 @@ def impl( scale, permuted_scale, pad_offsets, + output_buf, + permuted_probs_buf, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, hidden_size=hidden_size, with_probs=with_probs, with_pad=with_pad, + align_size=align_size, ) @staticmethod @@ -335,6 +548,8 @@ def lowering( scale, permuted_scale, pad_offsets, + output_buf, # Pre-zeroed output buffer (for input_output_aliases) + permuted_probs_buf, # Pre-zeroed permuted_probs buffer (for input_output_aliases) *, num_tokens, num_experts, @@ -342,9 +557,10 @@ def lowering( hidden_size, with_probs, with_pad, + align_size, ): """MLIR lowering using triton_call_lowering.""" - del num_out_tokens + del align_size inp_stride_token = hidden_size inp_stride_hidden = 1 output_stride_token = hidden_size @@ -371,6 +587,18 @@ def lowering( block_size = _get_min_block_size(_permute_kernel) grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + # Use input_output_aliases to alias pre-zeroed buffers to outputs. + # This ensures padding positions contain zeros since the kernel only writes valid positions. + # Input indices: 0=inp, 1=row_id_map, 2=probs, 3=scale, 4=permuted_scale, + # 5=pad_offsets, 6=output_buf, 7=permuted_probs_buf + # Output indices: 0=output, 1=permuted_probs + if with_pad: + input_output_aliases = {6: 0} + if with_probs: + input_output_aliases[7] = 1 + else: + input_output_aliases = None + return triton_call_lowering( ctx, _permute_kernel, @@ -380,9 +608,14 @@ def lowering( scale, permuted_scale, pad_offsets, + output_buf, + permuted_probs_buf, grid=grid, + input_output_aliases=input_output_aliases, constexprs={ "scale_hidden_dim": 0, + "num_tokens": num_tokens, + "num_out_tokens": num_out_tokens, "stride_row_id_map_token": row_id_stride_token, "stride_row_id_map_expert": row_id_stride_expert, "stride_input_token": inp_stride_token, @@ -405,24 +638,242 @@ def lowering( }, ) + @staticmethod + def infer_sharding_from_operands( + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + with_probs, + with_pad, + align_size, + mesh, + arg_infos, + result_infos, + ): + """Infer output sharding from input sharding. + + For batch-dimension partitioning: + - Input (num_tokens, hidden_size) is sharded on token dim + - Output (num_out_tokens, hidden_size) gets same token dim sharding + - Permuted probs (num_out_tokens,) gets same token dim sharding + """ + del align_size # Used only in partition + del num_tokens, num_experts, num_out_tokens, hidden_size, with_pad, result_infos + inp_spec = get_padded_spec(arg_infos[0]) + # Output has same sharding pattern: (token_shard, None) + output_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0], None), + desc="PermuteWithMaskMap.output_sharding", + ) + if with_probs: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0]), + desc="PermuteWithMaskMap.permuted_probs_sharding", + ) + else: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="PermuteWithMaskMap.permuted_probs_sharding_empty", + ) + return [output_sharding, permuted_probs_sharding] + + @staticmethod + def partition( + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + with_probs, + with_pad, + align_size, + mesh, + arg_infos, + result_infos, + ): + """Partition the primitive for distributed execution. + + For batch-dimension partitioning, each GPU processes its local tokens + independently. The row_id_map contains local destination indices, + so no inter-GPU communication is needed. + """ + del num_tokens, result_infos + inp_spec = get_padded_spec(arg_infos[0]) + + # Input shardings - preserve original shardings + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + + # Output shardings + output_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0], None), + desc="PermuteWithMaskMap.output_sharding", + ) + if with_probs: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0]), + desc="PermuteWithMaskMap.permuted_probs_sharding", + ) + else: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="PermuteWithMaskMap.permuted_probs_sharding_empty", + ) + out_shardings = [output_sharding, permuted_probs_sharding] + + # Get number of data parallel devices from the batch sharding axis + batch_axis = inp_spec[0] + if batch_axis is not None: + num_dp_devices = get_mesh_axis_size(batch_axis, mesh) + else: + num_dp_devices = 1 + + def sharded_impl(inp, row_id_map, probs, scale, permuted_scale, pad_offsets): + # Each shard processes its local tokens independently (data parallelism) + local_num_tokens = inp.shape[0] + + # ========================================================================= + # MoE Permutation Sharding (data parallelism, no expert parallelism) + # ========================================================================= + # Each GPU has ALL experts and processes its local batch of tokens. + # + # TopK bounds output: each token goes to at most topK experts, so: + # global_num_out_tokens = global_num_in_tokens * topK + # local_num_out_tokens = local_num_in_tokens * topK + # = global_num_out_tokens / num_dp_devices + # + # E = num_experts + # A = align_size for padding to group gemm size in cuBLAS + # With padding (align_size != 128, which is the default/no-op value): + # The global num_out_tokens passed here is already worst_case_out_tokens. + # We need to recalculate local worst-case from local raw tokens. + # local_raw_out_tokens = global_raw_out_tokens / num_dp_devices + # local_worst_case = ((local_raw_out + E*(A-1)) // A) * A + # + # Local permute produces output ordered by expert: [E0 | E1 | ... | EN] + # where each expert section contains tokens routed to that expert. + # + # Global assembly (if needed) should be done outside this primitive. + + # ========================================================================= + # Output size calculation + # ========================================================================= + # For both padding and non-padding cases, use simple division. + # The global num_out_tokens is already the worst-case buffer size. + # + # IMPORTANT for padding + sharding: + # Padding overhead is per-shard (each shard needs E*(A-1) extra space). + # The caller must account for this by passing a sufficiently large + # global num_out_tokens such that: global_worst / num_dp >= local_worst + # where local_worst = ((local_raw + E*(A-1)) // A) * A + + local_num_out_tokens = num_out_tokens // num_dp_devices + + # Local permute - output stays sharded on this GPU + local_output, local_permuted_probs = PermuteWithMaskMapPrimitive.impl( + inp, + row_id_map, + probs, + scale, + permuted_scale, + pad_offsets, + num_tokens=local_num_tokens, + num_experts=num_experts, + num_out_tokens=local_num_out_tokens, + hidden_size=hidden_size, + with_probs=with_probs, + with_pad=with_pad, + align_size=align_size, + ) + + return local_output, local_permuted_probs + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + with_probs, + with_pad, + align_size, + mesh, + value_types, + result_types, + ): + """Shardy sharding rule for this primitive.""" + del ( + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + align_size, + mesh, + value_types, + result_types, + ) + prefix = "PermuteWithMaskMap" + # inp: (num_tokens, hidden_size) + inp_spec = (f"{prefix}_tokens", f"{prefix}_hidden") + # row_id_map: (num_tokens, num_experts * 2 + 1) + row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") + # probs: (num_tokens, num_experts) or (0,) + probs_spec = ( + (f"{prefix}_tokens", f"{prefix}_experts") if with_probs else (f"{prefix}_empty",) + ) + # scale: (num_tokens, hidden_size) - same shape as inp, permuted together + scale_spec = (f"{prefix}_tokens", f"{prefix}_hidden") + # permuted_scale: (num_out_tokens, hidden_size) - same shape as output + permuted_scale_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden") + # pad_offsets: (num_experts,) or (0,) - uses same experts factor as probs + pad_offsets_spec = (f"{prefix}_experts",) if with_pad else (f"{prefix}_pad_empty",) + # output: (num_out_tokens, hidden_size) + output_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden") + # permuted_probs: (num_out_tokens,) or (0,) + permuted_probs_spec = (f"{prefix}_out_tokens",) if with_probs else (f"{prefix}_empty2",) + + return SdyShardingRule( + ( + inp_spec, + row_id_map_spec, + probs_spec, + scale_spec, + permuted_scale_spec, + pad_offsets_spec, + ), + (output_spec, permuted_probs_spec), + ) + register_primitive(PermuteWithMaskMapPrimitive) class UnpermuteWithMaskMapPrimitive(BasePrimitive): """ - Unpermute the input tensor based on the row_id_map. + Unpermute the input tensor based on the row_id_map, optionally with fused unpadding. """ name = "te_unpermute_with_mask_map_triton" multiple_results = True + # Outer primitive has 5 tensor inputs: inp, row_id_map, merging_probs, permuted_probs, pad_offsets + # Static args for outer primitive: num_tokens, num_experts, hidden_size, + # with_merging_probs, with_probs, with_unpad + # Inner primitive has adds output_buf, unpermuted_probs_buf impl_static_args = ( 5, 6, 7, 8, 9, - ) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs + 10, + ) inner_primitive = None outer_primitive = None @@ -432,16 +883,20 @@ def abstract( row_id_map_aval, merging_probs_aval, permuted_probs_aval, - pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False + pad_offsets_aval, + output_buf_aval=None, # Dummy (inner primitive only) + unpermuted_probs_buf_aval=None, # Dummy (inner primitive only) *, num_tokens, num_experts, hidden_size, with_merging_probs, with_probs, + with_unpad, ): """Shape/dtype inference for unpermute.""" - del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval + del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval, with_unpad + del output_buf_aval, unpermuted_probs_buf_aval output_shape = (num_tokens, hidden_size) output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) @@ -468,20 +923,33 @@ def impl( hidden_size, with_merging_probs, with_probs, + with_unpad, ): """Forward to inner primitive.""" assert UnpermuteWithMaskMapPrimitive.inner_primitive is not None + + # Create dummy buffers for kernel signature consistency with _permute_kernel. + # These are not used for pre-zeroing since unpermute writes to all output positions. + output_buf = jnp.empty((num_tokens, hidden_size), dtype=inp.dtype) + if with_probs: + unpermuted_probs_buf = jnp.empty((num_tokens, num_experts), dtype=permuted_probs.dtype) + else: + unpermuted_probs_buf = jnp.empty((0,), dtype=inp.dtype) + return UnpermuteWithMaskMapPrimitive.inner_primitive.bind( inp, row_id_map, merging_probs, permuted_probs, pad_offsets, + output_buf, + unpermuted_probs_buf, num_tokens=num_tokens, num_experts=num_experts, hidden_size=hidden_size, with_merging_probs=with_merging_probs, with_probs=with_probs, + with_unpad=with_unpad, ) @staticmethod @@ -492,12 +960,15 @@ def lowering( merging_probs, permuted_probs, pad_offsets, + output_buf, # Dummy for kernel signature consistency + unpermuted_probs_buf, # Dummy for kernel signature consistency *, num_tokens, num_experts, hidden_size, with_merging_probs, with_probs, + with_unpad, ): """MLIR lowering using triton_call_lowering.""" # Compute strides @@ -523,7 +994,6 @@ def lowering( block_size = _get_min_block_size(_unpermute_kernel) grid = (num_tokens, triton.cdiv(hidden_size, block_size)) - # Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False) return triton_call_lowering( ctx, _unpermute_kernel, @@ -532,6 +1002,8 @@ def lowering( merging_probs, permuted_probs, pad_offsets, + output_buf, + unpermuted_probs_buf, grid=grid, constexprs={ "stride_row_id_map_token": row_id_stride_token, @@ -550,174 +1022,170 @@ def lowering( "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), "WITH_MERGING_PROBS": with_merging_probs, "PERMUTE_PROBS": with_probs, - "FUSION_UNPAD": False, + "FUSION_UNPAD": with_unpad, "BLOCK_SIZE": block_size, }, ) - -register_primitive(UnpermuteWithMaskMapPrimitive) - - -class UnpermuteWithMaskMapAndUnpadPrimitive(BasePrimitive): - """ - Unpermute the input tensor based on the row_id_map with fused unpadding. - """ - - name = "te_unpermute_with_mask_map_and_unpad_triton" - multiple_results = True - impl_static_args = ( - 5, - 6, - 7, - 8, - 9, - ) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs - inner_primitive = None - outer_primitive = None - @staticmethod - def abstract( - inp_aval, - row_id_map_aval, - merging_probs_aval, - permuted_probs_aval, - pad_offsets_aval, - *, + def infer_sharding_from_operands( num_tokens, num_experts, hidden_size, with_merging_probs, with_probs, + with_unpad, + mesh, + arg_infos, + result_infos, ): - """Shape/dtype inference for unpermute with unpadding.""" - del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval - - output_shape = (num_tokens, hidden_size) - output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) - + """Infer output sharding from input sharding. + + For batch-dimension partitioning: + - row_id_map (num_tokens, num_experts*2+1) is sharded on token dim + - Output (num_tokens, hidden_size) gets same token dim sharding + """ + del num_tokens, num_experts, hidden_size, with_merging_probs, with_unpad, result_infos + row_id_map_spec = get_padded_spec(arg_infos[1]) + # Output has same token dimension sharding as row_id_map + output_sharding = NamedSharding( + mesh, + PartitionSpec(row_id_map_spec[0], None), + desc="UnpermuteWithMaskMap.output_sharding", + ) if with_probs: - unpermuted_probs_shape = (num_tokens, num_experts) - unpermuted_probs_aval = jax.core.ShapedArray( - unpermuted_probs_shape, permuted_probs_aval.dtype + unpermuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(row_id_map_spec[0], None), + desc="UnpermuteWithMaskMap.unpermuted_probs_sharding", ) else: - unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype) - - return output_aval, unpermuted_probs_aval + unpermuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="UnpermuteWithMaskMap.unpermuted_probs_sharding_empty", + ) + return [output_sharding, unpermuted_probs_sharding] @staticmethod - def impl( - inp, - row_id_map, - merging_probs, - permuted_probs, - pad_offsets, + def partition( num_tokens, num_experts, hidden_size, with_merging_probs, with_probs, + with_unpad, + mesh, + arg_infos, + result_infos, ): - """Forward to inner primitive.""" - assert UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive is not None - return UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive.bind( - inp, - row_id_map, - merging_probs, - permuted_probs, - pad_offsets, - num_tokens=num_tokens, - num_experts=num_experts, - hidden_size=hidden_size, - with_merging_probs=with_merging_probs, - with_probs=with_probs, + """Partition the primitive for distributed execution.""" + del num_tokens, result_infos + row_id_map_spec = get_padded_spec(arg_infos[1]) + + # Input shardings - preserve original shardings + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + + # Output shardings + output_sharding = NamedSharding( + mesh, + PartitionSpec(row_id_map_spec[0], None), + desc="UnpermuteWithMaskMap.output_sharding", ) + if with_probs: + unpermuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(row_id_map_spec[0], None), + desc="UnpermuteWithMaskMap.unpermuted_probs_sharding", + ) + else: + unpermuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="UnpermuteWithMaskMap.unpermuted_probs_sharding_empty", + ) + out_shardings = [output_sharding, unpermuted_probs_sharding] + + def sharded_impl(inp, row_id_map, merging_probs, permuted_probs, pad_offsets): + # Each shard processes its local tokens + local_num_tokens = row_id_map.shape[0] + return UnpermuteWithMaskMapPrimitive.impl( + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + num_tokens=local_num_tokens, + num_experts=num_experts, + hidden_size=hidden_size, # hidden_size is not sharded + with_merging_probs=with_merging_probs, + with_probs=with_probs, + with_unpad=with_unpad, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings @staticmethod - def lowering( - ctx, - inp, - row_id_map, - merging_probs, - permuted_probs, - pad_offsets, - *, + def shardy_sharding_rule( num_tokens, num_experts, hidden_size, with_merging_probs, with_probs, + with_unpad, + mesh, + value_types, + result_types, ): - """MLIR lowering using triton_call_lowering.""" - # Compute strides - inp_stride_token = hidden_size - inp_stride_hidden = 1 - output_stride_token = hidden_size - output_stride_hidden = 1 - row_id_stride_token = num_experts * 2 + 1 - row_id_stride_expert = 1 - - if with_merging_probs: - merging_probs_stride_token = num_experts - merging_probs_stride_expert = 1 - else: - merging_probs_stride_token = 0 - merging_probs_stride_expert = 0 - - permuted_probs_stride_token = 1 - unpermuted_probs_stride_token = num_experts - unpermuted_probs_stride_expert = 1 - - # Grid - use minimum BLOCK_SIZE from autotune configs - block_size = _get_min_block_size(_unpermute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + """Shardy sharding rule for this primitive.""" + del num_tokens, num_experts, hidden_size, mesh, value_types, result_types + prefix = "UnpermuteWithMaskMap" + # inp: (num_out_tokens, hidden_size) + inp_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden") + # row_id_map: (num_tokens, num_experts * 2 + 1) + row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") + # merging_probs: (num_tokens, num_experts) or (0,) + merging_probs_spec = ( + (f"{prefix}_tokens", f"{prefix}_experts") + if with_merging_probs + else (f"{prefix}_empty",) + ) + # permuted_probs: (num_out_tokens,) or (0,) + permuted_probs_spec = (f"{prefix}_out_tokens",) if with_probs else (f"{prefix}_empty2",) + # pad_offsets: (num_experts,) when with_unpad=True, or dummy (0,) otherwise + pad_offsets_spec = (f"{prefix}_experts",) if with_unpad else (f"{prefix}_pad_empty",) + # output: (num_tokens, hidden_size) + output_spec = (f"{prefix}_tokens", f"{prefix}_hidden") + # unpermuted_probs: (num_tokens, num_experts) or (0,) + unpermuted_probs_spec = ( + (f"{prefix}_tokens", f"{prefix}_experts") if with_probs else (f"{prefix}_empty3",) + ) - return triton_call_lowering( - ctx, - _unpermute_kernel, - inp, - row_id_map, - merging_probs, - permuted_probs, - pad_offsets, - grid=grid, - constexprs={ - "stride_row_id_map_token": row_id_stride_token, - "stride_row_id_map_expert": row_id_stride_expert, - "stride_input_token": inp_stride_token, - "stride_input_hidden": inp_stride_hidden, - "stride_output_token": output_stride_token, - "stride_output_hidden": output_stride_hidden, - "stride_merging_probs_token": merging_probs_stride_token, - "stride_merging_probs_expert": merging_probs_stride_expert, - "stride_permuted_probs_token": permuted_probs_stride_token, - "stride_unpermuted_probs_token": unpermuted_probs_stride_token, - "stride_unpermuted_probs_expert": unpermuted_probs_stride_expert, - "num_experts": num_experts, - "hidden_size": hidden_size, - "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), - "WITH_MERGING_PROBS": with_merging_probs, - "PERMUTE_PROBS": with_probs, - "FUSION_UNPAD": True, - "BLOCK_SIZE": block_size, - }, + return SdyShardingRule( + (inp_spec, row_id_map_spec, merging_probs_spec, permuted_probs_spec, pad_offsets_spec), + (output_spec, unpermuted_probs_spec), ) -register_primitive(UnpermuteWithMaskMapAndUnpadPrimitive) +register_primitive(UnpermuteWithMaskMapPrimitive) class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive): """ - Backward pass for unpermute with merging probabilities. + Backward pass for unpermute with merging probabilities, optionally with fused unpadding. This kernel computes gradients for both the input and merging_probs. """ name = "te_unpermute_bwd_with_merging_probs_triton" multiple_results = True - impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size + impl_static_args = ( + 5, + 6, + 7, + 8, + 9, + ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_unpad inner_primitive = None outer_primitive = None @@ -727,15 +1195,16 @@ def abstract( fwd_input_aval, merging_probs_aval, row_id_map_aval, - pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False + pad_offsets_aval, *, num_tokens, num_experts, num_out_tokens, hidden_size, + with_unpad, ): """Shape/dtype inference for unpermute backward with merging probs.""" - del fwd_input_aval, row_id_map_aval, pad_offsets_aval + del fwd_input_aval, row_id_map_aval, pad_offsets_aval, with_unpad # fwd_input_grad has same shape as fwd_input fwd_input_grad_shape = (num_out_tokens, hidden_size) @@ -760,6 +1229,7 @@ def impl( num_experts, num_out_tokens, hidden_size, + with_unpad, ): """Forward to inner primitive.""" assert UnpermuteBwdWithMergingProbsPrimitive.inner_primitive is not None @@ -773,6 +1243,7 @@ def impl( num_experts=num_experts, num_out_tokens=num_out_tokens, hidden_size=hidden_size, + with_unpad=with_unpad, ) @staticmethod @@ -788,6 +1259,7 @@ def lowering( num_experts, num_out_tokens, hidden_size, + with_unpad, ): """MLIR lowering using triton_call_lowering.""" del num_out_tokens @@ -812,7 +1284,6 @@ def lowering( # Get min block size from autotune configs for consistency block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel) - # Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False) return triton_call_lowering( ctx, _unpermute_bwd_with_merging_probs_kernel, @@ -838,152 +1309,126 @@ def lowering( "num_experts": num_experts, "hidden_size": hidden_size, "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), - "FUSION_UNPAD": False, + "FUSION_UNPAD": with_unpad, "BLOCK_SIZE": block_size, }, ) - -register_primitive(UnpermuteBwdWithMergingProbsPrimitive) - - -class UnpermuteBwdWithMergingProbsAndUnpadPrimitive(BasePrimitive): - """ - Backward pass for unpermute with merging probabilities and fused unpadding. - - This kernel computes gradients for both the input and merging_probs, - while handling padded outputs. - """ - - name = "te_unpermute_bwd_with_merging_probs_and_unpad_triton" - multiple_results = True - impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size - inner_primitive = None - outer_primitive = None - @staticmethod - def abstract( - fwd_output_grad_aval, - fwd_input_aval, - merging_probs_aval, - row_id_map_aval, - pad_offsets_aval, - *, + def infer_sharding_from_operands( num_tokens, num_experts, num_out_tokens, hidden_size, + with_unpad, + mesh, + arg_infos, + result_infos, ): - """Shape/dtype inference for unpermute backward with merging probs and unpadding.""" - del fwd_input_aval, row_id_map_aval, pad_offsets_aval - - # fwd_input_grad has same shape as fwd_input - fwd_input_grad_shape = (num_out_tokens, hidden_size) - fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype) - - # merging_probs_grad has same shape as merging_probs - merging_probs_grad_shape = (num_tokens, num_experts) - merging_probs_grad_aval = jax.core.ShapedArray( - merging_probs_grad_shape, merging_probs_aval.dtype + """Infer output sharding from input sharding.""" + del num_tokens, num_experts, num_out_tokens, hidden_size, with_unpad, result_infos + fwd_output_grad_spec = get_padded_spec(arg_infos[0]) + merging_probs_spec = get_padded_spec(arg_infos[2]) + # fwd_input_grad has same token sharding as fwd_output_grad + fwd_input_grad_sharding = NamedSharding( + mesh, + PartitionSpec(fwd_output_grad_spec[0], None), + desc="UnpermuteBwdWithMergingProbs.fwd_input_grad_sharding", ) - - return fwd_input_grad_aval, merging_probs_grad_aval + # merging_probs_grad has same sharding as merging_probs + merging_probs_grad_sharding = NamedSharding( + mesh, + PartitionSpec(merging_probs_spec[0], None), + desc="UnpermuteBwdWithMergingProbs.merging_probs_grad_sharding", + ) + return [fwd_input_grad_sharding, merging_probs_grad_sharding] @staticmethod - def impl( - fwd_output_grad, - fwd_input, - merging_probs, - row_id_map, - pad_offsets, + def partition( num_tokens, num_experts, num_out_tokens, hidden_size, + with_unpad, + mesh, + arg_infos, + result_infos, ): - """Forward to inner primitive.""" - assert UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive is not None - return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive.bind( - fwd_output_grad, - fwd_input, - merging_probs, - row_id_map, - pad_offsets, - num_tokens=num_tokens, - num_experts=num_experts, - num_out_tokens=num_out_tokens, - hidden_size=hidden_size, + """Partition the primitive for distributed execution.""" + del num_tokens, num_out_tokens, result_infos + fwd_output_grad_spec = get_padded_spec(arg_infos[0]) + merging_probs_spec = get_padded_spec(arg_infos[2]) + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + + fwd_input_grad_sharding = NamedSharding( + mesh, + PartitionSpec(fwd_output_grad_spec[0], None), + desc="UnpermuteBwdWithMergingProbs.fwd_input_grad_sharding", + ) + merging_probs_grad_sharding = NamedSharding( + mesh, + PartitionSpec(merging_probs_spec[0], None), + desc="UnpermuteBwdWithMergingProbs.merging_probs_grad_sharding", ) + out_shardings = [fwd_input_grad_sharding, merging_probs_grad_sharding] + + def sharded_impl(fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets): + local_num_tokens = row_id_map.shape[0] + # NOTE: local_num_out_tokens is obtained from the actual tensor shape, + # which reflects the data-dependent output size from the forward pass. + local_num_out_tokens = fwd_input.shape[0] + return UnpermuteBwdWithMergingProbsPrimitive.impl( + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, + num_tokens=local_num_tokens, + num_experts=num_experts, + num_out_tokens=local_num_out_tokens, + hidden_size=hidden_size, # hidden_size is not sharded + with_unpad=with_unpad, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings @staticmethod - def lowering( - ctx, - fwd_output_grad, - fwd_input, - merging_probs, - row_id_map, - pad_offsets, - *, + def shardy_sharding_rule( num_tokens, num_experts, num_out_tokens, hidden_size, + with_unpad, + mesh, + value_types, + result_types, ): - """MLIR lowering using triton_call_lowering.""" - del num_out_tokens - - # Compute strides - row_id_stride_token = num_experts * 2 + 1 - row_id_stride_expert = 1 - fwd_output_grad_stride_token = hidden_size - fwd_output_grad_stride_hidden = 1 - fwd_input_grad_stride_token = hidden_size - fwd_input_grad_stride_hidden = 1 - fwd_input_stride_token = hidden_size - fwd_input_stride_hidden = 1 - merging_probs_stride_token = num_experts - merging_probs_stride_expert = 1 - merging_probs_grad_stride_token = num_experts - merging_probs_grad_stride_expert = 1 - - # Grid - one program per token - grid = (num_tokens,) - - # Get min block size from autotune configs for consistency - block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel) - - return triton_call_lowering( - ctx, - _unpermute_bwd_with_merging_probs_kernel, - fwd_output_grad, - fwd_input, - merging_probs, - row_id_map, - pad_offsets, - grid=grid, - constexprs={ - "stride_row_id_map_token": row_id_stride_token, - "stride_row_id_map_expert": row_id_stride_expert, - "stride_fwd_output_grad_token": fwd_output_grad_stride_token, - "stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden, - "stride_fwd_input_grad_token": fwd_input_grad_stride_token, - "stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden, - "stride_fwd_input_token": fwd_input_stride_token, - "stride_fwd_input_hidden": fwd_input_stride_hidden, - "stride_merging_probs_token": merging_probs_stride_token, - "stride_merging_probs_expert": merging_probs_stride_expert, - "stride_merging_probs_grad_token": merging_probs_grad_stride_token, - "stride_merging_probs_grad_expert": merging_probs_grad_stride_expert, - "num_experts": num_experts, - "hidden_size": hidden_size, - "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), - "FUSION_UNPAD": True, - "BLOCK_SIZE": block_size, - }, + """Shardy sharding rule for this primitive.""" + del num_tokens, num_experts, num_out_tokens, hidden_size, mesh, value_types, result_types + prefix = "UnpermuteBwdWithMergingProbs" + fwd_output_grad_spec = (f"{prefix}_tokens", f"{prefix}_hidden") + fwd_input_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden") + merging_probs_spec = (f"{prefix}_tokens", f"{prefix}_experts") + row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") + # pad_offsets: (num_experts,) when with_unpad=True, or dummy (0,) otherwise + pad_offsets_spec = (f"{prefix}_experts",) if with_unpad else (f"{prefix}_pad_empty",) + fwd_input_grad_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden") + merging_probs_grad_spec = (f"{prefix}_tokens", f"{prefix}_experts") + + return SdyShardingRule( + ( + fwd_output_grad_spec, + fwd_input_spec, + merging_probs_spec, + row_id_map_spec, + pad_offsets_spec, + ), + (fwd_input_grad_spec, merging_probs_grad_spec), ) -register_primitive(UnpermuteBwdWithMergingProbsAndUnpadPrimitive) +register_primitive(UnpermuteBwdWithMergingProbsPrimitive) def unpermute_bwd_with_merging_probs( @@ -1027,7 +1472,7 @@ def unpermute_bwd_with_merging_probs( merging_probs_grad : jnp.ndarray Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`. """ - # Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature) + # Create dummy pad_offsets (not used when with_unpad=False, but required by kernel signature) dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) # Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind( @@ -1040,6 +1485,7 @@ def unpermute_bwd_with_merging_probs( num_experts=num_experts, num_out_tokens=num_out_tokens, hidden_size=hidden_size, + with_unpad=False, ) @@ -1088,7 +1534,7 @@ def unpermute_bwd_with_merging_probs_and_unpad( merging_probs_grad : jnp.ndarray Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`. """ - return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.outer_primitive.bind( + return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind( fwd_output_grad, fwd_input, merging_probs, @@ -1098,6 +1544,7 @@ def unpermute_bwd_with_merging_probs_and_unpad( num_experts=num_experts, num_out_tokens=num_out_tokens, hidden_size=hidden_size, + with_unpad=True, ) @@ -1147,6 +1594,54 @@ def lowering(ctx, split_sizes, sorted_indices, *, num_tokens, num_splits): }, ) + @staticmethod + def infer_sharding_from_operands(num_tokens, num_splits, mesh, arg_infos, result_infos): + """Infer output sharding from input sharding.""" + del num_tokens, num_splits, result_infos, arg_infos + # row_id_map is replicated since split_sizes and sorted_indices are typically small + return NamedSharding( + mesh, + PartitionSpec(None), + desc="MakeChunkSortMap.row_id_map_sharding", + ) + + @staticmethod + def partition(num_tokens, num_splits, mesh, arg_infos, result_infos): + """Partition the primitive for distributed execution.""" + del result_infos + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + + out_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="MakeChunkSortMap.row_id_map_sharding", + ) + + def sharded_impl(split_sizes, sorted_indices): + return MakeChunkSortMapPrimitive.impl( + split_sizes, + sorted_indices, + num_tokens=num_tokens, + num_splits=num_splits, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(num_tokens, num_splits, mesh, value_types, result_types): + """Shardy sharding rule for this primitive.""" + del num_tokens, num_splits, mesh, value_types, result_types + prefix = "MakeChunkSortMap" + split_sizes_spec = (f"{prefix}_splits",) + sorted_indices_spec = (f"{prefix}_splits",) + row_id_map_spec = (f"{prefix}_tokens",) + + return SdyShardingRule( + (split_sizes_spec, sorted_indices_spec), + (row_id_map_spec,), + ) + register_primitive(MakeChunkSortMapPrimitive) @@ -1228,6 +1723,91 @@ def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward }, ) + @staticmethod + def infer_sharding_from_operands( + num_tokens, hidden_size, is_forward, with_probs, mesh, arg_infos, result_infos + ): + """Infer output sharding from input sharding.""" + del num_tokens, hidden_size, is_forward, result_infos + inp_spec = get_padded_spec(arg_infos[0]) + output_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0], None), + desc="SortChunksByMap.output_sharding", + ) + if with_probs: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0]), + desc="SortChunksByMap.permuted_probs_sharding", + ) + else: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="SortChunksByMap.permuted_probs_sharding_empty", + ) + return [output_sharding, permuted_probs_sharding] + + @staticmethod + def partition(num_tokens, hidden_size, is_forward, with_probs, mesh, arg_infos, result_infos): + """Partition the primitive for distributed execution.""" + del num_tokens, result_infos + inp_spec = get_padded_spec(arg_infos[0]) + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + + output_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0], None), + desc="SortChunksByMap.output_sharding", + ) + if with_probs: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(inp_spec[0]), + desc="SortChunksByMap.permuted_probs_sharding", + ) + else: + permuted_probs_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="SortChunksByMap.permuted_probs_sharding_empty", + ) + out_shardings = [output_sharding, permuted_probs_sharding] + + def sharded_impl(inp, row_id_map, probs): + local_num_tokens = inp.shape[0] + return SortChunksByMapPrimitive.impl( + inp, + row_id_map, + probs, + num_tokens=local_num_tokens, + hidden_size=hidden_size, # hidden_size is not sharded + is_forward=is_forward, + with_probs=with_probs, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + num_tokens, hidden_size, is_forward, with_probs, mesh, value_types, result_types + ): + """Shardy sharding rule for this primitive.""" + del num_tokens, hidden_size, is_forward, mesh, value_types, result_types + prefix = "SortChunksByMap" + inp_spec = (f"{prefix}_tokens", f"{prefix}_hidden") + row_id_map_spec = (f"{prefix}_tokens",) + probs_spec = (f"{prefix}_tokens",) if with_probs else (f"{prefix}_empty",) + output_spec = (f"{prefix}_tokens", f"{prefix}_hidden") + permuted_probs_spec = (f"{prefix}_tokens",) if with_probs else (f"{prefix}_empty2",) + + return SdyShardingRule( + (inp_spec, row_id_map_spec, probs_spec), + (output_spec, permuted_probs_spec), + ) + register_primitive(SortChunksByMapPrimitive) @@ -1356,6 +1936,7 @@ def permute_with_mask_map( hidden_size=hidden_size, with_probs=with_probs, with_pad=False, + align_size=128, # Default value, no-op for non-padding case ) if not with_probs: @@ -1373,6 +1954,7 @@ def permute_with_mask_map_and_pad( num_experts: int, num_out_tokens: int, hidden_size: int, + align_size: int = 128, ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """ Permute the input tensor based on the row_id_map with fused padding. @@ -1395,13 +1977,18 @@ def permute_with_mask_map_and_pad( Number of tokens in the permuted tensor (including padding). hidden_size : int Hidden size of the input tensor. + align_size : int + Alignment size for padding (default: 128). Used for distributed sharding + to correctly compute local buffer sizes. Returns ------- output : jnp.ndarray Permuted and padded output tensor of shape `[num_out_tokens, hidden_size]`. + Padding positions are zero-filled. permuted_probs : Optional[jnp.ndarray] Permuted probabilities if probs was provided, None otherwise. + Padding positions are zero-filled. """ with_probs = probs is not None @@ -1426,8 +2013,14 @@ def permute_with_mask_map_and_pad( hidden_size=hidden_size, with_probs=with_probs, with_pad=True, + align_size=align_size, ) + # Note: Zero-filling of padding positions is handled by pre-zeroing the output + # buffers in impl() using jnp.zeros(), then aliasing them to the kernel's outputs + # via input_output_aliases. The kernel only writes to valid positions, leaving + # padding positions at zero. + if not with_probs: permuted_probs = None @@ -1479,7 +2072,7 @@ def unpermute_with_mask_map( merging_probs = jnp.zeros((0,), dtype=inp.dtype) if not with_probs: permuted_probs = jnp.zeros((0,), dtype=inp.dtype) - # Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature) + # Create dummy pad_offsets (not used when with_unpad=False, but required by kernel signature) dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind( @@ -1493,6 +2086,7 @@ def unpermute_with_mask_map( hidden_size=hidden_size, with_merging_probs=with_merging_probs, with_probs=with_probs, + with_unpad=False, ) if not with_probs: @@ -1550,7 +2144,7 @@ def unpermute_with_mask_map_and_unpad( if not with_probs: permuted_probs = jnp.zeros((0,), dtype=inp.dtype) - output, unpermuted_probs = UnpermuteWithMaskMapAndUnpadPrimitive.outer_primitive.bind( + output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind( inp, row_id_map, merging_probs, @@ -1561,6 +2155,7 @@ def unpermute_with_mask_map_and_unpad( hidden_size=hidden_size, with_merging_probs=with_merging_probs, with_probs=with_probs, + with_unpad=True, ) if not with_probs: diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 064b2843c6..979d127128 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -249,7 +249,8 @@ def lowering(ctx, x, *, block_size): kernel_constexprs = constexprs if constexprs is not None else {} # Handle autotuned kernels - compile all configs - if isinstance(kernel_fn, autotuner.Autotuner): + is_autotuned = isinstance(kernel_fn, autotuner.Autotuner) + if is_autotuned: # Compile all configs for runtime selection kernel_calls = [] actual_kernel_fn = kernel_fn.fn @@ -290,24 +291,23 @@ def lowering(ctx, x, *, block_size): kernel_calls.append((config_call, str(config))) - # Create autotuned kernel call - # Convert input_output_aliases to format with sizes - if input_output_aliases is None: - input_output_aliases = {} - - input_output_aliases_with_sizes = tuple( - ( - input_idx, - output_idx, - ctx.avals_in[input_idx].size * ctx.avals_in[input_idx].dtype.itemsize, - ) - for input_idx, output_idx in input_output_aliases.items() - ) - + # IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes. + # + # Background: + # 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an + # output can reuse an input's buffer. XLA may or may not honor this. + # 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers + # save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701). + # + # The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx], + # but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries + # to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy. + # + # WAR: Don't pass aliases to TritonAutotunedKernelCall. kernel_call = gpu_triton.TritonAutotunedKernelCall( f"{actual_kernel_fn.__name__}_autotuned", kernel_calls, - input_output_aliases_with_sizes, + (), # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc ) else: @@ -338,15 +338,17 @@ def lowering(ctx, x, *, block_size): serialized_metadata = b"" call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata) - if input_output_aliases is None: - input_output_aliases = {} + if input_output_aliases: + ffi_operand_output_aliases = input_output_aliases + else: + ffi_operand_output_aliases = None # Use JAX FFI lowering with compressed protobuf rule = jax.ffi.ffi_lowering( "triton_kernel_call", # Custom call target registered in gpu_triton.py api_version=2, backend_config=zlib.compress(call_proto), - operand_output_aliases=input_output_aliases, + operand_output_aliases=ffi_operand_output_aliases, ) return rule(ctx, *array_args) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 8c9003bb5f..6b5de9ab0f 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -157,8 +157,8 @@ def permute_with_mask_map( scale_hidden_dim : int Hidden size of the scale tensor. """ - # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed, - # since the kernel doesn't write to padding positions. + # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed. + # The kernel writes only to valid positions, leaving padding positions at zero. alloc = torch.zeros if pad_offsets is not None else torch.empty output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") permuted_probs = ( @@ -178,7 +178,13 @@ def permute_with_mask_map( scale, permuted_scale, pad_offsets, + # Pass output buffers as input parameters (for JAX input_output_aliases compatibility). + # In PyTorch, these point to the same memory as the output pointers below. + output, + permuted_probs, scale_hidden_dim, + num_tokens, + num_out_tokens, row_id_map.stride(0), row_id_map.stride(1), inp.stride(0), @@ -252,6 +258,10 @@ def unpermute_with_mask_map( merging_probs, permuted_probs, pad_offsets, + # Dummy buffer parameters for kernel signature consistency with _permute_kernel. + # These are unused in unpermute but maintain consistent interface. + output, # output_buf_ptr (unused, passed for signature consistency) + unpermuted_probs, # unpermuted_probs_buf_ptr (unused, passed for signature consistency) row_id_map.stride(0), row_id_map.stride(1), inp.stride(0),