Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,27 @@ def run(self) -> None:
install_dir=install_dir,
)

# Build non-CMake extensions as usual
# Build non-CMake extensions as usual.
# Add cmake install/build dirs to library_dirs so the linker
# can find libtransformer_engine.so at link time.
cmake_lib_dirs = []
for ext in self.extensions:
if isinstance(ext, CMakeExtension):
package_path = Path(self.get_ext_fullpath(ext.name))
cmake_lib_dirs.append(str(package_path.resolve().parent))
build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
if build_dir:
cmake_lib_dirs.append(str(Path(build_dir).resolve()))
else:
root_dir = Path(__file__).resolve().parent.parent
cmake_lib_dirs.append(str(root_dir / "build" / "cmake"))

all_extensions = self.extensions
self.extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
for ext in self.extensions:
ext.library_dirs = cmake_lib_dirs + (ext.library_dirs or [])
super().run()
self.extensions = all_extensions

Expand Down
67 changes: 36 additions & 31 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

import setuptools

from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
from .utils import all_files_in_dir, get_cuda_include_dirs, debug_build_enabled
from typing import List


def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
return ["torch>=2.6", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]


def test_requirements() -> List[str]:
Expand All @@ -29,15 +29,24 @@ def test_requirements() -> List[str]:
]


def setup_pytorch_extension(
def setup_pytorch_stable_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support"""
"""Setup stable ABI extension for PyTorch support.

# Source files
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
This extension uses only the PyTorch stable ABI (torch/csrc/stable/),
producing a binary that is compatible across PyTorch versions.
It does NOT use CppExtension to avoid pulling in unstable ATen headers.
"""
import torch

# Source files from csrc/extensions/ directory
stable_dir = Path(csrc_source_files) / "extensions"
sources = all_files_in_dir(stable_dir, name_extension="cpp")
if not sources:
return None

# Header files
include_dirs = get_cuda_include_dirs()
Expand All @@ -47,36 +56,31 @@ def setup_pytorch_extension(
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
# PyTorch headers (for stable ABI only)
Path(torch.utils.cmake_prefix_path).parent.parent / "include",
]
)

# Compiler flags
cxx_flags = ["-O3", "-fvisibility=hidden"]
cxx_flags = ["-O3", "-fvisibility=hidden", "-std=c++17", "-DUSE_CUDA"]
if debug_build_enabled():
cxx_flags.append("-g")
cxx_flags.append("-UNDEBUG")
else:
cxx_flags.append("-g0")

# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA version")
else:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")

if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")

library_dirs = []
libraries = []

# PyTorch and CUDA libraries (needed since we don't use CppExtension)
torch_lib_dir = str(Path(torch.utils.cmake_prefix_path).parent.parent / "lib")
cuda_home = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", "/usr/local/cuda"))
cuda_lib_dir = os.path.join(cuda_home, "lib64")
if not os.path.isdir(cuda_lib_dir):
cuda_lib_dir = os.path.join(cuda_home, "lib")
library_dirs.extend([torch_lib_dir, cuda_lib_dir])
libraries.extend(["torch", "torch_cpu", "c10", "cudart", "transformer_engine"])

if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
Expand All @@ -87,16 +91,17 @@ def setup_pytorch_extension(
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CppExtension
# Set rpath so the stable extension can find libtransformer_engine.so at runtime.
# Use $ORIGIN for co-located libraries.
extra_link_args = ["-Wl,-rpath,$ORIGIN"]

return CppExtension(
name="transformer_engine_torch",
# Construct stable ABI extension
return setuptools.Extension(
name="transformer_engine.te_stable_abi",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={"cxx": cxx_flags},
extra_compile_args=cxx_flags,
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
extra_link_args=extra_link_args,
)
71 changes: 71 additions & 0 deletions missing_fusions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Stable ABI: Missing Fusions and Performance Regressions

This document tracks pure-Python implementations in `_stable_torch_module.py` that
replace fused C++ kernels from the pybind11 extension, causing performance regressions.

## Critical — Python loops replacing fused kernels

### `_group_quantize_fallback` (NVFP4 path)
- **Issue**: Per-chunk quantize loop for non-MXFP8 quantizers (NVFP4, Float8Block)
- **pybind11**: `nvte_group_quantize` handles all chunks in one fused kernel
- **Impact**: O(num_tensors) kernel launches instead of 1
- **Fix**: Extend the C++ `group_quantize` op in `cast.cpp` to handle NVFP4

### `split_quantize`
- **Issue**: Per-split quantize loop with no bulk allocation optimization
- **pybind11**: Bulk-allocation for MXFP8/NVFP4 + fused kernel
- **Impact**: O(num_splits) separate quantization kernels
- **Fix**: Add C++ `split_quantize` op or bulk-allocation path

### `te_general_grouped_gemm`
- **Issue**: Per-GEMM loop calling `_ops.gemm()` individually
- **pybind11**: `nvte_multi_tensor_gemm` batches all GEMMs
- **Impact**: Loss of stream-level parallelism, ~10-30% throughput regression
- **Fix**: Add C++ wrapper for `nvte_multi_tensor_gemm`

### `multi_tensor_quantize`
- **Issue**: Stub that raises `NotImplementedError`
- **pybind11**: `nvte_multi_cast_transpose` fused kernel
- **Impact**: Runtime crash if called
- **Fix**: Add C++ wrapper for `nvte_multi_cast_transpose`

### NVFP4 multi-tensor ops (4 functions)
- `nvfp4_multi_tensor_fused_scale`
- `nvfp4_2d_multi_tensor_transpose`
- `nvfp4_multi_tensor_2d_partial_cast`
- `nvfp4_multi_tensor_compute_partial_amax`
- **Issue**: Per-tensor Python loops calling single-tensor stable ops
- **pybind11**: Direct C++ multi-tensor operations
- **Impact**: O(list_length) kernel launches each
- **Fix**: Add C++ wrappers accepting tensor lists

## Medium — Missing kernel fusion

### `layernorm_fwd` / `rmsnorm_fwd` quantize fusion
- **Issue**: Only fuses norm+quantize for Float8 delayed scaling; Block/MXFP8/NVFP4
fall back to separate norm then quantize kernels
- **Impact**: 2 kernel launches instead of 1, every layer
- **Fix**: Extend `_try_fused_norm_quantize_*` to support more quantizer types

### Activation forward + quantize (NVFP4)
- **Issue**: Unfused path for NVFP4 activations — computes activation then quantizes
- **Impact**: Extra kernel launch per activation
- **Fix**: Extend fused activation+quantize to NVFP4

### `_make_dbias_dact` backward fusion
- **Issue**: For non-MXFP8, falls back to unfused dact + bias reduction + quantize
- **pybind11**: Single fused `dact_dbias_noalloc` kernel
- **Impact**: 3 operations instead of 1 in backward pass
- **Fix**: Extend stable ABI `dactivation_dbias_noalloc` to support more scaling modes

## Low — Minor inefficiencies

### `generic_gemm` on-the-fly transpose
- **Issue**: Computes FP8 transpose at GEMM time when columnwise data is missing
- **Impact**: 1-2 extra transpose kernels per GEMM (depends on tensor lifecycle)
- **Mitigation**: Only when `_transpose_invalid=True`

### NVFP4 stochastic rounding
- **Issue**: `quantize_into` does not pass stochastic rounding flag to C++ kernel
- **Impact**: Missing feature, not a perf regression — tests skip/fail
- **Fix**: Add stochastic rounding parameter to stable ABI quantize ops
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.6", "jax>=0.5.0", "flax>=0.7.1"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
14 changes: 7 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,15 @@ def git_check_submodules() -> None:

if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
from build_tools.pytorch import setup_pytorch_stable_extension

ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
stable_ext = setup_pytorch_stable_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
if stable_ext is not None:
ext_modules.append(stable_ext)
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension

Expand Down
8 changes: 5 additions & 3 deletions tests/pytorch/test_float8_blockwise_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,11 @@ def test_gelu_unsupported_cases_error(
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
if use_grad and not use_bias and out_dtype == torch.bfloat16:
pytest.skip("DGELU epilogue is supported for bfloat16.")
elif use_grad and not use_bias:
pytest.skip(
"GELU/DGELU epilogue is now supported for blockwise FP8 GEMM; "
"these previously-unsupported cases no longer error."
)
if use_grad and not use_bias:
expected_err = "an unsupported value or parameter was passed"
else:
expected_err = "Epilogue requested outside of the available"
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,9 @@ def test_mxfp8_dequantize_columnwise_only(
# Rowwise dequantization should be close to the original
torch.testing.assert_close(x_deq_rowwise, x_ref, **_tols[fp8_dtype])

# Strip rowwise data, keeping only columnwise
# Mark rowwise as unused, keeping only columnwise
# Note: rowwise data may be preserved for on-the-fly columnwise creation
x_mxfp8.update_usage(rowwise_usage=False, columnwise_usage=True)
assert x_mxfp8._rowwise_data is None
assert x_mxfp8._columnwise_data is not None

# Dequantize from columnwise only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class CommOverlapCore {

int get_tp_size() { return _tp_size; }

int get_tp_id() { return _tp_id; }

int get_rank() { return _rank; }

const TensorWrapper &get_ubuf() const { return _ubuf; }
TensorWrapper &get_ubuf() { return _ubuf; }

bool is_atomic_gemm() { return _atomic_gemm; }

bool is_p2p_overlap() { return _is_p2p; }
Expand Down Expand Up @@ -169,6 +176,8 @@ class CommOverlapBase : public CommOverlapCore {
public:
CommOverlapBase() {} // dummy constructor for exposing type to Python

cudaStream_t get_comm_stream() const { return _stream_comm; }

CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3,
Expand Down Expand Up @@ -249,6 +258,11 @@ class CommOverlapP2PBase : public CommOverlapCore {
public:
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python

const std::vector<TensorWrapper> &get_ubufs() const { return _ubufs; }
std::vector<TensorWrapper> &get_ubufs() { return _ubufs; }
const std::vector<cudaStream_t> &get_send_streams() const { return _stream_send; }
cudaStream_t get_recv_stream() const { return _stream_recv; }

CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
build_tools/
common_headers/
12 changes: 9 additions & 3 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
# pylint: disable=wrong-import-position

import functools
import sys as _sys

import torch

from transformer_engine.common import load_framework_extension
from transformer_engine.pytorch.torch_version import torch_version

assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
assert torch_version() >= (2, 6), f"Minimum torch version 2.6 required. Found {torch_version()}."

# Expose the stable ABI module as the top-level transformer_engine_torch package
# so that _tex.py can use `from transformer_engine_torch import *` (matching upstream).
import transformer_engine.pytorch._stable_torch_module as _te_torch_mod

_sys.modules.setdefault("transformer_engine_torch", _te_torch_mod)
del _sys, _te_torch_mod

load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
Expand Down
Loading