Skip to content

[Common][JAX] Add CUB TopK MaxPairs interface#2784

Open
huanghua1994 wants to merge 3 commits intoNVIDIA:mainfrom
huanghua1994:CUB-topk
Open

[Common][JAX] Add CUB TopK MaxPairs interface#2784
huanghua1994 wants to merge 3 commits intoNVIDIA:mainfrom
huanghua1994:CUB-topk

Conversation

@huanghua1994
Copy link
Copy Markdown
Collaborator

Description

This PR introduces the new CUB TopK API for large N and K values.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added 3rdparty/cccl as a dependency since the CTK on the machine might not be new enough
  • Added transformer_engine/common/util/cub.cu as the enter point to the CUB TopK function
  • Added JAX FFI interface and JAX tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 20, 2026

Greptile Summary

This PR introduces a JAX-accessible topk primitive backed by cub::DeviceTopK::MaxPairs from CCCL, adding a new C header (cub.h), a CUDA implementation (cub.cu), an XLA FFI handler (cub.cpp), and a Python TopKPrimitive/topk wrapper. The feature is complete and functional for the single-device path with good dtype coverage (float32, float16, bfloat16 keys, int32 values) and a well-structured test suite.

Key points:

  • Global CCCL version override (P1): The CMakeLists change replaces the CTK-provided CUB/Thrust/libcudacxx include paths with the bundled submodule for all targets, not just the new TopK code. This silently changes the CCCL version seen by all existing consumers in the project and could cause subtle regressions or build failures in environments where the bundled CCCL commit is incompatible with the installed CTK.
  • Missing batcher / partition / shardy_sharding_rule (addressed in prior thread): TopKPrimitive does not implement these abstract methods, meaning vmap and multi-device sharding will fail at runtime.
  • Hardcoded 4 MiB workspace (addressed in prior thread): No runtime guard for N > 5,000,000 or K > 100,000 inputs.
  • Missing k <= num_items guard (addressed in prior thread): Passing k > num_items to CUB is undefined behaviour.
  • Missing k > 0 guard: No validation that k_value is positive before reaching the C++ layer.
  • Discarded MaxPairs return value in cub.cu: Errors from the new execution-environment API are only caught indirectly via cudaGetLastError() later.

Confidence Score: 4/5

  • Safe to merge after addressing the global CCCL override scope and the previously-raised missing guards; the core CUDA implementation is correct.
  • The CMakeLists change globally replaces CTK-provided CCCL headers for all targets — a P1 concern with potential build-compatibility impact across the project. The previously-raised issues (missing batcher/partition, workspace bounds, k<=num_items) also remain open. The CUDA kernel logic and FFI wiring are sound, and the tests are thorough for the single-device path.
  • transformer_engine/common/CMakeLists.txt (global CCCL override), transformer_engine/jax/cpp_extensions/cub.py (missing abstract methods, k_value guard)

Important Files Changed

Filename Overview
transformer_engine/common/CMakeLists.txt Adds CCCL submodule include paths and cub.cu to the build; the global replacement of CTK-provided CCCL headers with the pinned submodule version affects all existing CUB/Thrust consumers in the library, not just the new TopK code.
transformer_engine/common/util/cub.cu Implements nvte_topk using the new CUB execution-environment API; dispatches correctly for float32/float16/bfloat16 keys with int32 values. Return value of MaxPairs is discarded.
transformer_engine/jax/cpp_extensions/cub.py Introduces TopKPrimitive and topk public API; missing batcher/partition/shardy_sharding_rule (noted in prior thread), hardcoded workspace (prior thread), and no lower-bound guard on k_value.
transformer_engine/jax/csrc/extensions/cub.cpp FFI handler TopkFFI validates dtypes and shapes well; missing k > 0 and k <= num_items guards (latter noted in prior thread).
tests/jax/test_custom_call_compute.py Adds TestTopk with parametrized dtype and size coverage; sorts outputs before comparison to handle CUB's unsorted guarantee; correctness assertions are sound for all tested (n, k) pairs.

Sequence Diagram

sequenceDiagram
    participant User as Python caller
    participant topk as topk() [cub.py]
    participant Outer as TopKPrimitive.outer_primitive
    participant Impl as TopKPrimitive.impl()
    participant Inner as TopKPrimitive.inner_primitive
    participant FFI as TopkFFI [cub.cpp]
    participant CUDA as nvte_topk [cub.cu]
    participant CUB as cub::DeviceTopK::MaxPairs

    User->>topk: topk(x, k_value)
    topk->>topk: values = arange(N, int32)
    topk->>Outer: outer_primitive.bind(keys, values, k_value)
    Outer->>Impl: impl(keys, values, k_value)
    Impl->>Inner: inner_primitive.bind(keys, values, k_value)
    Inner->>FFI: te_topk_ffi (XLA FFI dispatch)
    FFI->>FFI: validate dtypes, shapes
    FFI->>CUDA: nvte_topk(stream, keys_in, values_in, keys_out, values_out, workspace, N, k, ws_bytes)
    CUDA->>CUB: DeviceTopK::MaxPairs(workspace, ws_bytes, ..., N, k, env)
    CUB-->>CUDA: top-k (keys, indices) written to output buffers
    CUDA-->>FFI: return
    FFI->>FFI: ffi_with_cuda_error_check()
    FFI-->>Inner: (keys_out, values_out, workspace)
    Inner-->>Impl: (keys_out, values_out, workspace)
    Impl-->>Outer: (keys_out, values_out)
    Outer-->>topk: (keys_out, values_out)
    topk-->>User: (top_k_values, top_k_indices)
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +27 to +94
class CubTopkPrimitive(BasePrimitive):
"""
CUB Topk Primitive
"""

