Skip to content

Precomputed swizzle_idx into group Hadamard ComputeKernel#2808

Open
cael-ling wants to merge 3 commits intoNVIDIA:mainfrom
cael-ling:refactor/grp-hadamard-swizzle-outside-loops
Open

Precomputed swizzle_idx into group Hadamard ComputeKernel#2808
cael-ling wants to merge 3 commits intoNVIDIA:mainfrom
cael-ling:refactor/grp-hadamard-swizzle-outside-loops

Conversation

@cael-ling
Copy link
Copy Markdown

@cael-ling cael-ling commented Mar 29, 2026

Description

ComputeKernel used to derive warp_id, local_rank, ld_row_idx, ld_col_idx, and swizzle_idx from threadIdx.x on every call. Those quantities depend only on the thread’s position in the block and template constants; they do not change with pipeline stage, compute_stage_y / compute_stage_x, or the per-tile in_sh_ptr offset.

GroupHadamardAmaxTmaKernel now computes them once per thread before the for (stage_y) loop and passes swizzle_idx into ComputeKernel, avoiding redundant work in the hot nested loops. Behavior is unchanged; this is a small micro-optimization and clearer separation of loop-invariant mapping vs. per-tile pointer arithmetic.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 29, 2026

Greptile Summary

This PR hoists the computation of warp_id, local_rank, ld_row_idx, ld_col_idx, and swizzle_idx out of the innermost hot loops and into the kernel prologue, then passes the precomputed swizzle_idx as an explicit parameter to ComputeKernel. The change is applied consistently across all three sibling kernel files (group_hadamard_transform.cu, hadamard_transform.cu, graph_safe_group_hadamard_transform.cu).

  • The refactor is semantically correct: swizzle_128B_atom_32B(ld_row_idx, ld_col_idx) is a pure function whose inputs depend only on threadIdx.x and the compile-time constant kHadamardDimension — both are invariant over all stage/compute-stage loop iterations.
  • The tile-varying pointer arithmetic (in_sh_ptr + in_row_offset + compute_stage_x * …) remains at the call site, correctly separating loop-invariant indexing from per-tile offsets.
  • Note that warp_id (x-dimension only, used for swizzle) and warpid (full 2D index, used for the final ReduceMax) are intentionally distinct, a distinction preserved from the original code.
  • No behavioral changes are introduced; this is a pure micro-optimization and code-clarity improvement.

Confidence Score: 5/5

Safe to merge — pure loop-invariant hoisting with no behavioral change across all three kernel files.

The optimization is mathematically correct (swizzle_idx depends only on threadIdx.x and a compile-time constant), applied uniformly across all three kernel variants, and introduces no new variables or logic paths. No P0 or P1 issues found.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/hadamard_transform/group_hadamard_transform.cu Hoists warp_id / local_rank / swizzle_idx computation out of the hot nested loop into the kernel prologue, and passes swizzle_idx as a parameter to ComputeKernel — correct and semantically equivalent.
transformer_engine/common/hadamard_transform/hadamard_transform.cu Same swizzle_idx hoisting applied to HadamardAmaxTmaKernel; change is consistent with the group variant and equally correct.
transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu Same swizzle_idx hoisting applied to GraphSafeGroupHadamardAmaxTmaKernel; change is consistent with the other two variants and equally correct.

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +302 to +306
const int warp_id = threadIdx.x / kThreadsPerWarp;
const int local_rank = threadIdx.x % kThreadsPerWarp;
const int ld_row_idx = local_rank % kHadamardDimension;
const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same optimization not applied to sibling files

hadamard_transform.cu and graph_safe_group_hadamard_transform.cu contain near-identical ComputeKernel definitions that still recompute warp_id, local_rank, ld_row_idx, ld_col_idx, and swizzle_idx inside the function body on every invocation. If the goal is to eliminate redundant per-iteration work, those two files have the same hot-loop structure and would benefit from the same refactor.

This is not a bug — since ComputeKernel is __forceinline__, the compiler can already hoist these invariants under optimization. But for consistency and to complete the stated intent of the PR, consider applying the same pattern to:

  • hadamard_transform.cu:35-40 / call site at ~line 288
  • graph_safe_group_hadamard_transform.cu:74-79 / call site at ~line 362

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling force-pushed the refactor/grp-hadamard-swizzle-outside-loops branch from c5b2087 to f101b02 Compare March 29, 2026 10:20
@cael-ling cael-ling changed the title refactor(hadamard): pass precomputed swizzle_idx into group Hadamard ComputeKernel Precomputed swizzle_idx into group Hadamard ComputeKernel Mar 29, 2026
@zhongbozhu
Copy link
Copy Markdown
Collaborator

LGTM, can you introduce this change to other similar variants of the kernel? https://github.com/NVIDIA/TransformerEngine/tree/main/transformer_engine/common/hadamard_transform

@cael-ling
Copy link
Copy Markdown
Author

LGTM, can you introduce this change to other similar variants of the kernel? https://github.com/NVIDIA/TransformerEngine/tree/main/transformer_engine/common/hadamard_transform

OK, will do

cael-ling and others added 2 commits March 30, 2026 19:10
@cael-ling
Copy link
Copy Markdown
Author

The change has been applied to variants:(group_hadamard_transform.cu/hadamard_trnsform.cu/graph_safe_group_hadamard_transform.cu)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants