Skip to content

[JAX] Warmup FFIs with "initialize" stage#2800

Open
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/warmup-xla-ffis
Open

[JAX] Warmup FFIs with "initialize" stage#2800
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/warmup-xla-ffis

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Mar 25, 2026

Description

Add "initialize" stage to TE FFIs that didn't previously have them.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add initialize FFI handlers in JAX .cpp extensions
  • Register them as "initialize" stage in pybind.cpp

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 25, 2026 21:19
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 25, 2026

Greptile Summary

This PR adds XLA FFI "initialize" stage handlers to warm up CUDA/cuDNN kernels for JAX custom calls, addressing JIT compilation latency on first use. Each new *InitializeFFI function wraps its corresponding execute FFI in wrapInStreamCapture, which begins CUDA stream capture (in relaxed mode), invokes the original kernel to trigger JIT compilation, ends capture, then immediately destroys the captured graph — so no side effects occur. The warmed-up kernels then run with lower latency when the "execute" stage fires for real.

Key changes:

  • New initialize handlers added for: DBiasQuantize, Dequantize, all six softmax variants, FusedAttnForward/Backward, GemmV2, Gemm (deprecated), and all four router operations.
  • Each new handler is correctly registered in pybind.cpp as the initialize key in its FFI dict alongside the existing execute key.
  • GroupedQuantizeFFI is correctly excluded from getting an initialize handler — the function calls cudaStreamSynchronize(stream) internally, which is incompatible with CUDA stream capture (as noted by the existing comment "Note: This may break cudaGraph").
  • ScaledMaskedSoftmaxBackwardInitializeHandler correctly reuses ScaledSoftmaxBackwardInitializeFFI, consistent with how the execute-stage handler for the same op already reuses ScaledSoftmaxBackwardFFI.

Confidence Score: 5/5

  • Safe to merge; the new initialize handlers are mechanically correct and consistent with existing patterns.
  • All new code follows the established wrapInStreamCapture pattern already used for normalization and other handlers. No P0/P1 issues found. The only notable observation (exception safety in wrapInStreamCapture) is a pre-existing limitation in a helper not modified by this PR, surfaced here only because the scope of handlers using it has grown.
  • No files require special attention. ffi.h (not in the diff) warrants a follow-up to address exception safety in wrapInStreamCapture.

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions.h Adds XLA_FFI_DECLARE_HANDLER_SYMBOL declarations for all new initialize-stage handlers (quantization, softmax, attention, GEMM, router).
transformer_engine/jax/csrc/extensions/quantization.cpp Adds DBiasQuantizeInitializeFFI and DequantizeInitializeFFI using wrapInStreamCapture; GroupedQuantize correctly omitted (incompatible with stream capture due to cudaStreamSynchronize).
transformer_engine/jax/csrc/extensions/softmax.cpp Adds initialize-stage wrappers for all six softmax handlers; ScaledMaskedSoftmaxBackwardInitializeHandler correctly reuses ScaledSoftmaxBackwardInitializeFFI, consistent with the execute-stage handler.
transformer_engine/jax/csrc/extensions/attention.cpp Adds FusedAttnForwardInitializeFFI and FusedAttnBackwardInitializeFFI via wrapInStreamCapture; note that for cuDNN < 9.3.0 with ragged (THD) format, nvte_get_runtime_num_segments is called inside capture which does a D2H copy (pre-existing limitation).
transformer_engine/jax/csrc/extensions/gemm.cpp Adds GemmInitializeFFI and GemmV2InitializeFFI; GemmInitializeFFI wraps the deprecated GemmFFI (which has std::once_flag deprecation warnings that will fire during the initialize stage, a minor cosmetic side effect).
transformer_engine/jax/csrc/extensions/pybind.cpp Registers all new initialize handlers in the Registrations() dict; all updated entries consistently include both "initialize" and "execute" keys, except te_grouped_quantize_ffi which is correctly kept as a bare execute handler.
transformer_engine/jax/csrc/extensions/router.cpp Adds initialize-stage wrappers for all four router handlers (FusedTopkWithScoreFunctionForward/Backward, FusedMoEAuxLossForward/Backward) following the same pattern.

Sequence Diagram

sequenceDiagram
    participant JAX as JAX/XLA Runtime
    participant Prepare as prepare stage
    participant Init as initialize stage<br/>(new in this PR)
    participant Exec as execute stage

    JAX->>Prepare: prepare (e.g. CudnnHandleInitHandler)
    Note over Prepare: cuDNN/cuBLAS handles init,<br/>UB buffer allocation

    JAX->>Init: initialize (e.g. FusedAttnForwardInitializeHandler)
    Note over Init: wrapInStreamCapture(...)
    Init->>Init: cudaStreamBeginCapture(stream, Relaxed)
    Init->>Init: OriginalFFI(stream, args...)
    Note over Init: CUDA/cuDNN kernels JIT-compiled<br/>but not executed
    Init->>Init: cudaStreamEndCapture(stream, &graph)
    Init->>Init: cudaGraphDestroy(graph)
    Note over Init: Graph discarded — warmup only

    JAX->>Exec: execute (e.g. FusedAttnForwardHandler)
    Note over Exec: Actual kernel execution<br/>uses pre-warmed JIT cache
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review March 30, 2026 19:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant