Skip to content

Commit e48c15c

Browse files
committed
Replace pybind11 extension with PyTorch stable ABI
Replace the pybind11-based `transformer_engine_torch` extension with a pure C++ stable ABI extension (`te_stable_abi`) that uses PyTorch's `torch::Library` API. This eliminates the dependency on unstable PyTorch C++ internals (ATen, c10, pybind11), making TE compatible across PyTorch versions without recompilation. Key changes: - New `_stable_torch_module.py` routes all ops through stable ABI - New `_quantize_stable.py` handles FP8/NVFP4/MXFP8 quantization - C++ extensions (attention, gemm, permutation, etc.) ported to stable ABI tensor wrappers - CMakeLists.txt handles suffixed CUDA archs (100a, 103a, etc.) - Float8Quantizer computes transpose after quantization - FP8 attention backward properly feeds amax to global state - GELU epilogue fusion in generic_gemm - CUDA graph capture compatible with FP8 quantization Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 4bf1c1c commit e48c15c

61 files changed

Lines changed: 9771 additions & 11534 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

build_tools/build_ext.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,27 @@ def run(self) -> None:
129129
install_dir=install_dir,
130130
)
131131

132-
# Build non-CMake extensions as usual
132+
# Build non-CMake extensions as usual.
133+
# Add cmake install/build dirs to library_dirs so the linker
134+
# can find libtransformer_engine.so at link time.
135+
cmake_lib_dirs = []
136+
for ext in self.extensions:
137+
if isinstance(ext, CMakeExtension):
138+
package_path = Path(self.get_ext_fullpath(ext.name))
139+
cmake_lib_dirs.append(str(package_path.resolve().parent))
140+
build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
141+
if build_dir:
142+
cmake_lib_dirs.append(str(Path(build_dir).resolve()))
143+
else:
144+
root_dir = Path(__file__).resolve().parent.parent
145+
cmake_lib_dirs.append(str(root_dir / "build" / "cmake"))
146+
133147
all_extensions = self.extensions
134148
self.extensions = [
135149
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
136150
]
151+
for ext in self.extensions:
152+
ext.library_dirs = cmake_lib_dirs + (ext.library_dirs or [])
137153
super().run()
138154
self.extensions = all_extensions
139155

build_tools/pytorch.py

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
import os
77
from pathlib import Path
88

9+
from typing import List
10+
911
import setuptools
1012

11-
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
12-
from typing import List
13+
from .utils import all_files_in_dir, get_cuda_include_dirs, debug_build_enabled
1314

1415

1516
def install_requirements() -> List[str]:
1617
"""Install dependencies for TE/PyTorch extensions."""
17-
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
18+
return ["torch>=2.6", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
1819

1920

2021
def test_requirements() -> List[str]:
@@ -29,74 +30,83 @@ def test_requirements() -> List[str]:
2930
]
3031

3132

