66import os
77from pathlib import Path
88
9+ from typing import List
10+
911import setuptools
1012
11- from .utils import all_files_in_dir , cuda_version , get_cuda_include_dirs , debug_build_enabled
12- from typing import List
13+ from .utils import all_files_in_dir , get_cuda_include_dirs , debug_build_enabled
1314
1415
1516def install_requirements () -> List [str ]:
1617 """Install dependencies for TE/PyTorch extensions."""
17- return ["torch>=2.1 " , "einops" , "onnxscript" , "onnx" , "packaging" , "pydantic" , "nvdlfw-inspect" ]
18+ return ["torch>=2.6 " , "einops" , "onnxscript" , "onnx" , "packaging" , "pydantic" , "nvdlfw-inspect" ]
1819
1920
2021def test_requirements () -> List [str ]:
@@ -29,74 +30,83 @@ def test_requirements() -> List[str]:
2930 ]
3031
3132
32- def setup_pytorch_extension (
33+ def setup_pytorch_stable_extension (
3334 csrc_source_files ,
3435 csrc_header_files ,
3536 common_header_files ,
3637) -> setuptools .Extension :
37- """Setup CUDA extension for PyTorch support"""
38+ """Setup stable ABI extension for PyTorch support.
3839
39- # Source files
40- sources = all_files_in_dir (Path (csrc_source_files ), name_extension = "cpp" )
40+ This extension uses only the PyTorch stable ABI (torch/csrc/stable/),
41+ producing a binary that is compatible across PyTorch versions.
42+ It does NOT use CppExtension to avoid pulling in unstable ATen headers.
43+ """
44+ import torch
4145
42- # Header files
46+ # Source files from csrc/extensions/ directory
47+ stable_dir = Path (csrc_source_files ) / "extensions"
48+ sources = all_files_in_dir (stable_dir , name_extension = "cpp" )
49+ if not sources :
50+ return None
51+
52+ # Include directories
4353 include_dirs = get_cuda_include_dirs ()
4454 include_dirs .extend (
4555 [
4656 common_header_files ,
4757 common_header_files / "common" ,
4858 common_header_files / "common" / "include" ,
4959 csrc_header_files ,
60+ # PyTorch headers (for stable ABI only)
61+ Path (torch .utils .cmake_prefix_path ).parent .parent / "include" ,
5062 ]
5163 )
5264
5365 # Compiler flags
54- cxx_flags = ["-O3" , "-fvisibility=hidden" ]
66+ cxx_flags = ["-O3" , "-fvisibility=hidden" , "-std=c++17" , "-DUSE_CUDA" ]
67+ if bool (int (os .environ .get ("NVTE_ENABLE_NVSHMEM" , "0" ))):
68+ cxx_flags .append ("-DNVTE_ENABLE_NVSHMEM" )
69+ nvshmem_home = os .environ .get ("NVSHMEM_HOME" , "" )
70+ if nvshmem_home :
71+ include_dirs .append (Path (nvshmem_home ) / "include" )
72+ # Try system NVSHMEM paths (Debian/Ubuntu packages)
73+ for nvshmem_inc in ["/usr/include/nvshmem_13" , "/usr/local/include/nvshmem" ]:
74+ if os .path .isdir (nvshmem_inc ):
75+ include_dirs .append (Path (nvshmem_inc ))
76+ break
5577 if debug_build_enabled ():
5678 cxx_flags .append ("-g" )
5779 cxx_flags .append ("-UNDEBUG" )
5880 else :
5981 cxx_flags .append ("-g0" )
6082
61- # Version-dependent CUDA options
62- try :
63- version = cuda_version ()
64- except FileNotFoundError :
65- print ("Could not determine CUDA version" )
66- else :
67- if version < (12 , 0 ):
68- raise RuntimeError ("Transformer Engine requires CUDA 12.0 or newer" )
69-
70- if bool (int (os .getenv ("NVTE_UB_WITH_MPI" , "0" ))):
71- assert (
72- os .getenv ("MPI_HOME" ) is not None
73- ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
74- mpi_path = Path (os .getenv ("MPI_HOME" ))
75- include_dirs .append (mpi_path / "include" )
76- cxx_flags .append ("-DNVTE_UB_WITH_MPI" )
77-
78- library_dirs = []
79- libraries = []
80- if bool (int (os .getenv ("NVTE_ENABLE_NVSHMEM" , 0 ))):
81- assert (
82- os .getenv ("NVSHMEM_HOME" ) is not None
83- ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
84- nvshmem_home = Path (os .getenv ("NVSHMEM_HOME" ))
85- include_dirs .append (nvshmem_home / "include" )
86- library_dirs .append (nvshmem_home / "lib" )
87- libraries .append ("nvshmem_host" )
88- cxx_flags .append ("-DNVTE_ENABLE_NVSHMEM" )
83+ # Library directories and libraries
84+ # Find the TE common library (libtransformer_engine.so)
85+ te_lib_dir = Path (csrc_source_files ).parent .parent .parent
86+ cuda_home = os .environ .get ("CUDA_HOME" , os .environ .get ("CUDA_PATH" , "/usr/local/cuda" ))
87+ cuda_lib_dir = os .path .join (cuda_home , "lib64" )
88+ if not os .path .isdir (cuda_lib_dir ):
89+ cuda_lib_dir = os .path .join (cuda_home , "lib" )
90+ library_dirs = [
91+ str (Path (torch .utils .cmake_prefix_path ).parent .parent / "lib" ),
92+ str (te_lib_dir ),
93+ cuda_lib_dir ,
94+ ]
95+ libraries = ["torch" , "torch_cpu" , "c10" , "cudart" , "transformer_engine" ]
8996
90- # Construct PyTorch CUDA extension
91- sources = [str (path ) for path in sources ]
92- include_dirs = [str (path ) for path in include_dirs ]
93- from torch .utils .cpp_extension import CppExtension
97+ # Set rpath so the stable extension can find libtransformer_engine.so at runtime.
98+ # Use $ORIGIN for co-located libraries plus the absolute path for editable installs.
99+ extra_link_args = [
100+ "-Wl,-rpath,$ORIGIN" ,
101+ f"-Wl,-rpath,{ te_lib_dir .resolve ()} " ,
102+ ]
94103
95- return CppExtension (
96- name = "transformer_engine_torch " ,
104+ return setuptools . Extension (
105+ name = "transformer_engine.te_stable_abi " ,
97106 sources = [str (src ) for src in sources ],
98107 include_dirs = [str (inc ) for inc in include_dirs ],
99- extra_compile_args = {"cxx" : cxx_flags },
100- libraries = [str (lib ) for lib in libraries ],
101- library_dirs = [str (lib_dir ) for lib_dir in library_dirs ],
108+ extra_compile_args = cxx_flags ,
109+ libraries = libraries ,
110+ library_dirs = library_dirs ,
111+ extra_link_args = extra_link_args ,
102112 )
0 commit comments