Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,23 +474,22 @@ def lowering(ctx, x, *, block_size):

kernel_calls.append((config_call, str(config)))

# IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes.
#
# Background:
# 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an
# output can reuse an input's buffer. XLA may or may not honor this.
# 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers
# save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701).
#
# The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx],
# but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries
# to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy.
#
# WAR: Don't pass aliases to TritonAutotunedKernelCall.
input_output_aliases_with_sizes = ()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should guard this on the JAX version corresponding to the version that has the fixed XLA version as users may still be on older JAX/XLA. I think currently that will be a dev nightly release, so we can guard on that for now and then add a TODO(tdophung) to update to an official release in the future when it's included in the next release

if input_output_aliases:
num_inputs = len(ctx.avals_in)
aliases = []
for input_idx, output_idx in input_output_aliases.items():
aval = ctx.avals_in[input_idx]
size_bytes = aval.size * jnp.dtype(aval.dtype).itemsize
# AutotunedKernelCall expects buffer indices (inputs + outputs).
buffer_output_idx = num_inputs + output_idx
aliases.append((input_idx, buffer_output_idx, size_bytes))
input_output_aliases_with_sizes = tuple(aliases)

kernel_call = gpu_triton.TritonAutotunedKernelCall(
f"{actual_kernel_fn.__name__}_autotuned",
kernel_calls,
(), # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc
input_output_aliases_with_sizes,
)

else:
Expand Down
Loading