diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..aadc77570f 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -366,12 +366,10 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } - - // Ensure all threads have finished their computation before new data over-writes the shared - // memory. - __syncthreads(); } - + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); // Ensure generic shared-memory accesses are visible before the next TMA write. ptx::fence_proxy_async_shared_cta(); } diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..38418415c8 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -349,11 +349,10 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } - - // Ensure all threads have finished their computation before new data over-writes the shared - // memory. - __syncthreads(); } + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); // Ensure generic shared-memory accesses are visible before the next TMA write. ptx::fence_proxy_async_shared_cta(); diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4adc836886..8feef7d078 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -292,12 +292,10 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } - - // Ensure all threads have finished their computation before new data over-writes the shared - // memory. - __syncthreads(); } - + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); // Ensure generic shared-memory accesses are visible before the next TMA write. ptx::fence_proxy_async_shared_cta(); }