-
Notifications
You must be signed in to change notification settings - Fork 609
[JAX] HLO FFI tests #2593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] HLO FFI tests #2593
Changes from all commits
04072e3
4551979
2f1e29a
6d5d210
b9ba7c5
28a07c1
069f751
121f259
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,13 +2,16 @@ | |||||||||||||||
| # | ||||||||||||||||
| # See LICENSE for license information. | ||||||||||||||||
|
|
||||||||||||||||
| from io import StringIO | ||||||||||||||||
| import jax | ||||||||||||||||
| import jax.numpy as jnp | ||||||||||||||||
| import pytest | ||||||||||||||||
| from jax import jit, value_and_grad | ||||||||||||||||
| from functools import reduce | ||||||||||||||||
| from typing import Union | ||||||||||||||||
| import operator | ||||||||||||||||
| import os | ||||||||||||||||
| import re | ||||||||||||||||
|
|
||||||||||||||||
| from utils import ( | ||||||||||||||||
| assert_allclose, | ||||||||||||||||
|
|
@@ -1921,3 +1924,209 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): | |||||||||||||||
| assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) | ||||||||||||||||
| assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) | ||||||||||||||||
| assert_allclose(prim_dbias, ref_dbias, dtype=dtype) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason) | ||||||||||||||||
| class TestFFICompatibility: | ||||||||||||||||
|
|
||||||||||||||||
| HLO_DIR = os.path.join(os.path.dirname(__file__), "ffi_hlo") | ||||||||||||||||
|
|
||||||||||||||||
| @pytest.fixture(name="ffi_hlo_name") | ||||||||||||||||
| def hlo_fixture(shape): | ||||||||||||||||
| for file in os.listdir(TestFFICompatibility.HLO_DIR): | ||||||||||||||||
| file_path = os.path.join(TestFFICompatibility.HLO_DIR, file) | ||||||||||||||||
| if os.path.isfile(file_path): | ||||||||||||||||
| yield file.split(".")[0] | ||||||||||||||||
|
|
||||||||||||||||
| @pytest.mark.skipif( | ||||||||||||||||
| os.getenv("NVTE_JAX_FFI_HLO_GENERATE", "0") != "1", reason="HLO generation not enabled" | ||||||||||||||||
| ) | ||||||||||||||||
| def test_generate_hlo(self): | ||||||||||||||||
| """Run this test with NVTE_JAX_FFI_HLO_GENERATE=1 to generate StableHLO text files for FFI compatibility tests. Use this when intentionally changing FFI bindings and breaking compatibility changes are required. | ||||||||||||||||
|
|
||||||||||||||||
| Instructions: | ||||||||||||||||
| 1. `CUDA_VISIBLE_DEVICES=0 XLA_FLAGS="$XLA_FLAGS --xla_dump_to=./tests/jax/ffi_hlo_dump" NVTE_JAX_FFI_HLO_GENERATE=1 pytest tests/jax/test_custom_call_compute.py::TestFFICompatibility::test_generate_hlo -s` | ||||||||||||||||
| 2. Find `tests/jax/ffi_hlo_dump/jit_train_step_<some numbers>/module.mlir` and copy it to the `tests/jax/ffi_hlo/` directory named transformer_stablehlo.txt | ||||||||||||||||
| """ | ||||||||||||||||
| import math | ||||||||||||||||
| from transformer_engine.common.recipe import NVFP4BlockScaling, Float8CurrentScaling | ||||||||||||||||
| from transformer_engine.jax import autocast, MeshResource, softmax | ||||||||||||||||
| from transformer_engine.jax.flax import TransformerLayer | ||||||||||||||||
| import flax.linen as nn | ||||||||||||||||
|
|
||||||||||||||||
| with autocast(enabled=True, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()): | ||||||||||||||||
|
|
||||||||||||||||
| class Model(nn.Module): | ||||||||||||||||
| """This module does not represent any meaningful model, it is just to cover all FFI calls.""" | ||||||||||||||||
|
|
||||||||||||||||
| @nn.compact | ||||||||||||||||
| def __call__(self, x): | ||||||||||||||||
| # Covers most of the FFI calls | ||||||||||||||||
| x = TransformerLayer( | ||||||||||||||||
| hidden_dropout=0.0, | ||||||||||||||||
| attention_dropout=0.0, | ||||||||||||||||
| intermediate_dropout=0.0, | ||||||||||||||||
| dtype=jnp.bfloat16, | ||||||||||||||||
| )(x) | ||||||||||||||||
|
|
||||||||||||||||
| # Arbitrarily call softmax multiple times to cover all softmax FFI calls | ||||||||||||||||
| x = x.reshape((1, *x.shape)) | ||||||||||||||||
| x = softmax.softmax(x, softmax_fusion_type=softmax.SoftmaxFusionType.SCALED) | ||||||||||||||||
| mask1 = self.variable( | ||||||||||||||||
| "collection", | ||||||||||||||||
| "mask1", | ||||||||||||||||
| lambda: jax.random.bernoulli(jax.random.PRNGKey(0), shape=x.shape).astype( | ||||||||||||||||
| jnp.bfloat16 | ||||||||||||||||
| ), | ||||||||||||||||
| ).value.astype(jnp.uint8) | ||||||||||||||||
| x = softmax.softmax( | ||||||||||||||||
| x, mask=mask1, softmax_fusion_type=softmax.SoftmaxFusionType.SCALED_MASKED | ||||||||||||||||
| ) | ||||||||||||||||
| mask2 = self.variable( | ||||||||||||||||
| "collection", | ||||||||||||||||
| "mask2", | ||||||||||||||||
| lambda: (1.0 - jnp.tril(jnp.ones_like(x))).astype(jnp.bfloat16), | ||||||||||||||||
| ).value.astype(jnp.uint8) | ||||||||||||||||
| x = x.reshape((-1, 1, 32, 32)) | ||||||||||||||||
| x = softmax.softmax( | ||||||||||||||||
| x, | ||||||||||||||||
| mask=mask2, | ||||||||||||||||
| softmax_fusion_type=softmax.SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, | ||||||||||||||||
| ) | ||||||||||||||||
| return x | ||||||||||||||||
|
|
||||||||||||||||
| model = Model() | ||||||||||||||||
| input_shape = (1, 128, 512) | ||||||||||||||||
| x = jnp.ones(input_shape, dtype=jnp.bfloat16) | ||||||||||||||||
|
|
||||||||||||||||
| var_collect = model.init(jax.random.PRNGKey(0), x) | ||||||||||||||||
|
|
||||||||||||||||
| def f(var_collect, x): | ||||||||||||||||
| x = model.apply(var_collect, x, rngs={"sr_rng": jax.random.PRNGKey(0)}) | ||||||||||||||||
| x = jnp.mean(x) # fake loss function for value_and_grad | ||||||||||||||||
| return x | ||||||||||||||||
|
|
||||||||||||||||
| @jax.jit | ||||||||||||||||
| def train_step(var_collect, x, grouped_kernel): | ||||||||||||||||
| loss, grads = jax.value_and_grad(f)(var_collect, x) | ||||||||||||||||
|
|
||||||||||||||||
| # Arbitrarily call grouped quantize and GEMM to cover remaining FFI calls | ||||||||||||||||
| x = x.reshape((-1, x.shape[-1])) | ||||||||||||||||
| x = grouped_dense( | ||||||||||||||||
| x, | ||||||||||||||||
| grouped_kernel, | ||||||||||||||||
| contracting_dims=((1,), (1,)), | ||||||||||||||||
| group_sizes=jnp.array([x.shape[0]], dtype=jnp.int32), | ||||||||||||||||
| quantizer_set=QuantizerFactory.create_set( | ||||||||||||||||
| n_groups=1, | ||||||||||||||||
| fp8_recipe=Float8CurrentScaling(), | ||||||||||||||||
| quantize_meta_set=QuantizeMetaSet( | ||||||||||||||||
| QuantizeMeta(), QuantizeMeta(), QuantizeMeta() | ||||||||||||||||
| ), | ||||||||||||||||
| ), | ||||||||||||||||
| ) | ||||||||||||||||
| loss += jnp.mean(x) | ||||||||||||||||
|
|
||||||||||||||||
| return loss, grads | ||||||||||||||||
|
|
||||||||||||||||
| grouped_kernel = jnp.zeros((1, x.shape[-1], x.shape[-1]), dtype=jnp.bfloat16) | ||||||||||||||||
| train_step(var_collect, x, grouped_kernel) | ||||||||||||||||
|
|
||||||||||||||||
| def _get_hlo_text_from_file(self, hlo_name: str) -> str: | ||||||||||||||||
| """Reads the StableHLO text from a file given its name.""" | ||||||||||||||||
| hlo_file_path = os.path.join(self.HLO_DIR, f"{hlo_name}.txt") | ||||||||||||||||
| with open(hlo_file_path, "r") as f: | ||||||||||||||||
| hlo_text = f.read() | ||||||||||||||||
| return hlo_text | ||||||||||||||||
|
|
||||||||||||||||
| def _make_args_based_on_input_tensor_shape_and_dtype(self, stablehlo_text: str): | ||||||||||||||||
| """Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly.""" | ||||||||||||||||
| # Parse function signature to extract argument information | ||||||||||||||||
|
Comment on lines
+2043
to
+2044
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: The regex pattern uses non-greedy matching
Suggested change
|
||||||||||||||||
| # Pattern matches: @main(%arg0: tensor<32x32xbf16>, %arg1: tensor<64xf32>, ...) | ||||||||||||||||
| pattern = r"@main\((.*?)\{" | ||||||||||||||||
| match = re.search(pattern, stablehlo_text) | ||||||||||||||||
|
|
||||||||||||||||
| if not match: | ||||||||||||||||
| raise ValueError("Could not find @main function signature in StableHLO text") | ||||||||||||||||
|
|
||||||||||||||||
| args_str = match.group(1) | ||||||||||||||||
|
|
||||||||||||||||
| # Parse individual arguments | ||||||||||||||||
|
Comment on lines
+2052
to
+2054
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: The parsing logic assumes shape dimensions are separated by 'x' and the last element is always the dtype. This will fail for scalar tensors (e.g.,
Suggested change
|
||||||||||||||||
| # Pattern matches: %arg0: tensor<32x32xbf16> | ||||||||||||||||
| arg_pattern = r"%arg(\d+):\s*tensor<([^>]+)>" | ||||||||||||||||
| arg_matches = re.findall(arg_pattern, args_str) | ||||||||||||||||
|
|
||||||||||||||||
| parsed_args = [] | ||||||||||||||||
| for arg_num, shape_and_dtype_str in arg_matches: | ||||||||||||||||
| print(f"Parsing argument {arg_num} with shape and dtype: {shape_and_dtype_str}") | ||||||||||||||||
| # Parse shape: "32x32xbf16" -> [32, 32] | ||||||||||||||||
| dtype_str = shape_and_dtype_str.split("x")[-1] | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Missing handling for
Suggested change
|
||||||||||||||||
| shape = [int(dim) for dim in shape_and_dtype_str.split("x")[:-1]] | ||||||||||||||||
|
|
||||||||||||||||
| # Map StableHLO dtype to JAX dtype | ||||||||||||||||
| dtype_map = { | ||||||||||||||||
| "bf16": jnp.bfloat16, | ||||||||||||||||
| "f32": jnp.float32, | ||||||||||||||||
| "f16": jnp.float16, | ||||||||||||||||
| "f8E4M3FN": jnp.float8_e4m3fn, | ||||||||||||||||
| "f8E5M2": jnp.float8_e5m2, | ||||||||||||||||
| "i32": jnp.int32, | ||||||||||||||||
| "ui32": jnp.uint32, | ||||||||||||||||
| } | ||||||||||||||||
| dtype = dtype_map.get(dtype_str) | ||||||||||||||||
|
|
||||||||||||||||
| parsed_args.append(jnp.ones(shape, dtype=dtype)) | ||||||||||||||||
| return parsed_args | ||||||||||||||||
|
|
||||||||||||||||
| def test_ffi_compatibility(self, ffi_hlo_name): | ||||||||||||||||
| """Tests that the current FFI bindings are compatible with the provided HLO and there are no API mismatches.""" | ||||||||||||||||
| from jax.extend.backend import get_backend | ||||||||||||||||
|
|
||||||||||||||||
| stablehlo_text = self._get_hlo_text_from_file(ffi_hlo_name) | ||||||||||||||||
| args = self._make_args_based_on_input_tensor_shape_and_dtype(stablehlo_text) | ||||||||||||||||
|
|
||||||||||||||||
| client = get_backend("cuda") | ||||||||||||||||
| executable = client.compile_and_load( | ||||||||||||||||
| stablehlo_text.encode("utf-8"), executable_devices=jax.devices()[:1] | ||||||||||||||||
| ) | ||||||||||||||||
| results = executable.execute(args) | ||||||||||||||||
| print(results) # No need to assert anything here, just ensure it runs without error | ||||||||||||||||
|
|
||||||||||||||||
| def test_all_primitive_ffi_tested(self): | ||||||||||||||||
| """Ensures that all our TE primitives with FFI bindings are included in the FFI HLO compatibility tests.""" | ||||||||||||||||
| # Open all HLO files and extract primitive FFI names | ||||||||||||||||
| tested_hlos = set() | ||||||||||||||||
| for file in os.listdir(self.HLO_DIR): | ||||||||||||||||
| file_path = os.path.join(self.HLO_DIR, file) | ||||||||||||||||
| if os.path.isfile(file_path) and file.endswith(".txt"): | ||||||||||||||||
| with open(file_path, "r") as f: | ||||||||||||||||
| hlo_text = f.read() | ||||||||||||||||
| # Extract primitive name from HLO text | ||||||||||||||||
| pattern = r"stablehlo.custom_call @(.+?)\(" | ||||||||||||||||
| matches = re.findall(pattern, hlo_text) | ||||||||||||||||
| if matches: | ||||||||||||||||
| for match in matches: | ||||||||||||||||
| primitive_name = match | ||||||||||||||||
| tested_hlos.add(primitive_name) | ||||||||||||||||
|
|
||||||||||||||||
| # Assert that all registered primitives have corresponding FFI tests | ||||||||||||||||
| import transformer_engine_jax | ||||||||||||||||
|
|
||||||||||||||||
| KNOWN_MISSING_FFI_TESTS = { | ||||||||||||||||
| # dequantize does not have a JAX primitive currently | ||||||||||||||||
| "te_dequantize_ffi", | ||||||||||||||||
| # needs testing | ||||||||||||||||
| "te_grouped_gemm_d2h_group_sizes_ffi", | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| unmatched_primitives = set() | ||||||||||||||||
| for primitive_ffi_name, _ in transformer_engine_jax.registrations().items(): | ||||||||||||||||
| if ( | ||||||||||||||||
| primitive_ffi_name not in tested_hlos | ||||||||||||||||
| and primitive_ffi_name not in KNOWN_MISSING_FFI_TESTS | ||||||||||||||||
| ): | ||||||||||||||||
| unmatched_primitives.add(primitive_ffi_name) | ||||||||||||||||
|
|
||||||||||||||||
| assert ( | ||||||||||||||||
| len(unmatched_primitives) == 0 | ||||||||||||||||
| ), f"The following primitives do not have FFI tests: {unmatched_primitives}" | ||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: The fixture parameter
shapeis undefined and not used in the function body. This will cause an error when pytest tries to parametrize this fixture.