Skip to content

Conversation

@cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Jan 12, 2026

Description

This PR enables determinism for FusedAttention on Blackwell for FP16/BF16 precisions and cuDNN >= 9.18.0.

To run with determinism, please set this flag: export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please see Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa changed the title [Common] Enable determinism for SDPA on Blackwell [Common] Enable determinism for cuDNN >= 9.18 on Blackwell Jan 12, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Summary

This PR enables deterministic FusedAttention on Blackwell GPUs (SM 100+) for FP16/BF16 with cuDNN >= 9.18.0.

Key Changes:

  • Added deterministic parameter to nvte_get_fused_attn_backend() API and propagated through JAX and PyTorch layers
  • Implemented Blackwell-specific backend selection logic: non-deterministic backward (cuDNN 9.7+) requires dropout=0 OR bias=NONE; deterministic backward (cuDNN 9.18+) requires dropout=0 AND bias=NONE
  • Updated cudnn-frontend submodule to support new deterministic algorithms
  • Removed blanket Blackwell determinism check from PyTorch utils
  • Added comprehensive test coverage with NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 in CI for both PyTorch and JAX
  • Fixed NVTE_UNFUSED_ATTN environment variable handling in PyTorch tests

Implementation Notes:

  • Forward passes always use deterministic algorithms (existing behavior)
  • Backward passes can now use deterministic or non-deterministic algorithms based on user preference and hardware/library support
  • The logic correctly handles version-gated features and constraint validation

Confidence Score: 4/5

  • This PR is safe to merge with careful testing on Blackwell hardware with cuDNN 9.18+
  • The implementation is sound and addresses the previous review comments. The core logic for backend selection with determinism constraints is correct, API changes are properly propagated through all layers, and comprehensive test coverage has been added. The score reflects that this is a hardware and library-specific feature requiring validation on actual Blackwell GPUs with the specified cuDNN version.
  • Most important file to review is transformer_engine/common/fused_attn/fused_attn.cpp:447-452 to verify the Blackwell determinism logic matches cuDNN 9.18 capabilities

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds deterministic parameter to backend selection logic and implements Blackwell cuDNN 9.18.0+ determinism support with proper constraints for dropout and bias
transformer_engine/jax/cpp_extensions/attention.py Updates JAX backend assertions for Blackwell determinism requirements and passes deterministic flag to C++ layer
transformer_engine/pytorch/attention/dot_product_attention/utils.py Removes blanket Blackwell determinism check, now delegates to backend selection logic with deterministic parameter
tests/jax/test_fused_attn.py Updates test skip conditions to match new Blackwell determinism constraints for cuDNN 9.7+ and 9.18+
tests/pytorch/attention/test_attention.py Adds deterministic flag reading from environment and passes to backend selection calls, plus fixes NVTE_UNFUSED_ATTN handling

Sequence Diagram

sequenceDiagram
    participant User
    participant PyTorch/JAX
    participant Backend Selection
    participant cuDNN Frontend
    participant Forward Pass
    participant Backward Pass

    User->>PyTorch/JAX: Set NVTE_ALLOW_NONDETERMINISTIC_ALGO
    Note over User,PyTorch/JAX: 0=deterministic, 1=non-deterministic

    PyTorch/JAX->>Backend Selection: get_fused_attn_backend(deterministic)
    Note over Backend Selection: New parameter: deterministic

    alt Blackwell (sm_arch >= 100) Training
        Backend Selection->>Backend Selection: Check cuDNN version & constraints
        alt Non-deterministic (cuDNN >= 9.7.0)
            Note over Backend Selection: Requires: dropout=0 OR bias=NONE
        else Deterministic (cuDNN >= 9.18.0)
            Note over Backend Selection: Requires: dropout=0 AND bias=NONE
        end
    end

    Backend Selection->>PyTorch/JAX: Return backend (arbitrary_seqlen or max512)

    PyTorch/JAX->>Forward Pass: nvte_fused_attn_fwd(deterministic=false)
    Note over Forward Pass: Always uses deterministic algorithm

    Forward Pass->>cuDNN Frontend: Execute deterministic forward
    cuDNN Frontend-->>Forward Pass: Return O, aux tensors

    alt Training Mode
        PyTorch/JAX->>Backward Pass: nvte_fused_attn_bwd(deterministic)
        Note over Backward Pass: Uses actual deterministic flag
        
        alt Deterministic
            Backward Pass->>cuDNN Frontend: Execute deterministic backward (9.18+)
        else Non-deterministic
            Backward Pass->>cuDNN Frontend: Execute non-deterministic backward (9.7+)
        end
        
        cuDNN Frontend-->>Backward Pass: Return dQ, dK, dV
    end

    Backward Pass-->>User: Gradients
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Overview

