Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
015d346
add fp8 determinism support
sudhakarsingh27 Jan 23, 2026
6b71500
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fp8_d…
sudhakarsingh27 Jan 24, 2026
d785c52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2026
8ce534d
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fp8_d…
sudhakarsingh27 Feb 10, 2026
1cfc2ce
update cudnn fe to 1.18
sudhakarsingh27 Feb 11, 2026
0fe3ab0
Merge branch 'fp8_determinism_sm100' of github.com:sudhakarsingh27/Tr…
sudhakarsingh27 Feb 11, 2026
ae7ff3b
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fp8_d…
sudhakarsingh27 Feb 11, 2026
8a98792
Merge branch 'main' into fp8_determinism_sm100
sudhakarsingh27 Feb 17, 2026
37e9d28
Merge branch 'main' into fp8_determinism_sm100
sudhakarsingh27 Feb 18, 2026
cbcc973
resolve conflicts while mergin main
sudhakarsingh27 Feb 20, 2026
1c684b7
enable determinism for sm90
sudhakarsingh27 Feb 23, 2026
9e04dcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2026
96aa4e0
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fp8_d…
sudhakarsingh27 Feb 23, 2026
9483df4
Update transformer_engine/pytorch/attention/dot_product_attention/uti…
sudhakarsingh27 Feb 23, 2026
923db5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2026
ce2dc79
Apply suggestion from @greptile-apps[bot]
sudhakarsingh27 Feb 23, 2026
8b0c874
remove extraneous `deterministic` test input arg
sudhakarsingh27 Feb 24, 2026
75cd00d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,10 +1834,16 @@ def get_model(dtype, config):
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
dtype,
model,
qkv_format,
input_layernorm,
fp8_dpa_bwd,
RoPE,
is_training,
scaling_mode,
):
"""Test MultiHeadAttention module in FP8"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]

Expand Down Expand Up @@ -2094,7 +2100,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
# config.dropout_p = 0.1

os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"

# Test backend availability
Expand Down
22 changes: 11 additions & 11 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,10 @@ void nvte_fused_attn_bwd_qkvpacked(
Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride);

fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO,
input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view,
input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream,
handle);
bias_type, attn_mask_type, deterministic, &Q_view, &K_view, &V_view, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view,
&dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace,
stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
Expand Down Expand Up @@ -1087,10 +1087,10 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride);

fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view,
&dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace,
stream, handle);
qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, &K_view,
&V_view, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP,
output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
Expand Down Expand Up @@ -1323,9 +1323,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv,
qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K,
input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP,
output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
Expand Down
41 changes: 23 additions & 18 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1982,13 +1982,13 @@ void fused_attn_fp8_fwd_impl_v1(
void fused_attn_fp8_bwd_impl_v1(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor,
float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM,
void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO,
void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS,
void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed,
NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK,
void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP,
void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK,
void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK,
void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed,
void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type,
cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size,
Expand All @@ -2003,6 +2003,7 @@ void fused_attn_fp8_bwd_impl_v1(
bool is_dropout = (dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
const auto cudnn_runtime_version = cudnnGetVersion();
auto bias_sq = s_q;
auto bias_skv = s_kv;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
Expand Down Expand Up @@ -2045,7 +2046,7 @@ void fused_attn_fp8_bwd_impl_v1(
0,
0,
true,
false,
deterministic,
qkv_tensor_type,
o_tensor_type,
do_tensor_type,
Expand Down Expand Up @@ -2216,6 +2217,10 @@ void fused_attn_fp8_bwd_impl_v1(
// }
// }

if (cudnn_runtime_version >= 91900) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
Comment on lines +2220 to +2222
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Version check uses 91900 (cuDNN 9.19.0), but related PR #2584 and description mention 9.18.1+ requirement. Should this be 91810 instead?

Suggested change
if (cudnn_runtime_version >= 91900) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
if (cudnn_runtime_version >= 91810) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}

Is there a specific reason FP8 requires cuDNN 9.19.0+ while FP16/BF16 only needs 9.18.1+?


if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
Expand Down Expand Up @@ -2519,11 +2524,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q,
const Tensor* input_K, const Tensor* input_V, const Tensor* input_O,
const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv,
const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ,
const Tensor* output_dK, const Tensor* output_dV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic,
const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V,
const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M,
const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP,
const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV,
const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv,
const Tensor* rng_state, Tensor* workspace, cudaStream_t stream,
cudnnHandle_t handle) {
Expand Down Expand Up @@ -2581,11 +2586,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK,
devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ,
devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS,
devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/common/fused_attn/fused_attn_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv,
const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ,
const Tensor *output_dK, const Tensor *output_dV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M,
const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1067,8 +1067,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_fused_attention = False
fused_attention_backend = None
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons with FP8")
if (
fused_attention_backend == FusedAttnBackend["FP8"]
and is_training
and (device_compute_capability < (9, 0) or cudnn_version < (9, 19, 0))
):
logger.debug(
"Disabling FusedAttention for determinism reasons with FP8 on arch < sm90 or cuDNN"
" < 9.19.0"
)
use_fused_attention = False
fused_attention_backend = None
if (
Expand Down
Loading