From a4dea3ba50256e5dbdbca3f1cc91a960b94ebc8c Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Fri, 27 Mar 2026 00:19:52 -0700 Subject: [PATCH 1/3] one __syncthreads per stage in GroupHadamardAmaxTmaKernel Signed-off-by: Cael Ling Made-with: Cursor --- .../hadamard_transform/group_hadamard_transform.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..c352d506de 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -350,11 +350,12 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } - // Ensure all threads have finished their computation before new data over-writes the shared - // memory. - __syncthreads(); - } + } + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + // Ensure generic shared-memory accesses are visible before the next TMA write. ptx::fence_proxy_async_shared_cta(); } From 9822e079afc40d2af5d6595d10ce369f799f927d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Mar 2026 07:04:32 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/hadamard_transform/group_hadamard_transform.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index c352d506de..38418415c8 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -349,13 +349,11 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } - - } // Ensure all threads have finished their computation before new data over-writes the shared - // memory. + // memory. __syncthreads(); - + // Ensure generic shared-memory accesses are visible before the next TMA write. ptx::fence_proxy_async_shared_cta(); } From c89118d5abfca2b465b7e042c3343cc0b80efb99 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 30 Mar 2026 18:26:23 -0700 Subject: [PATCH 3/3] Apply the change to other variants Signed-off-by: Cael Ling --- .../graph_safe_group_hadamard_transform.cu | 8 +++----- .../common/hadamard_transform/hadamard_transform.cu | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..aadc77570f 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -366,12 +366,10 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } - - // Ensure all threads have finished their computation before new data over-writes the shared - // memory. - __syncthreads(); } - + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); // Ensure generic shared-memory accesses are visible before the next TMA write. ptx::fence_proxy_async_shared_cta(); } diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4adc836886..8feef7d078 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -292,12 +292,10 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } - - // Ensure all threads have finished their computation before new data over-writes the shared - // memory. - __syncthreads(); } - + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); // Ensure generic shared-memory accesses are visible before the next TMA write. ptx::fence_proxy_async_shared_cta(); }