name = "te_cub_topk_ffi"
multiple_results = True
impl_static_args = (2,) # k_value
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
in_keys_aval,
in_values_aval,
*,
k_value,
):
keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype)
values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype)
assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert values_dtype == jnp.int32

workspace_bytes = get_cub_topk_workspace_bytes()
out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype)
out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32)
workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8)
return (out_keys_aval, out_values_aval, workspace_aval)

@staticmethod
def outer_abstract(*args, **kwargs):
out_keys_aval, out_values_aval, _workspace_aval = CubTopkPrimitive.abstract(*args, **kwargs)
return (out_keys_aval, out_values_aval)

@staticmethod
def lowering(
ctx,
in_keys,
in_values,
k_value,
):
workspace_bytes = get_cub_topk_workspace_bytes()
return ffi.ffi_lowering(
CubTopkPrimitive.name,
)(
ctx,
in_keys,
in_values,
k_value=k_value,
workbuf_bytes=workspace_bytes,
)

@staticmethod
def impl(
in_keys,
in_values,
k_value,
):
assert CubTopkPrimitive.inner_primitive is not None
out_keys, out_values, _workspace = CubTopkPrimitive.inner_primitive.bind(
in_keys,
in_values,
k_value=k_value,
)
return (out_keys, out_values)


register_primitive(CubTopkPrimitive)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing batcher, partition, and shardy_sharding_rule methods

CubTopkPrimitive extends BasePrimitive, which declares batcher(), partition(), and shardy_sharding_rule() as abstract methods. CubTopkPrimitive does not implement any of them.

When register_primitive(CubTopkPrimitive) is called, base.py does:

batching.primitive_batchers[outer_p] = cls.batcher  # resolves to abstract method → returns NotImplemented
outer_p_lower.def_partition(partition=cls.partition, ...)  # same

This means:

  • Any attempt to use vmap over cub_topk will fail at runtime because the registered batcher is the abstract method (which returns NotImplemented, not a callable that returns batched results).
  • Multi-device / sharding via custom_partitioning will similarly fail when partition is invoked.

Every other primitive in the codebase (e.g., in router.py) implements all three of these methods. If sharding and batching are intentionally unsupported for now, the methods should at minimum raise a clear NotImplementedError (rather than silently returning NotImplemented), and this limitation should be documented.

Comment on lines +17 to +24
def get_cub_topk_workspace_bytes() -> int:
"""
Get the workspace size for CUB Topk
The safe way is calling the CUB kernel to query the workspace size.
For convenience, we use a heuristic value based on experiments.
4 MiB is enough for N up to 5,000,000 and K up to 100,000.
"""
return 4 * 1024 * 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded workspace size may silently corrupt memory for large inputs

get_cub_topk_workspace_bytes() always returns a fixed 4 MiB and the docstring itself acknowledges this only covers "N up to 5,000,000 and K up to 100,000." However, there is no validation in the Python or C++ layer that the user's actual N and K do not exceed these limits.

If a caller passes N > 5_000_000 or K > 100_000, cub::DeviceTopK::MaxPairs will be given an undersized workspace buffer and will write out-of-bounds on the GPU — a silent CUDA memory corruption with no error raised back to the caller.

The correct approach is to call cub::DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, ...) with a null workspace pointer to query the required size at runtime, then allocate that exact amount. The current heuristic should at minimum be accompanied by a runtime guard that raises an error when the inputs exceed the documented limits.

Comment on lines +39 to +40
int num_items = static_cast<int>(keys_in_shape[0]);
int k = static_cast<int>(k_value);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No validation that k <= num_items

There is no check that k_value is less than or equal to num_items (the size of the input array). CUB's DeviceTopK::MaxPairs requires k <= num_items; if k > num_items the behavior is undefined and will likely produce a CUDA error or garbage output.

A guard should be added here alongside the existing shape checks:

NVTE_CHECK(k <= num_items, "k (", k, ") must be <= num_items (", num_items, ")");


from .base import BasePrimitive, register_primitive

__all__ = ["CubTopkPrimitive"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Public function cub_topk not exported in __all__

The user-facing function cub_topk (defined at line 97) is not included in __all__. Only CubTopkPrimitive is listed. Tools and users relying on __all__ for the module's public API won't discover cub_topk. Since the test imports it as a primary API (from transformer_engine.jax.cpp_extensions.cub import cub_topk), it should be exported.

Suggested change
__all__ = ["CubTopkPrimitive"]
__all__ = ["CubTopkPrimitive", "cub_topk"]

Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants