Skip to content

Pass input_output_alias to TritonAutotunedKernelCall#2814

Open
tdophung wants to merge 2 commits intoNVIDIA:mainfrom
tdophung:sort_chunks_WAR_p2
Open

Pass input_output_alias to TritonAutotunedKernelCall#2814
tdophung wants to merge 2 commits intoNVIDIA:mainfrom
tdophung:sort_chunks_WAR_p2

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Mar 31, 2026

Description

https://nvbugspro.nvidia.com/bug/5810384
To remove the WAR that was put in place for this bug.

This should also serves as part 2 to WAR to the intermittent sort_chunks_by_index bug seen before in #2730

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: JAX Toolbox <jax@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 31, 2026

Greptile Summary

This PR removes the workaround that was passing an empty input_output_aliases_with_sizes tuple to TritonAutotunedKernelCall, and replaces it with the correct alias computation. The workaround was masking a bug in jaxlib/gpu/triton_kernels.cc where the restore phase would unconditionally access input_copies[input_idx] even when the save phase had skipped saving (because XLA did not actually alias the buffers), leading to a CUDA_ERROR_INVALID_VALUE. The referenced internal bug (5810384) claims the C++ bug is now fixed.

Key changes:

  • Computes (input_idx, buffer_output_idx, size_bytes) tuples for each alias entry in input_output_aliases, where buffer_output_idx = num_inputs + output_idx to match the C++ buffer-indexing convention (inputs first, then outputs).
  • Passes the computed tuple to TritonAutotunedKernelCall instead of the previously hardcoded empty ().
  • This also fixes a silent correctness issue: without proper save/restore, subsequent autotuning configs were running on already-overwritten input data, which could select the wrong config (the root cause of the intermittent sort_chunks_by_index bug from PR WAR sort_chunks_by_index intermittent failures in L0 JAX unitttest #2730).

Minor note: size_bytes is derived from the input aval's shape/dtype; an assertion that the aliased output has the same shape/dtype would make the assumption explicit and catch any unexpected future misuse.

Confidence Score: 5/5

  • Safe to merge — the alias computation is logically correct and removes both the crash risk and the silent autotuning correctness bug.
  • The only finding is a P2 style suggestion (add an assertion that aliased input/output shapes match). The core logic — computing buffer indices with num_inputs + output_idx and byte sizes with aval.size * itemsize — is correct. The WAR removal is well-motivated by the referenced C++ bugfix, and the previous workaround was already causing silent correctness issues in autotuning.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/triton_extensions/utils.py Removes the empty-alias workaround for the jaxlib save/restore bug and properly computes input_output_aliases_with_sizes for TritonAutotunedKernelCall; correctness depends on the referenced C++ bug being fixed in all targeted jaxlib versions.

Sequence Diagram

sequenceDiagram
    participant Caller as triton_call_lowering
    participant Alias as alias computation (new)
    participant ATCK as TritonAutotunedKernelCall (C++)
    participant CUDA as CUDA runtime

    Caller->>Alias: iterate input_output_aliases.items()
    Alias-->>Caller: (input_idx, num_inputs+output_idx, size_bytes) tuples
    Caller->>ATCK: TritonAutotunedKernelCall(name, kernel_calls, aliases_with_sizes)

    loop for each autotuning config
        ATCK->>CUDA: cudaMemcpy — save aliased input buffer (save phase)
        ATCK->>CUDA: launch kernel config N
        ATCK->>CUDA: record timing
        ATCK->>CUDA: cudaMemcpy — restore original input buffer (restore phase)
    end

    ATCK-->>Caller: best config selected, correct input state preserved
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants