diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 6fe0ffdaee..65ca74c484 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -429,6 +429,15 @@ def test_dpa_softmax(dtype, model_configs, model): ) +@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax_thd(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False) + + model_configs_mla = { # test: ModelConfig(b, sq, hq, dqk) "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9480b8de70..06ed6e5723 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -283,9 +283,14 @@ def test_cp_with_fused_attention( pytest.skip( "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" ) - if config.softmax_type != "vanilla" and qkv_format == "thd": + if ( + get_cudnn_version() < (9, 18, 0) + and config.softmax_type != "vanilla" + and qkv_format == "thd" + ): pytest.skip( - "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" + " non-vanilla softmax types!" ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 75b360e485..a5931188dc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" + ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "Context parallelism does not support MLA with {cp_comm_type=}!" + ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: assert ( softmax_type == "vanilla" - ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention - ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" assert ( softmax_type == "vanilla" or cp_comm_type == "a2a" - ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + if get_cudnn_version() < (9, 18, 0): + assert softmax_type == "vanilla" or qkv_format != "thd", ( + f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with" + " qkv_format = 'thd'!" + ) args = [ is_training, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index cb74a15e77..fcac740cc3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -716,22 +716,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", - softmax_type, - ) - use_unfused_attention = False + if cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False if context_parallel: - logger.debug( - "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" - " = %s", - softmax_type, - ) - use_unfused_attention = False if cp_comm_type != "a2a": logger.debug( "Disabling FusedAttention for context parallelism with softmax_type = %s and" @@ -1049,6 +1041,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic: + if softmax_type != "vanilla": + logger.debug( + "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + "Sink attention (off-by-one and learnable softmax) requires " + "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + softmax_type, + ) + use_fused_attention = False + fused_attention_backend = None if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False