Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation)


# FP8 recipes with fp8_dpa=True for attention FP8 emulation export test
fp8_dpa_recipes = [None] # None = no FP8
if fp8_available:
fp8_dpa_recipes.append(recipe.DelayedScaling(fp8_dpa=True))
fp8_dpa_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True))


@pytest.mark.parametrize("fp8_recipe", fp8_dpa_recipes)
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type",
[
Expand All @@ -730,6 +738,7 @@ def test_export_core_attention(
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
fp8_recipe: recipe.Recipe,
):
# Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
Expand All @@ -749,22 +758,26 @@ def test_export_core_attention(

mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
fp8_str = "_fp8_dpa" if fp8_recipe is not None else ""
fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx"

is_fp8 = fp8_recipe is not None

model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
attention_dropout=0.0 if is_fp8 else 0.5, # Disable dropout for FP8 deterministic results
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None)
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None)
do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16,):
return
atol = 5e-2 if is_fp8 else 1e-2 # Higher tolerance for FP8 due to quantization effects
validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring
if is_in_onnx_export_mode():
return FP8EmulationFunc.onnx_forward(
tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout
)

if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3]
Expand Down Expand Up @@ -202,6 +207,62 @@ def backward(ctx, grad1, grad2, grad3):
tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None

@staticmethod
def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None):
"""
ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations.

This method performs quantize + dequantize to emulate FP8 effects using ONNX-compatible ops.
Uses the quantizer's onnx_quantize/onnx_dequantize methods for proper scaling behavior.

Parameters
----------
tensor1, tensor2, tensor3 : torch.Tensor
Input tensors (e.g., Q, K, V for QKV_quantizer, or single tensor for S/O quantizers)
quantizer : Quantizer
The quantizer object with onnx_quantize/onnx_dequantize methods
quantizer_name : str
Name of quantizer: "QKV_quantizer", "S_quantizer", "O_quantizer", etc.
qkv_layout : str, optional
QKV layout string (not used in ONNX path)

Returns
-------
Tuple of emulated tensors
"""
# pylint: disable=unused-argument

if quantizer_name == "QKV_quantizer":
# Combine Q, K, V -> quantize together -> split back
orig_dtype = tensor1.dtype
shapes = [tensor1.shape, tensor2.shape, tensor3.shape]
numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()]

# Flatten and concatenate
combined = torch.cat(
[tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0
)

# Quantize + dequantize combined tensor using quantizer's ONNX methods
combined_fp8 = quantizer.onnx_quantize(combined)
out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype)

# Split back
out1 = out[: numels[0]].reshape(shapes[0])
out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1])
out3 = out[numels[0] + numels[1] :].reshape(shapes[2])

return out1, out2, out3
elif quantizer_name in ["S_quantizer", "O_quantizer"]:
# Emulate FP8 on single tensor using quantizer's ONNX methods
orig_dtype = tensor1.dtype
t_fp8 = quantizer.onnx_quantize(tensor1)
out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype)
return out, tensor2, tensor3
else:
# Pass-through
return tensor1, tensor2, tensor3


class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
Expand Down
21 changes: 18 additions & 3 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,24 @@ def wrapper(*args, **kwargs):
import torch._dynamo

if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: (
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
)

def no_torch_dynamo(recursive=True):
"""Decorator to disable Torch Dynamo, except during ONNX export."""

def decorator(f):
disabled_f = torch._dynamo.disable(f, recursive=recursive)

@wraps(f)
def wrapper(*args, **kwargs):
# Check dynamically at call time, not at decoration time
if is_in_onnx_export_mode():
return f(*args, **kwargs)
return disabled_f(*args, **kwargs)

return wrapper

return decorator

else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
Expand Down
Loading