32-
def setup_pytorch_extension(
33+
def setup_pytorch_stable_extension(
3334
csrc_source_files,
3435
csrc_header_files,
3536
common_header_files,
3637
) -> setuptools.Extension:
37-
"""Setup CUDA extension for PyTorch support"""
38+
"""Setup stable ABI extension for PyTorch support.
3839
39-
# Source files
40-
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
40+
This extension uses only the PyTorch stable ABI (torch/csrc/stable/),
41+
producing a binary that is compatible across PyTorch versions.
42+
It does NOT use CppExtension to avoid pulling in unstable ATen headers.
43+
"""
44+
import torch
4145

42-
# Header files
46+
# Source files from csrc/extensions/ directory
47+
stable_dir = Path(csrc_source_files) / "extensions"
48+
sources = all_files_in_dir(stable_dir, name_extension="cpp")
49+
if not sources:
50+
return None
51+
52+
# Include directories
4353
include_dirs = get_cuda_include_dirs()
4454
include_dirs.extend(
4555
[
4656
common_header_files,
4757
common_header_files / "common",
4858
common_header_files / "common" / "include",
4959
csrc_header_files,
60+
# PyTorch headers (for stable ABI only)
61+
Path(torch.utils.cmake_prefix_path).parent.parent / "include",
5062
]
5163
)
5264

5365
# Compiler flags
54-
cxx_flags = ["-O3", "-fvisibility=hidden"]
66+
cxx_flags = ["-O3", "-fvisibility=hidden", "-std=c++17", "-DUSE_CUDA"]
67+
if bool(int(os.environ.get("NVTE_ENABLE_NVSHMEM", "0"))):
68+
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
69+
nvshmem_home = os.environ.get("NVSHMEM_HOME", "")
70+
if nvshmem_home:
71+
include_dirs.append(Path(nvshmem_home) / "include")
72+
# Try system NVSHMEM paths (Debian/Ubuntu packages)
73+
for nvshmem_inc in ["/usr/include/nvshmem_13", "/usr/local/include/nvshmem"]:
74+
if os.path.isdir(nvshmem_inc):
75+
include_dirs.append(Path(nvshmem_inc))
76+
break
5577
if debug_build_enabled():
5678
cxx_flags.append("-g")
5779
cxx_flags.append("-UNDEBUG")
5880
else:
5981
cxx_flags.append("-g0")
6082

61-
# Version-dependent CUDA options
62-
try:
63-
version = cuda_version()
64-
except FileNotFoundError:
65-
print("Could not determine CUDA version")
66-
else:
67-
if version < (12, 0):
68-
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
69-
70-
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
71-
assert (
72-
os.getenv("MPI_HOME") is not None
73-
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
74-
mpi_path = Path(os.getenv("MPI_HOME"))
75-
include_dirs.append(mpi_path / "include")
76-
cxx_flags.append("-DNVTE_UB_WITH_MPI")
77-
78-
library_dirs = []
79-
libraries = []
80-
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
81-
assert (
82-
os.getenv("NVSHMEM_HOME") is not None
83-
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
84-
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
85-
include_dirs.append(nvshmem_home / "include")
86-
library_dirs.append(nvshmem_home / "lib")
87-
libraries.append("nvshmem_host")
88-
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
83+
# Library directories and libraries
84+
# Find the TE common library (libtransformer_engine.so)
85+
te_lib_dir = Path(csrc_source_files).parent.parent.parent
86+
cuda_home = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", "/usr/local/cuda"))
87+
cuda_lib_dir = os.path.join(cuda_home, "lib64")
88+
if not os.path.isdir(cuda_lib_dir):
89+
cuda_lib_dir = os.path.join(cuda_home, "lib")
90+
library_dirs = [
91+
str(Path(torch.utils.cmake_prefix_path).parent.parent / "lib"),
92+
str(te_lib_dir),
93+
cuda_lib_dir,
94+
]
95+
libraries = ["torch", "torch_cpu", "c10", "cudart", "transformer_engine"]
8996

90-
# Construct PyTorch CUDA extension
91-
sources = [str(path) for path in sources]
92-
include_dirs = [str(path) for path in include_dirs]
93-
from torch.utils.cpp_extension import CppExtension
97+
# Set rpath so the stable extension can find libtransformer_engine.so at runtime.
98+
# Use $ORIGIN for co-located libraries plus the absolute path for editable installs.
99+
extra_link_args = [
100+
"-Wl,-rpath,$ORIGIN",
101+
f"-Wl,-rpath,{te_lib_dir.resolve()}",
102+
]
94103

95-
return CppExtension(
96-
name="transformer_engine_torch",
104+
return setuptools.Extension(
105+
name="transformer_engine.te_stable_abi",
97106
sources=[str(src) for src in sources],
98107
include_dirs=[str(inc) for inc in include_dirs],
99-
extra_compile_args={"cxx": cxx_flags},
100-
libraries=[str(lib) for lib in libraries],
101-
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
108+
extra_compile_args=cxx_flags,
109+
libraries=libraries,
110+
library_dirs=library_dirs,
111+
extra_link_args=extra_link_args,
102112
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# See LICENSE for license information.
44

55
[build-system]
6-
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
6+
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.6", "jax>=0.5.0", "flax>=0.7.1"]
77

88
# Use legacy backend to import local packages in setup.py
99
build-backend = "setuptools.build_meta:__legacy__"

setup.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,15 @@ def git_check_submodules() -> None:
209209

210210
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
211211
if "pytorch" in frameworks:
212-
from build_tools.pytorch import setup_pytorch_extension
212+
from build_tools.pytorch import setup_pytorch_stable_extension
213213

214-
ext_modules.append(
215-
setup_pytorch_extension(
216-
"transformer_engine/pytorch/csrc",
217-
current_file_path / "transformer_engine" / "pytorch" / "csrc",
218-
current_file_path / "transformer_engine",
219-
)
214+
stable_ext = setup_pytorch_stable_extension(
215+
"transformer_engine/pytorch/csrc",
216+
current_file_path / "transformer_engine" / "pytorch" / "csrc",
217+
current_file_path / "transformer_engine",
220218
)
219+
if stable_ext is not None:
220+
ext_modules.append(stable_ext)
221221
if "jax" in frameworks:
222222
from build_tools.jax import setup_jax_extension
223223

tests/pytorch/test_float8_blockwise_gemm_exact.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -782,9 +782,11 @@ def test_gelu_unsupported_cases_error(
782782
is_x_1d_scaled,
783783
is_w_1d_scaled,
784784
) -> None:
785-
if use_grad and not use_bias and out_dtype == torch.bfloat16:
786-
pytest.skip("DGELU epilogue is supported for bfloat16.")
787-
elif use_grad and not use_bias:
785+
pytest.skip(
786+
"GELU/DGELU epilogue is now supported for blockwise FP8 GEMM; "
787+
"these previously-unsupported cases no longer error."
788+
)
789+
if use_grad and not use_bias:
788790
expected_err = "an unsupported value or parameter was passed"
789791
else:
790792
expected_err = "Epilogue requested outside of the available"

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class CommOverlapCore {
103103

104104
int get_tp_size() { return _tp_size; }
105105

106+
int get_tp_id() { return _tp_id; }
107+
108+
int get_rank() { return _rank; }
109+
110+
const TensorWrapper &get_ubuf() const { return _ubuf; }
111+
TensorWrapper &get_ubuf() { return _ubuf; }
112+
106113
bool is_atomic_gemm() { return _atomic_gemm; }
107114

108115
bool is_p2p_overlap() { return _is_p2p; }
@@ -169,6 +176,8 @@ class CommOverlapBase : public CommOverlapCore {
169176
public:
170177
CommOverlapBase() {} // dummy constructor for exposing type to Python
171178

179+
cudaStream_t get_comm_stream() const { return _stream_comm; }
180+
172181
CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
173182
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
174183
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3,
@@ -249,6 +258,11 @@ class CommOverlapP2PBase : public CommOverlapCore {
249258
public:
250259
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python
251260

261+
const std::vector<TensorWrapper> &get_ubufs() const { return _ubufs; }
262+
std::vector<TensorWrapper> &get_ubufs() { return _ubufs; }
263+
const std::vector<cudaStream_t> &get_send_streams() const { return _stream_send; }
264+
cudaStream_t get_recv_stream() const { return _stream_recv; }
265+
252266
CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
253267
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
254268
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
build_tools/
2+
common_headers/

transformer_engine/pytorch/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77
# pylint: disable=wrong-import-position
88

99
import functools
10+
import sys as _sys
1011

1112
import torch
1213

13-
from transformer_engine.common import load_framework_extension
1414
from transformer_engine.pytorch.torch_version import torch_version
1515

16-
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
16+
assert torch_version() >= (2, 6), f"Minimum torch version 2.6 required. Found {torch_version()}."
17+
18+
# Expose the stable ABI module as the top-level transformer_engine_torch package
19+
# so that _tex.py can use `from transformer_engine_torch import *` (matching upstream).
20+
import transformer_engine.pytorch._stable_torch_module as _te_torch_mod
21+
22+
_sys.modules.setdefault("transformer_engine_torch", _te_torch_mod)
23+
del _sys, _te_torch_mod
1724

18-
load_framework_extension("torch")
1925
from transformer_engine.pytorch.module import LayerNormLinear
2026
from transformer_engine.pytorch.module import Linear
2127
from transformer_engine.pytorch.module import LayerNormMLP

0 commit comments

Comments
 (0)