Greptile Summary

Overview

This PR enables determinism for FusedAttention on Blackwell GPUs (SM 100) with cuDNN version 9.18.0 or higher. The implementation moves determinism checking logic from Python to the C++ backend selection layer.

Key Changes

  1. Backend Selection Logic: Added a new condition in nvte_get_fused_attn_backend() that disables the arbitrary sequence length backend for Blackwell when:

    • Training mode is enabled
    • Determinism is required
    • Any of: cuDNN < 9.18.0, bias is used, or dropout > 0
  2. API Updates: Added deterministic parameter to the backend selection function across Python, C++, and JAX interfaces. Forward passes hardcode deterministic=true while backward passes accept it as a parameter.

  3. Code Migration: Moved Blackwell determinism checks from Python (utils.py) to C++ backend selection, consolidating version, bias, and dropout checks in one place.

  4. Test Infrastructure: Added environment variable NVTE_ALLOW_NONDETERMINISTIC_ALGO to control determinism in tests, and added explicit NVTE_UNFUSED_ATTN=0 settings to ensure proper backend isolation.

  5. Dependency Update: Updated cudnn-frontend submodule to version 1.17 to support the new determinism features.

Architecture

The change follows a layered approach:

  • User API Level: Python tests set deterministic flag via environment variable or torch settings
  • Python Layer: Extracts deterministic flag and passes to C++ extension
  • C++ Backend Selection: Evaluates hardware, cuDNN version, bias, and dropout to determine if deterministic FusedAttention is supported
  • Execution: If requirements aren't met, falls back to other backends (FlashAttention or UnfusedDotProductAttention)

The implementation correctly restricts deterministic FusedAttention to cases where cuDNN guarantees deterministic behavior, avoiding silent non-determinism.

Confidence Score: 4/5

  • This PR is safe to merge with minor issues that should be addressed
  • The implementation is sound and correctly adds determinism support for Blackwell GPUs. The core logic properly checks cuDNN version, bias, and dropout constraints. However, two issues lower the confidence: (1) inconsistent tab/space indentation in the critical condition on line 444 of fused_attn.cpp, and (2) duplicate XML output file in test.sh causing test results to be overwritten. Both are non-critical but should be fixed before merge.
  • Pay attention to transformer_engine/common/fused_attn/fused_attn.cpp (line 444 indentation) and qa/L0_pytorch_unittest/test.sh (line 48 XML filename collision)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 4/5 Added determinism check for Blackwell (sm100) to disable FusedAttention when cuDNN < 9.18.0 or bias/dropout are used. Contains tab indentation inconsistency on line 444.
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Removed Python-side Blackwell determinism check, now handled in C++. Added deterministic parameter to backend selection call.
tests/pytorch/attention/test_attention.py 5/5 Added deterministic flag from environment variable and torch settings. Updated tests to explicitly set NVTE_UNFUSED_ATTN=0 to ensure correct backend isolation.
qa/L0_pytorch_unittest/test.sh 3/5 Added deterministic test run with NVTE_ALLOW_NONDETERMINISTIC_ALGO=0. Both test runs write to same XML file causing results to be overwritten.

Sequence Diagram

sequenceDiagram
    participant User as User/Test
    participant PyAPI as Python API
    participant Utils as utils.py
    participant CppExt as C++ Extensions
    participant Backend as Backend Selection
    participant cuDNN as cuDNN Library

    User->>PyAPI: Call attention with deterministic=True
    PyAPI->>Utils: get_attention_backend(params)
    Utils->>Utils: Extract deterministic from params
    Utils->>CppExt: get_fused_attn_backend(..., deterministic)
    CppExt->>Backend: nvte_get_fused_attn_backend(..., deterministic)
    
    alt Blackwell (sm_arch >= 100) & Training & Deterministic
        Backend->>Backend: Check cuDNN version >= 9.18.0
        Backend->>Backend: Check bias_type == NO_BIAS
        Backend->>Backend: Check dropout == 0.0
        alt All checks pass
            Backend-->>CppExt: F16_arbitrary_seqlen backend
        else Any check fails
            Backend-->>CppExt: No_Backend (disabled)
        end
    else Other architectures or inference
        Backend->>Backend: Apply standard backend selection
        Backend-->>CppExt: Selected backend
    end
    
    CppExt-->>Utils: Backend choice
    Utils-->>PyAPI: Backend configuration
    
    alt Forward Pass
        PyAPI->>CppExt: nvte_fused_attn_fwd(..., deterministic=true)
        Note over PyAPI,CppExt: Forward always uses deterministic=true
    else Backward Pass
        PyAPI->>CppExt: nvte_fused_attn_bwd(..., deterministic)
        Note over PyAPI,CppExt: Backward respects user's deterministic flag
    end
    
    CppExt->>cuDNN: Execute attention operation
    cuDNN-->>CppExt: Results
    CppExt-->>PyAPI: Output tensors
    PyAPI-->>User: Attention output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

