diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 50cd150c4e..c55b0cc193 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): _test_export_layernorm_mlp(activation=activation) +# FP8 recipes with fp8_dpa=True for attention FP8 emulation export test +fp8_dpa_recipes = [None] # None = no FP8 +if fp8_available: + fp8_dpa_recipes.append(recipe.DelayedScaling(fp8_dpa=True)) + fp8_dpa_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True)) + + +@pytest.mark.parametrize("fp8_recipe", fp8_dpa_recipes) @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -730,6 +738,7 @@ def test_export_core_attention( precision: torch.dtype, use_mask: bool, attn_mask_type: str, + fp8_recipe: recipe.Recipe, ): # Set dimensions (these are arbitrary). seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) @@ -749,22 +758,26 @@ def test_export_core_attention( mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" + fp8_str = "_fp8_dpa" if fp8_recipe is not None else "" + fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx" + + is_fp8 = fp8_recipe is not None model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, - attention_dropout=0.5, + attention_dropout=0.0 if is_fp8 else 0.5, # Disable dropout for FP8 deterministic results qkv_format=qkv_format, attn_mask_type=attn_mask_type, ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) - te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) + do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe) + te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16,): return + atol = 5e-2 if is_fp8 else 1e-2 # Higher tolerance for FP8 due to quantization effects validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs + fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c726ed8849..96791f72df 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): # pylint: disable=missing-function-docstring + if is_in_onnx_export_mode(): + return FP8EmulationFunc.onnx_forward( + tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout + ) + if quantizer_name == "QKV_quantizer": query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] @@ -202,6 +207,62 @@ def backward(ctx, grad1, grad2, grad3): tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None + @staticmethod + def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None): + """ + ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. + + This method performs quantize + dequantize to emulate FP8 effects using ONNX-compatible ops. + Uses the quantizer's onnx_quantize/onnx_dequantize methods for proper scaling behavior. + + Parameters + ---------- + tensor1, tensor2, tensor3 : torch.Tensor + Input tensors (e.g., Q, K, V for QKV_quantizer, or single tensor for S/O quantizers) + quantizer : Quantizer + The quantizer object with onnx_quantize/onnx_dequantize methods + quantizer_name : str + Name of quantizer: "QKV_quantizer", "S_quantizer", "O_quantizer", etc. + qkv_layout : str, optional + QKV layout string (not used in ONNX path) + + Returns + ------- + Tuple of emulated tensors + """ + # pylint: disable=unused-argument + + if quantizer_name == "QKV_quantizer": + # Combine Q, K, V -> quantize together -> split back + orig_dtype = tensor1.dtype + shapes = [tensor1.shape, tensor2.shape, tensor3.shape] + numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] + + # Flatten and concatenate + combined = torch.cat( + [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 + ) + + # Quantize + dequantize combined tensor using quantizer's ONNX methods + combined_fp8 = quantizer.onnx_quantize(combined) + out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype) + + # Split back + out1 = out[: numels[0]].reshape(shapes[0]) + out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]) + out3 = out[numels[0] + numels[1] :].reshape(shapes[2]) + + return out1, out2, out3 + elif quantizer_name in ["S_quantizer", "O_quantizer"]: + # Emulate FP8 on single tensor using quantizer's ONNX methods + orig_dtype = tensor1.dtype + t_fp8 = quantizer.onnx_quantize(tensor1) + out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype) + return out, tensor2, tensor3 + else: + # Pass-through + return tensor1, tensor2, tensor3 + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 5884188b7e..c19b24944f 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -51,9 +51,24 @@ def wrapper(*args, **kwargs): import torch._dynamo if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: ( - f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive) - ) + + def no_torch_dynamo(recursive=True): + """Decorator to disable Torch Dynamo, except during ONNX export.""" + + def decorator(f): + disabled_f = torch._dynamo.disable(f, recursive=recursive) + + @wraps(f) + def wrapper(*args, **kwargs): + # Check dynamically at call time, not at decoration time + if is_in_onnx_export_mode(): + return f(*args, **kwargs) + return disabled_f(*args, **kwargs) + + return wrapper + + return decorator + else: # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable