Skip to content
254 changes: 125 additions & 129 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,131 @@ def fill_userbuffers_buffer_for_all_gather(
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")


def _is_weight_workspace_valid(
workspace: QuantizedTensorStorage,
quantizer: Quantizer,
) -> bool:
"""Check if a cached weight workspace is compatible with the quantizer's current usage."""
if isinstance(workspace, Float8TensorStorage):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and workspace._transpose is None
):
return False
elif isinstance(workspace, MXFP8TensorStorage):
if quantizer.rowwise_usage and workspace._rowwise_data is None:
return False
if quantizer.columnwise_usage and workspace._columnwise_data is None:
return False
elif isinstance(workspace, NVFP4TensorStorage):
if quantizer.rowwise_usage and workspace._rowwise_data is None:
return False
if quantizer.columnwise_usage and workspace._columnwise_data is None:
return False
if isinstance(workspace, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
return False
return True


def quantize_weight(
*,
tensor: Optional[torch.Tensor] = None,
quantizer: Optional[Quantizer] = None,
workspace: Optional[QuantizedTensorStorage] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional["dist_group_type"] = None,
workspace_dtype: Optional[torch.dtype] = None,
cache: bool = False,
) -> Tuple[QuantizedTensorStorage, Optional[QuantizedTensorStorage]]:
"""Quantize a weight tensor, optionally reusing a cached workspace.

Parameters
----------
tensor: torch.Tensor, optional
Weight tensor to quantize.
quantizer: Quantizer, optional
Quantizer for casting the weight.
workspace: QuantizedTensorStorage, optional
Previously cached workspace (from the module's ``_fp8_workspaces``).
``None`` indicates a cache miss.
update_workspace: bool, default = True
Whether to update an existing workspace with fresh values.
skip_update_flag: torch.Tensor, optional
GPU flag to conditionally skip the update.
fsdp_group: dist_group_type, optional
FSDP process group the weights are distributed over.
workspace_dtype: torch.dtype, optional
High-precision dtype for debug quantization workspaces.
cache: bool, default = False
If ``True`` and a new workspace is created, it will be returned
as the second element so the caller can store it.

Returns
-------
(weightmat, new_workspace)
*weightmat*: quantized weight ready for GEMM.
*new_workspace*: non-``None`` only when a brand-new workspace was
created **and** ``cache=True``. The caller should store it in
``_fp8_workspaces``.
"""

# Already-quantized weight (primary FP8 parameters)
if isinstance(tensor, QuantizedTensor):
update_rowwise = True if quantizer.rowwise_usage else None
update_columnwise = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise,
columnwise_usage=update_columnwise,
)
if isinstance(quantizer, DebugQuantizer):
tensor = quantizer.wrap_quantized_tensor(tensor)
return tensor, None

# Validate workspace
if workspace is not None and quantizer is not None:
if not _is_weight_workspace_valid(workspace, quantizer):
workspace = None

# FSDP gather on cached workspace
if (
workspace is not None
and tensor is not None
and fsdp_group is not None
and workspace.data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], workspace)

# Cache hit — update in-place and return
if workspace is not None:
if skip_update_flag is not None:
update_workspace = True
if update_workspace:
if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if hasattr(workspace, "quantize_"):
workspace.quantize_(tensor, noop_flag=skip_update_flag)
else:
tex.quantize(tensor, quantizer, workspace, skip_update_flag)
return workspace, None

# Cache miss — create new workspace
if tensor is None or quantizer is None:
raise ValueError("tensor and quantizer kwargs must be provided to construct FP8 workspace")
if cache:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
saved_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache:
quantizer.internal = saved_internal
return out, out
return out, None


class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""

Expand Down Expand Up @@ -1392,135 +1517,6 @@ def clear(self):
def forward(self):
"""Needs override."""

def get_weight_workspace(
self,
*,
tensor: Optional[torch.Tensor] = None,
quantizer: Optional[Quantizer] = None,
cache_name: Optional[str] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get workspace buffer for weights and maybe update its values

The workspace buffer may be cached for future function calls.

Parameters
----------
tensor : torch.Tensor, optional
Values to copy into workspace. Required if the workspace
is being constructed or updated.
quantizer: Quantizer, optional
Quantizer used to cast the weights. Required if the
workspace is being constructed or updated.
cache_name: str, optional
Key for caching.
update_workspace: bool, default = True
Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
"""

# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)

if isinstance(quantizer, DebugQuantizer):
tensor = quantizer.wrap_quantized_tensor(tensor)

return tensor

# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)

# Reset cache if workspace is invalid
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorStorage):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
elif isinstance(out, MXFP8TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
elif isinstance(out, NVFP4TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True
if reset_cache:
out = None
del self._fp8_workspaces[cache_name]

# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
if (
out is not None
and tensor is not None
and fsdp_group is not None
and out.data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

# Construct workspace if needed
if out is None:
if tensor is None or quantizer is None:
raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)

if cache_name is not None:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache_name is not None:
quantizer.internal = quantizer_internal

# Update cache
if cache_name is not None:
self._fp8_workspaces[cache_name] = out
return out

# Update workspace if needed
if skip_update_flag is not None:
update_workspace = True
if update_workspace:
if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if hasattr(out, "quantize_"):
out.quantize_(tensor, noop_flag=skip_update_flag)
else:
tex.quantize(tensor, quantizer, out, skip_update_flag)
return out

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
Expand Down
42 changes: 31 additions & 11 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from .base import (
get_dummy_wgrad,
quantize_weight,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
Expand Down Expand Up @@ -69,7 +70,7 @@ def forward(
inp: torch.Tensor,
non_tensor_args: Tuple,
*weights_and_biases,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, list]:
# pylint: disable=missing-function-docstring

# Reduce number of arguments to autograd function in order
Expand All @@ -92,7 +93,8 @@ def forward(
sequence_parallel,
activation_dtype,
is_grad_enabled,
module,
weight_workspaces,
cache_weight,
skip_fp8_weight_update,
save_original_input,
debug,
Expand Down Expand Up @@ -166,18 +168,19 @@ def forward(

# Initialize weights
weights_fp8: list
new_workspaces = [None] * num_gemms
if fp8 or debug:
# FP8 cast to workspace buffer
weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch
update_ws = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace(
weight_fp8, new_workspaces[i] = quantize_weight(
tensor=weights[i],
quantizer=weight_quantizers[i],
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
workspace=weight_workspaces[i] if weight_workspaces else None,
update_workspace=update_ws,
skip_update_flag=skip_fp8_weight_update,
workspace_dtype=activation_dtype,
cache=cache_weight,
)
weights_fp8.append(weight_fp8)

Expand Down Expand Up @@ -310,10 +313,12 @@ def forward(
ctx.input_quantizers = input_quantizers

# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspaces

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
def backward(
ctx, grad_output: torch.Tensor, _grad_workspaces
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with get_nvtx_range_context("_GroupedLinear_backward"):
saved_tensors = restore_from_func_ctx(ctx)
Expand Down Expand Up @@ -987,6 +992,13 @@ def forward(
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

num_gemms = len(m_splits)
cache_weight = is_first_microbatch is not None
weight_workspaces = [
self._fp8_workspaces.get(f"weight{i}") if cache_weight else None
for i in range(num_gemms)
]

non_tensor_args = (
m_splits,
self.apply_bias,
Expand All @@ -1005,12 +1017,20 @@ def forward(
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
weight_workspaces,
cache_weight,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
out, new_workspaces = linear_fn(
*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors
)

if cache_weight:
for i, ws in enumerate(new_workspaces):
if ws is not None:
self._fp8_workspaces[f"weight{i}"] = ws

finally:
self.end_forward()
Expand Down
Loading
Loading