make .xml file specific to deterministic tests in qa/

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 3 commits January 13, 2026 06:00
fix typo

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
fix indentation

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

@cyanguwa
Copy link
Collaborator Author

/te-ci jax L0

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 3 commits January 15, 2026 06:57
Signed-off-by: Charlene Yang <[email protected]>
fix and/or logic

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@liayan
Copy link

liayan commented Jan 16, 2026

Cool, we are currently suffering from this issue.
Do we have a rough timeline for when it could be merged?
Let me know if there is anything I can do, such as a test. Would like to help.

Copy link
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments - some suggested changes and some questions.
Looks good to me, otherwise. Approving to not block from merge, if urgent.

It would be helpful, if you have a table for what's supported for <cuDNN9.18, >=cuDNN9.18, <sm100, sm100+, drop, dbias, etc. in the PR description.

I would also suggest to look into the number of tests being run and the timing (you can compare your PRs L0 jax and L0 pyt timings to the timings in TE 2.11 or in TE main CI - we would not want to go overboard with our timing budget, for sure. If you can report the timing in the PR, it would be helpful as well.
Worst case, if urgent, we can merge this PR and address the QA bit (which runs in the CI) in a separate PR subsequently .

Lastly, this might be some effort but would ensure correctness. As the code for skipping the tests in TE JAX tests has been modified, it would be good to check the test count before and after this PR to check if tests that should not be skipped are incorrectly being skipped

mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this will first run the non-deterministic fused attn tests as part of L31, which runs all non distributed tests, followed by running the fused attn deterministic tests as part of L32.
Is that the intention ? - to run fused attn 2x - with and without determinism ?

That will greatly increase our test time and might be unnecessary. The last pipeline launched was for L1 so I am unsure that I can track the effect this change will have on timing as this is an L0 change. Could you report that in the PR please ?
Thanks !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could come with an approach that runs half the fused attn tests deterministically and the other half non-deterministically ?
Or run all deterministically only ?

Comment on lines 47 to +48
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that part of my question for the JAX side L0 test is answered here.
Seems like the intention is to run 2x attention tests - with and without determinism.
I think we'd have to think more about this as I'd assume this would consume substantial testing budget.
Thoughts ?

float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?

window_size[1],
return_max_logit,
cuda_graph,
deterministic,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?

float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool deterministic);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?


os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just an effort to be explicit with the env vars and has nothing to do with determinism directly, right ?

if any(x >= 100 for x in compute_capabilities) and is_training:
assert (
FusedAttnHelper.is_non_deterministic_allowed()
and get_cudnn_version() >= (9, 7, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the min cuDNN version check for ? i.e. what was supported cuDNN 9.7 onwards ?

assert (
FusedAttnHelper.is_non_deterministic_allowed()
and get_cudnn_version() >= (9, 7, 0)
and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that :

  1. =sm100 + dropout + no dbias = supported but non deterministic as dropout requires choosing a non deterministic kernel

  2. =sm100 + no dropout + dbias = not supported as dbias requires choosing the deterministic path

  3. =sm100 + no dropout + no dbias = supported

If this is true wouldn't case #2 falsely pass even though not supported ?
Or is my understanding incorrect ?

"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
)

if get_device_compute_capability(0) >= 100 and self.is_training:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the is_training flag in the check

(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
):
pytest.skip(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sm100+, non-deterministic bprop (cuDNN 9.7+), ONLY bias or ONLY dropout is supported but no both t the same time,right ?

@KshitijLakhani
Copy link
Collaborator

/te-ci L0 L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants