diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..231d522f3a 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -65,19 +65,13 @@ __device__ __forceinline__ size_t get_current_tensor_id( template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -322,6 +316,12 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -364,7 +364,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..8b7f079072 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -41,19 +41,13 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -305,6 +299,12 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -347,7 +347,7 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4adc836886..216ed1930a 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -26,19 +26,13 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -248,6 +242,12 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -290,7 +290,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared