diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 28e3f08e18..ebec1b3cc9 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -43,8 +43,10 @@ import jax.numpy as jnp from ..version_utils import ( + TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION, TRITON_EXTENSION_MIN_JAX_VERSION, is_triton_extension_supported, + jax_version_meet_requirement, ) @@ -474,23 +476,31 @@ def lowering(ctx, x, *, block_size): kernel_calls.append((config_call, str(config))) - # IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes. - # - # Background: - # 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an - # output can reuse an input's buffer. XLA may or may not honor this. - # 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers - # save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701). - # - # The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx], - # but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries - # to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy. - # - # WAR: Don't pass aliases to TritonAutotunedKernelCall. + input_output_aliases_with_sizes = () + if input_output_aliases: + if jax_version_meet_requirement(TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION): + num_inputs = len(ctx.avals_in) + aliases = [] + for input_idx, output_idx in input_output_aliases.items(): + aval = ctx.avals_in[input_idx] + size_bytes = aval.size * jnp.dtype(aval.dtype).itemsize + # AutotunedKernelCall expects buffer indices (inputs + outputs). + buffer_output_idx = num_inputs + output_idx + aliases.append((input_idx, buffer_output_idx, size_bytes)) + input_output_aliases_with_sizes = tuple(aliases) + else: + warnings.warn( + f"JAX >= {TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION} is required " + "to safely pass input_output_aliases to TritonAutotunedKernelCall. " + "Passing empty aliases as a workaround (jax-ml/jax#35218).", + UserWarning, + stacklevel=2, + ) + kernel_call = gpu_triton.TritonAutotunedKernelCall( f"{actual_kernel_fn.__name__}_autotuned", kernel_calls, - (), # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc + input_output_aliases_with_sizes, ) else: diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py index 04b7ff879a..63598481a2 100644 --- a/transformer_engine/jax/version_utils.py +++ b/transformer_engine/jax/version_utils.py @@ -25,6 +25,15 @@ def jax_version_meet_requirement(version: str): # Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults). TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0" +# Minimum JAX version for safe input_output_aliases in TritonAutotunedKernelCall. +# jaxlib/gpu/triton_kernels.cc had a bug in the autotuning save/restore loop: +# it iterated over all declared aliases unconditionally, but input_copies only +# contains entries for aliases where XLA actually shared buffers at runtime. +# Accessing a missing entry produced a null vector → CUDA_ERROR_INVALID_VALUE. +# Fixed by: https://github.com/jax-ml/jax/pull/35218 (merged 2026-03-17, main). +# Ships in JAX 0.9.3 (not yet released as of 2026-03-31). +TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION = "0.9.3" + def is_triton_extension_supported() -> bool: """Return True if the current JAX version supports Triton kernel dispatch. @@ -40,4 +49,5 @@ def is_triton_extension_supported() -> bool: "jax_version_meet_requirement", "is_triton_extension_supported", "TRITON_EXTENSION_MIN_JAX_VERSION", + "TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION", ]