From 4fc0a38fd3385e072436ed4ff4309d06db317ea0 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 21 Nov 2025 23:41:19 +0000 Subject: [PATCH 01/23] Implemented persistent nvfp4 kernel Signed-off-by: Oleg Goncharov --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 53 +- .../common/cast/core/common.cuh | 6 + .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 7 + ...quantize_transpose_nvfp4_persistent_1D.cuh | 759 ++++++++++++++++++ transformer_engine/common/util/ptx.cuh | 257 ++++++ 5 files changed, 1059 insertions(+), 23 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index afd7927da2..3315cb513c 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -349,6 +349,7 @@ void compare_nvfp4_tensors(const std::string& name, const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const int rows, const int cols, double atol = 1e-5, double rtol = 1e-8) { + constexpr bool print_detailed_summary = false; std::vector mismatch_messages; size_t total_mismatches = 0; @@ -381,36 +382,42 @@ void compare_nvfp4_tensors(const std::string& name, std::to_string(t) + " vs " + std::to_string(r) + " (abs_diff: " + std::to_string(fabs(t - r)) + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; - mismatch_messages.push_back(msg); - - // Optional: limit number of detailed messages to avoid overwhelming output - if (mismatch_messages.size() <= 100) { - std::cout << "Error in tensor " << name << ": " << msg << std::endl; + if constexpr (print_detailed_summary) { + mismatch_messages.push_back(msg); + + // Optional: limit number of detailed messages to avoid overwhelming output + if (mismatch_messages.size() <= 100) { + std::cout << "Error in tensor " << name << ": " << msg << std::endl; + } + } else { + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; } } } } } - // Always report summary - either success or failure - std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; - std::cout << "Total elements checked: " << (rows * cols) << std::endl; - - if (total_mismatches > 0) { - std::cout << "STATUS: FAILED for output" << std::endl; - std::cout << "Total mismatches found: " << total_mismatches << std::endl; - std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; - if (mismatch_messages.size() > 100) { - std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + if constexpr (print_detailed_summary) { + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > 100) { + std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; } - std::cout << "============================" << std::endl; - - GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; - } else { - std::cout << "STATUS: PASSED for output" << std::endl; - std::cout << "All elements match within tolerance!" << std::endl; - std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; - std::cout << "============================" << std::endl; } } diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index b750142f5b..c80fcdad8f 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -35,6 +35,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) { return cols % alignment_requirement == 0; } +__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) { + size_t addr = reinterpret_cast(p); + addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1); + return reinterpret_cast(addr); +} + namespace kernel { constexpr size_t THREADS_PER_BLOCK = 256; diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 7322bf2655..aff388cd92 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -21,6 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" +#include "specialized/quantize_transpose_nvfp4_persistent_1D.cuh" namespace transformer_engine { namespace dispatch { @@ -1159,6 +1160,12 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, #if FP4_TYPE_SUPPORTED using namespace quantize_transpose_kernel; using namespace ptx; + + if (!use_2d_quantization && input.dtype() == DType::kBFloat16) { + quantize_transpose_persistent_1D(input, noop, output, quant_config, stream); + return; + } + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh new file mode 100644 index 0000000000..09edfff321 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh @@ -0,0 +1,759 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4_persistent_1D.cuh + * \brief Persistent kernel to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_PERSISTENT_1D_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_PERSISTENT_1D_CUH_ + +#include +#include +#include +#include + +#include "../../../common.h" +#include "../../../util/math.h" +#include "../../../util/ptx.cuh" +#include "../../../utils.cuh" +#include "../core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace quantize_transpose_persistent_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) +static_assert(SCALE_DIM == 16 && "NVFP4 block size is 16\0"); + +constexpr int THREADS_NUM = 128; +constexpr int ELTS_PER_THREAD = 16; +constexpr int CHUNK_DIM_Y = 128; +constexpr int CHUNK_DIM_X = 128; +constexpr int TILE_DIM_Y = 64; +constexpr int TILE_DIM_X = 64; + +static_assert(THREADS_NUM == 128 && "Hardcoded and fixed parameter\0"); +static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0"); +static_assert(TILE_DIM_Y == 64 && "Hardcoded and fixed parameter\0"); +static_assert(TILE_DIM_X == 64 && "Hardcoded and fixed parameter\0"); + + +static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && + "Unbalanced threads workload\0"); + +static_assert((CHUNK_DIM_Y % TILE_DIM_Y == 0) && + "Chunk size Y must be evenly divisible by the tile size Y\0"); +static_assert((CHUNK_DIM_X % TILE_DIM_X == 0) && + "Chunk size X must be evenly divisible by the tile size X\0"); + +static_assert((TILE_DIM_Y % SCALE_DIM == 0) && + "Tile size Y must be evenly divisible by the scale dim\0"); +static_assert((TILE_DIM_X % SCALE_DIM == 0) && + "Tile size X must be evenly divisible by the scale dim\0"); + +constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; + +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; + +constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; + +constexpr int STAGES_Y = TILES_Y; +constexpr int STAGES_X = TILES_X; +constexpr int STAGES = STAGES_Y * STAGES_X; + +constexpr int PREFETCH_STAGES = 1; +constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; +constexpr int BUFFS_NUM_IN = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_DIM_X = TILE_DIM_X; +constexpr int BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr int BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr int BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; +constexpr int BUFF_IN_ELTS_NUM = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr int BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr int BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr int BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr int BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr int PACK_SIZE = 8; +static_assert(PACK_SIZE == 8 && "Pack size is fixed to 8\0"); +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; + +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; + +constexpr int THREADS_X_TRANSP = TILE_DIM_X / 2; +constexpr int THREADS_Y_TRANSP = THREADS_NUM / THREADS_X_TRANSP; + +constexpr int ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; +constexpr int ITERATIONS_TRANSPOSE = SCALES_PER_TILE_Y / THREADS_Y_TRANSP; +static_assert(ITERATIONS_TRANSPOSE >= 1 && "Number of transpose iterations should be >=1\0"); +static_assert((SCALES_PER_TILE_Y % THREADS_Y_TRANSP == 0) + && "Partial transpose iterations are not supported\0"); + +constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE / STAGES; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; + +using IType = bf16; +using IType2 = typename ptx::FPx2; +using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X/2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_T_DIM_Y][BUFF_OUT_T_DIM_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; +using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>; + +__device__ __forceinline__ float +get_amax_of_pair(const IType2 xormax_pair) { + return static_cast(__hmax(__habs(xormax_pair.x), __habs(xormax_pair.y))); +} + +// Compute "correct" per-block encoding scaling factor +__device__ __forceinline__ bf16 +compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_dec) { + constexpr float float_max = detail::TypeExtrema::max; + const float scale_rcp = fminf(1.0f / (static_cast(S_dec_block) * S_dec), float_max); + return static_cast(scale_rcp); +} + +template +__device__ __forceinline__ void +colwise_scaling(const IType * __restrict__ sIn_ptr, + fp4e2m1x2 * __restrict__ sOut_tr_ptr, + nvfp4_scale_t * __restrict__ sSFcolwise_ptr, + const float S_enc_colwise, + const float S_dec_colwise, + const int stage_Y, + const int stage_X, + const int buff_in, + const int buff_out_tr, + RNG_t& rng, uint4 &random_uint4, int &rnd_idx) { + const auto &sIn2x = *reinterpret_cast(sIn_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; + const int tid_X_colwise = thread_lane; + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; + const int thread_offset_X_colwise = tid_X_colwise * 2; + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; + + __align__(8) IType rIn[2][SCALE_DIM]; + // Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + #pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), + static_cast(__habs(thread_amax_2x.y))}; + #pragma unroll + for (int w = 0; w < 2; ++w) { + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + + const bf16 SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_dec_colwise); + + // Scale elements + __align__(8) uint32_t rOut[SCALE_DIM / 8]; + #pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding(elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, SFcoefficient); + } + } + uint64_t& out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], out_pack_16x); + } +} + +template +__device__ __forceinline__ void +rowwise_scaling(const IType * __restrict__ sIn_ptr, + fp4e2m1x2 * __restrict__ sOut_ptr, + nvfp4_scale_t * __restrict__ sSFrowwise_ptr, + const float S_enc_rowwise, + const float S_dec_rowwise, + const int stage_Y, + const int stage_X, + const int buff_in, + const int buff_out, + RNG_t& rng, uint4 &random_uint4, int &rnd_idx) { + const auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; + + const int SF_thread_offset_rowwise_Y = tid_Y_rowwise; + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; + + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_Y = SF_thread_offset_rowwise_Y + stage_Y * TILE_DIM_Y; + const int stage_rowwise_scales_offset_X = SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; + #pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + // Read (cache) input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + uint64_t& elts03 = *reinterpret_cast(&rIn[w][0]); + uint64_t& elts47 = *reinterpret_cast(&rIn[w][2]); + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + // Load elements + ptx::ld_shared_b128(elts03, elts47, &sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); + #pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = get_amax_of_pair(thread_amax_2x); + + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + const bf16 SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_dec_rowwise); + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } + + // Scale elements + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + + uint32_t out_x8; + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding(elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, SFcoefficient); + } + + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + quantize_transpose_nvfp4_persistent_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG_t rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + const bool leading_thread = (threadIdx.x == 0); + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + + constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_T_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int in_mem = buff_size_aligned_in; + + constexpr int out_mem_rowwise_data = buff_size_aligned_out; + constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = + DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType *sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2 *sOut_ptr = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *sOut_tr_ptr = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + + nvfp4_scale_t *sSFrowwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *sSFcolwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + __shared__ uint64_t workID_mbar; + __shared__ __uint128_t workID_response; + constexpr uint32_t workID_response_size = sizeof(workID_response); + static_assert(workID_response_size == 16); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + // Coordinates of the first chunk (CTA) to process + int32_t ctaid_X = blockIdx.x; + int32_t ctaid_Y = blockIdx.y; + + // Initialize shared memory barriers with the number of threads participating in them + if (leading_thread) { + #pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::mbarrier_init(&workID_mbar, 1); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + bool job_finished = false; + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + int ctaid_parity = 0; + + // Prefetch input data only when processing the first chunk, + // which enables the one-iteration overlap throughout the entire kernel life + #pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + + uint64_t* barrier = &IN_buff_readable_mbar[buff_in]; + if (leading_thread) { + uint64_t* dst = reinterpret_cast(&sIn[buff_in]); + const uint64_t* src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, barrier); + } + } + + while (!job_finished) { + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + const int block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; + const int block_offset_X_tr = ctaid_Y * CHUNK_DIM_Y; + + const int chunk_rows = rows - block_offset_Y; + const int chunk_cols = cols - block_offset_X; + + const int scales_block_offset_Y_rowwise = ctaid_Y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X; + const int scales_block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; + const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y; + + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes(&workID_mbar, &workID_response); + } + + #pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + if (stage == STAGES - PREFETCH_STAGES) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_2D_id(&workID_response, ctaid_X, ctaid_Y); + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } + ctaid_parity ^= 1; + } + + // Prefetch next stage Input data + if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; + const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; + + const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y; + const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X; + + // Offsets change, because coordinates of the next "to-be-prefetched" CTA do also chage + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const int global_offset_X = block_offset_X + next_prefetch_stage_offset_X; + + uint64_t* barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + if (leading_thread) { + uint64_t* dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t* src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, barrier); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Wait for TMA transfer to have finished reading shared memory + // I.e. the OUT buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read(); + + // NVFP4 Quantization + rowwise_scaling( + sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, S_dec_rowwise, + stage_Y, stage_X, buff_in, buff_out, rng, random_uint4, rnd_idx); + + if constexpr (RETURN_TRANSPOSE) { + colwise_scaling( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, S_dec_colwise, + stage_Y, stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); + } + + // Wait for shared memory writes to be visible to TMA engine + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine + + // Initiate TMA transfer to copy shared memory to global memory + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X; + const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&sOut[buff_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, global_offset_Y_tr, + reinterpret_cast(&sOut_tr[buff_out_tr])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation + ptx::cp_async_bulk_commit_group(); + } + + buff_in = (buff_in + 1) % BUFFS_NUM_IN; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } // end of stages + + // Vectorized store of scaling factors (S2G) + { + // Rowwise + { + using ScalesVec = Vec; + // number of scales in X dimension of this chunk + const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); + + for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + ScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = row_global * scale_stride + + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + } + } + } + + // Colwise + if constexpr (RETURN_TRANSPOSE) { + using ScalesVec = Vec; + // number of scales in Y dimension of this chunk + const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); + + for (size_t row_tr = threadIdx.x; row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { + const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; + if (row_tr_global < cols) { + ScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = row_tr_global * scale_stride_t + + scales_block_offset_X_tr; + scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); + } + } + } + + if (!job_finished) { + // Ensures all reads from SFs buffer have completed and it's ready to be reused + __syncthreads(); + } + } + } + + if (leading_thread) { + #pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + ptx::mbarrier_invalid(&workID_mbar); + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +#endif // FP4_TYPE_SUPPORTED +} // namespace quantize_transpose_persistent_kernel + +inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor *noop, + Tensor *output, const QuantizationConfig *quant_config, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_transpose_persistent_kernel; + using namespace ptx; + + const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + // If transposed output is allocated, return the transposed data + // Otherwise, it's not necesary to return the transposed data. + const bool return_transpose = output->has_columnwise_data(); + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + + if (return_transpose) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const int blocks_Y = DIVUP(rows, static_cast(CHUNK_DIM_Y)); + const int blocks_X = DIVUP(cols, static_cast(CHUNK_DIM_X)); + const dim3 grid(blocks_X, blocks_Y); + const int block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_T_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int buff_size_scales = + DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = + DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + const int in_mem = buff_size_aligned_in; + + const int out_data_mem = buff_size_aligned_out; + const int out_data_transpose_mem = return_transpose ? buff_size_aligned_out_t : 0; + const int out_scales_mem = buff_size_scales; + const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0; + + const int out_mem = out_data_mem + out_data_transpose_mem; + + const int dshmem_size = in_mem + out_mem + + out_scales_transpose_mem + out_scales_mem + + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, + { + auto kernel = quantize_transpose_nvfp4_persistent_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + } + ); + ); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_PERSISTENT_1D_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 754cbd900a..4b7a5c9c44 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -153,6 +153,24 @@ __device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_arrive_relaxed_cta_shared_cta(uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.relaxed.cta.shared::cta.b64 _, [%0], 1;" ::"r"(mbar_ptr)); +#else + NVTE_DEVICE_ERROR("mbarrier_arrive_relaxed_cta_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mbarrier_arrive_release_cta_shared_cta(uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], 1;" ::"r"(mbar_ptr)); +#else + NVTE_DEVICE_ERROR("mbarrier_arrive_release_cta_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -164,6 +182,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta( + uint64_t *mbar, const uint32_t tx_count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;" ::"r"(mbar_ptr), + "r"(tx_count)); +#else + NVTE_DEVICE_ERROR( + "mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + __device__ __forceinline__ void fence_mbarrier_init_release_cluster() { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile("fence.mbarrier_init.release.cluster;"); @@ -243,6 +273,99 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar, + uint32_t phase_parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "{\n\t" + ".reg .b64 r1; \n\t" + ".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met + "WAIT: \n\t" // loop around barrier wait + "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t" + "@waitComplete bra DONE; \n\t" // mbarrier conditions are met + "bra WAIT; \n\t" // just a time-out, try again + "DONE: \n\t" + "}\n\t" + : + : "r"(mbar_ptr), "r"(phase_parity) + : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mbarrier_wait_parity_relaxed_cta_shared_cta(uint64_t *mbar, + uint32_t phase_parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "{\n\t" + ".reg .b64 r1; \n\t" + ".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met + "WAIT: \n\t" // loop around barrier wait + "mbarrier.try_wait.parity.relaxed.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t" + "@waitComplete bra DONE; \n\t" // mbarrier conditions are met + "bra WAIT; \n\t" // just a time-out, try again + "DONE: \n\t" + "}\n\t" + : + : "r"(mbar_ptr), "r"(phase_parity) + : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_wait_parity_relaxed_cta_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void +clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes( + uint64_t *mbar, __uint128_t *response_data_ptr) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::" + "all.b128 " + "[%0], [%1];" ::"r"(workID_response), + "r"(mbar_ptr)); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +__device__ __forceinline__ void get_cancelled_cta_2D_id(__uint128_t *response_data_ptr, + int32_t &ctaid_X, int32_t &ctaid_Y) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "{\n\t" + ".reg .s32 x_ctaid; \n\t" + ".reg .s32 y_ctaid; \n\t" + "mov .s32 x_ctaid, -1; \n\t" + "mov .s32 y_ctaid, -1; \n\t" + ".reg.b128 try_cancel_response; \n\t" + "ld.shared.b128 try_cancel_response, [%2]; \n\t" + ".reg .pred P1; \n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response; \n\t" + "@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, " + "_}, try_cancel_response; \n\t" + "mov .s32 %0, x_ctaid; \n\t" + "mov .s32 %1, y_ctaid; \n\t" + "}\n\t" + : "=r"(ctaid_X), "=r"(ctaid_Y) + : "r"(workID_response) + : "memory"); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; @@ -657,6 +780,95 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); } } + +__device__ __forceinline__ uint32_t +mul_cvt_bf16_to_fp4_8x_round_to_nearest(const uint64_t in03, const uint64_t in47, + const bf16 scaling_coefficient) { + uint32_t out_8x = 0; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + // Elements reordered to match e2m1x4 packing order (v1,v0) + "cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} + +__device__ __forceinline__ uint32_t +mul_cvt_bf16_to_fp4_8x_stochastic_rounding(const uint64_t in03, const uint64_t in47, + const bf16 scaling_coefficient, const uint32_t rbits03, + const uint32_t rbits47) { + uint32_t out_8x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient)), + "r"(rbits03), "r"(rbits47)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} + #endif // FP4_TYPE_SUPPORTED // SIMD like "Fused" cast + multiplication (x2) @@ -1505,6 +1717,51 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { return out; } +// Loads single BF16/FP16 element from shared memory state space +__device__ __forceinline__ bf16 +ld_shared_b16(const bf16 * __restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16 dst; + asm volatile("ld.shared.b16 %0, [%1];" : "=h"(reinterpret_cast(dst)) : "r"(src_smem_ptr)); + return dst; +} + +// Loads pair of BF16/FP16 values from shared memory state space +__device__ __forceinline__ bf16x2 +ld_shared_b32(const bf16x2 * __restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16x2 dst; + asm volatile("ld.shared.b32 %0, [%1];" : "=r"(reinterpret_cast(dst)) : "r"(src_smem_ptr)); + return dst; +} + +// Loads 8x BF16 values from shared memory state space +__device__ __forceinline__ void +ld_shared_b128(uint64_t& elts03, uint64_t& elts47, const bf16 * __restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + asm volatile( + "{\n\t" + ".reg.b128 xy; \n\t" + "ld.shared.b128 xy, [%2]; \n\t" + "mov.b128 {%0, %1}, xy; \n" + "}\n" + : "=l"(elts03), "=l"(elts47) + : "r"(src_smem_ptr)); +} + +// Vectorized store of x8 FP4 elements into shared memory state space +__device__ __forceinline__ void +st_shared_b32(fp4e2m1x2 * __restrict__ dst_smem, uint32_t fp4_pack_x8) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8)); +} + +// Vectorized store of x16 FP4 elements into shared memory state space +__device__ __forceinline__ void +st_shared_b64(fp4e2m1x2 * __restrict__ dst_smem, uint64_t fp4_pack_x16) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16)); +} } // namespace ptx namespace { From 03198ae195d707911ccad7acce2cc665036c7149 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Nov 2025 23:47:50 +0000 Subject: [PATCH 02/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 6 +- ...quantize_transpose_nvfp4_persistent_1D.cuh | 292 +++++++++--------- transformer_engine/common/util/ptx.cuh | 174 +++++------ 3 files changed, 236 insertions(+), 236 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 3315cb513c..8b7f8b9634 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -384,7 +384,7 @@ void compare_nvfp4_tensors(const std::string& name, ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; if constexpr (print_detailed_summary) { mismatch_messages.push_back(msg); - + // Optional: limit number of detailed messages to avoid overwhelming output if (mismatch_messages.size() <= 100) { std::cout << "Error in tensor " << name << ": " << msg << std::endl; @@ -401,7 +401,7 @@ void compare_nvfp4_tensors(const std::string& name, // Always report summary - either success or failure std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; std::cout << "Total elements checked: " << (rows * cols) << std::endl; - + if (total_mismatches > 0) { std::cout << "STATUS: FAILED for output" << std::endl; std::cout << "Total mismatches found: " << total_mismatches << std::endl; @@ -410,7 +410,7 @@ void compare_nvfp4_tensors(const std::string& name, std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; } std::cout << "============================" << std::endl; - + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; } else { std::cout << "STATUS: PASSED for output" << std::endl; diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh index 09edfff321..e716383270 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh @@ -49,7 +49,6 @@ static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0") static_assert(TILE_DIM_Y == 64 && "Hardcoded and fixed parameter\0"); static_assert(TILE_DIM_X == 64 && "Hardcoded and fixed parameter\0"); - static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && "Unbalanced threads workload\0"); @@ -58,9 +57,9 @@ static_assert((CHUNK_DIM_Y % TILE_DIM_Y == 0) && static_assert((CHUNK_DIM_X % TILE_DIM_X == 0) && "Chunk size X must be evenly divisible by the tile size X\0"); -static_assert((TILE_DIM_Y % SCALE_DIM == 0) && +static_assert((TILE_DIM_Y % SCALE_DIM == 0) && "Tile size Y must be evenly divisible by the scale dim\0"); -static_assert((TILE_DIM_X % SCALE_DIM == 0) && +static_assert((TILE_DIM_X % SCALE_DIM == 0) && "Tile size X must be evenly divisible by the scale dim\0"); constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; @@ -118,8 +117,8 @@ constexpr int THREADS_Y_TRANSP = THREADS_NUM / THREADS_X_TRANSP; constexpr int ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; constexpr int ITERATIONS_TRANSPOSE = SCALES_PER_TILE_Y / THREADS_Y_TRANSP; static_assert(ITERATIONS_TRANSPOSE >= 1 && "Number of transpose iterations should be >=1\0"); -static_assert((SCALES_PER_TILE_Y % THREADS_Y_TRANSP == 0) - && "Partial transpose iterations are not supported\0"); +static_assert((SCALES_PER_TILE_Y % THREADS_Y_TRANSP == 0) && + "Partial transpose iterations are not supported\0"); constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE / STAGES; @@ -140,38 +139,31 @@ constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; using IType = bf16; using IType2 = typename ptx::FPx2; using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; -using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X/2]; +using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_T_DIM_Y][BUFF_OUT_T_DIM_X]; using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>; -__device__ __forceinline__ float -get_amax_of_pair(const IType2 xormax_pair) { +__device__ __forceinline__ float get_amax_of_pair(const IType2 xormax_pair) { return static_cast(__hmax(__habs(xormax_pair.x), __habs(xormax_pair.y))); } // Compute "correct" per-block encoding scaling factor -__device__ __forceinline__ bf16 -compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_dec) { +__device__ __forceinline__ bf16 compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, + const float S_dec) { constexpr float float_max = detail::TypeExtrema::max; const float scale_rcp = fminf(1.0f / (static_cast(S_dec_block) * S_dec), float_max); return static_cast(scale_rcp); } template -__device__ __forceinline__ void -colwise_scaling(const IType * __restrict__ sIn_ptr, - fp4e2m1x2 * __restrict__ sOut_tr_ptr, - nvfp4_scale_t * __restrict__ sSFcolwise_ptr, - const float S_enc_colwise, - const float S_dec_colwise, - const int stage_Y, - const int stage_X, - const int buff_in, - const int buff_out_tr, - RNG_t& rng, uint4 &random_uint4, int &rnd_idx) { +__device__ __forceinline__ void colwise_scaling( + const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, + const float S_dec_colwise, const int stage_Y, const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { const auto &sIn2x = *reinterpret_cast(sIn_ptr); auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); @@ -197,17 +189,18 @@ colwise_scaling(const IType * __restrict__ sIn_ptr, __align__(8) IType rIn[2][SCALE_DIM]; // Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - #pragma unroll +#pragma unroll for (int i = 0; i < SCALE_DIM; ++i) { - const IType2 elt_pair = ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); rIn[0][i] = elt_pair.x; rIn[1][i] = elt_pair.y; ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); } const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), - static_cast(__habs(thread_amax_2x.y))}; - #pragma unroll - for (int w = 0; w < 2; ++w) { + static_cast(__habs(thread_amax_2x.y))}; +#pragma unroll + for (int w = 0; w < 2; ++w) { const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); // Store scaling factors to SMEM buffer (R2S) @@ -217,35 +210,31 @@ colwise_scaling(const IType * __restrict__ sIn_ptr, // Scale elements __align__(8) uint32_t rOut[SCALE_DIM / 8]; - #pragma unroll +#pragma unroll for (int e = 0; e < SCALE_DIM / 8; ++e) { const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); if constexpr (USE_STOCHASTIC_ROUNDING) { const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); - rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding(elts03, elts47, SFcoefficient, rbits03, rbits47); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding(elts03, elts47, SFcoefficient, + rbits03, rbits47); } else { rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, SFcoefficient); } } - uint64_t& out_pack_16x = *reinterpret_cast(rOut); - ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], out_pack_16x); + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); } } template -__device__ __forceinline__ void -rowwise_scaling(const IType * __restrict__ sIn_ptr, - fp4e2m1x2 * __restrict__ sOut_ptr, - nvfp4_scale_t * __restrict__ sSFrowwise_ptr, - const float S_enc_rowwise, - const float S_dec_rowwise, - const int stage_Y, - const int stage_X, - const int buff_in, - const int buff_out, - RNG_t& rng, uint4 &random_uint4, int &rnd_idx) { +__device__ __forceinline__ void rowwise_scaling( + const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, + const float S_dec_rowwise, const int stage_Y, const int stage_X, const int buff_in, + const int buff_out, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { const auto &sIn = *reinterpret_cast(sIn_ptr); auto &sOut = *reinterpret_cast(sOut_ptr); auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -265,8 +254,9 @@ rowwise_scaling(const IType * __restrict__ sIn_ptr, const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); const int stage_rowwise_scales_offset_Y = SF_thread_offset_rowwise_Y + stage_Y * TILE_DIM_Y; - const int stage_rowwise_scales_offset_X = SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; - #pragma unroll + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; +#pragma unroll for (int it = 0; it < ITERATIONS_NORMAL; ++it) { const int it_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; @@ -274,15 +264,15 @@ rowwise_scaling(const IType * __restrict__ sIn_ptr, // Read (cache) input elements (S2R). Find NVFP4-block AMAX IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { - uint64_t& elts03 = *reinterpret_cast(&rIn[w][0]); - uint64_t& elts47 = *reinterpret_cast(&rIn[w][2]); + uint64_t &elts03 = *reinterpret_cast(&rIn[w][0]); + uint64_t &elts47 = *reinterpret_cast(&rIn[w][2]); const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; // Load elements ptx::ld_shared_b128(elts03, elts47, &sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); } @@ -299,8 +289,8 @@ rowwise_scaling(const IType * __restrict__ sIn_ptr, sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; } - // Scale elements - #pragma unroll +// Scale elements +#pragma unroll for (int w = 0; w < WAVES; ++w) { const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); @@ -309,7 +299,8 @@ rowwise_scaling(const IType * __restrict__ sIn_ptr, if constexpr (USE_STOCHASTIC_ROUNDING) { const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); - out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding(elts03, elts47, SFcoefficient, rbits03, rbits47); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding(elts03, elts47, SFcoefficient, + rbits03, rbits47); } else { out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, SFcoefficient); } @@ -322,17 +313,13 @@ rowwise_scaling(const IType * __restrict__ sIn_ptr, } template -__global__ void __launch_bounds__(THREADS_NUM) - quantize_transpose_nvfp4_persistent_kernel( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t *const scales_ptr, - nvfp4_scale_t *const scales_t_ptr, const float *noop, - const float *const amax_rowwise_ptr, - const float *const amax_colwise_ptr, const size_t rows, - const size_t cols, const size_t scale_stride, - const size_t scale_stride_t, const size_t *rng_state) { +__global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persistent_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t, const size_t *rng_state) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -353,18 +340,21 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; - constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_T_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_T_SIZE, TMA_SHMEM_ALIGNMENT); constexpr int in_mem = buff_size_aligned_in; constexpr int out_mem_rowwise_data = buff_size_aligned_out; constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; - constexpr int out_mem_rowwise_scales = - DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE( + CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned extern __shared__ unsigned char dynamic_shmem[]; unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); @@ -387,15 +377,17 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; // Compute a global encoding/decoding scaling factors for all S_dec_b - const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) - ? 1.0f - : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + const float S_enc_rowwise = + (amax_rowwise_ptr == nullptr) + ? 1.0f + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); // NOTE: This is to match with how emulation code was written. const float S_dec_rowwise = 1.0 / S_enc_rowwise; - const float S_enc_colwise = (amax_colwise_ptr == nullptr) - ? S_enc_rowwise - : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_enc_colwise = + (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); const float S_dec_colwise = 1.0 / S_enc_colwise; __shared__ uint64_t workID_mbar; @@ -411,7 +403,7 @@ __global__ void __launch_bounds__(THREADS_NUM) // Initialize shared memory barriers with the number of threads participating in them if (leading_thread) { - #pragma unroll +#pragma unroll for (int buff = 0; buff < BUFFS_NUM; ++buff) { ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); } @@ -427,14 +419,14 @@ __global__ void __launch_bounds__(THREADS_NUM) int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; int ctaid_parity = 0; - // Prefetch input data only when processing the first chunk, - // which enables the one-iteration overlap throughout the entire kernel life - #pragma unroll +// Prefetch input data only when processing the first chunk, +// which enables the one-iteration overlap throughout the entire kernel life +#pragma unroll for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { const int buff_in = stage; - const int stage_Y = stage / STAGES_X; + const int stage_Y = stage / STAGES_X; const int stage_X = stage % STAGES_X; - + const int stage_offset_Y = stage_Y * TILE_DIM_Y; const int stage_offset_X = stage_X * TILE_DIM_X; @@ -444,16 +436,17 @@ __global__ void __launch_bounds__(THREADS_NUM) const int global_offset_Y = block_offset_Y + stage_offset_Y; const int global_offset_X = block_offset_X + stage_offset_X; - uint64_t* barrier = &IN_buff_readable_mbar[buff_in]; + uint64_t *barrier = &IN_buff_readable_mbar[buff_in]; if (leading_thread) { - uint64_t* dst = reinterpret_cast(&sIn[buff_in]); - const uint64_t* src = reinterpret_cast(&tensor_map_input); + uint64_t *dst = reinterpret_cast(&sIn[buff_in]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); // Arrive on the barrier and tell how many bytes are expected to come in ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); // Initiate bulk tensor copy - ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, barrier); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); } } @@ -474,10 +467,11 @@ __global__ void __launch_bounds__(THREADS_NUM) if (leading_thread) { ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); - ptx::clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes(&workID_mbar, &workID_response); + ptx::clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes( + &workID_mbar, &workID_response); } - #pragma unroll +#pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const int stage_Y = stage / STAGES_X; const int stage_X = stage % STAGES_X; @@ -496,9 +490,9 @@ __global__ void __launch_bounds__(THREADS_NUM) // Prefetch next stage Input data if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { - const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; - const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; + const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y; @@ -511,22 +505,24 @@ __global__ void __launch_bounds__(THREADS_NUM) const int global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; const int global_offset_X = block_offset_X + next_prefetch_stage_offset_X; - uint64_t* barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; if (leading_thread) { - uint64_t* dst = reinterpret_cast(&sIn[next_prefetch_buff]); - const uint64_t* src = reinterpret_cast(&tensor_map_input); + uint64_t *dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); // Arrive on the barrier and tell how many bytes are expected to come in ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); // Initiate bulk tensor copy - ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, barrier); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); } ptx::fence_proxy_async_shared_cta(); } - + // Wait for the data to have arrived - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // Wait for TMA transfer to have finished reading shared memory @@ -534,14 +530,14 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( - sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, S_dec_rowwise, - stage_Y, stage_X, buff_in, buff_out, rng, random_uint4, rnd_idx); + rowwise_scaling(sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, + S_dec_rowwise, stage_Y, stage_X, buff_in, buff_out, + rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { - colwise_scaling( - sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, S_dec_colwise, - stage_Y, stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); + colwise_scaling(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, + S_enc_colwise, S_dec_colwise, stage_Y, stage_X, + buff_in, buff_out_tr, rng, random_uint4, rnd_idx); } // Wait for shared memory writes to be visible to TMA engine @@ -555,17 +551,17 @@ __global__ void __launch_bounds__(THREADS_NUM) const int global_offset_X = block_offset_X + stage_offset_X; const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X; const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y; - + ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, - reinterpret_cast(&sOut[buff_out])); + reinterpret_cast(&tensor_map_output), global_offset_X, + global_offset_Y, reinterpret_cast(&sOut[buff_out])); if constexpr (RETURN_TRANSPOSE) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, global_offset_Y_tr, - reinterpret_cast(&sOut_tr[buff_out_tr])); + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); } - + // Create a "bulk async-group" out of the previous bulk copy operation ptx::cp_async_bulk_commit_group(); } @@ -573,7 +569,7 @@ __global__ void __launch_bounds__(THREADS_NUM) buff_in = (buff_in + 1) % BUFFS_NUM_IN; buff_out = (buff_out + 1) % BUFFS_NUM_OUT; buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; - } // end of stages + } // end of stages // Vectorized store of scaling factors (S2G) { @@ -582,13 +578,13 @@ __global__ void __launch_bounds__(THREADS_NUM) using ScalesVec = Vec; // number of scales in X dimension of this chunk const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); - + for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { const size_t row_global = scales_block_offset_Y_rowwise + row; if (row_global < rows) { ScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); - const size_t scale_idx_global = row_global * scale_stride - + scales_block_offset_X_rowwise; + const size_t scale_idx_global = + row_global * scale_stride + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); } } @@ -604,22 +600,22 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { ScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); - const size_t scale_idx_global = row_tr_global * scale_stride_t - + scales_block_offset_X_tr; + const size_t scale_idx_global = + row_tr_global * scale_stride_t + scales_block_offset_X_tr; scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); } } } - + if (!job_finished) { - // Ensures all reads from SFs buffer have completed and it's ready to be reused + // Ensures all reads from SFs buffer have completed and it's ready to be reused __syncthreads(); } } } if (leading_thread) { - #pragma unroll +#pragma unroll for (int buff = 0; buff < BUFFS_NUM; ++buff) { ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); } @@ -653,13 +649,13 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - + if (return_transpose) { NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), - "Transposed output must have FP4 type."); + "Transposed output must have FP4 type."); NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Transposed scaling tensor must be allocated"); + "Transposed scaling tensor must be allocated"); } const size_t rows = input.flat_first_dim(); @@ -676,14 +672,17 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * const int block_size = THREADS_NUM; const size_t scale_stride = output->scale_inv.shape[1]; - const size_t scale_stride_transpose = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); - nvfp4_scale_t *const scales_transpose_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); const float *noop_ptr = reinterpret_cast(noop->data.dptr); const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); - const float *const amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; const size_t *rng_state = nullptr; @@ -712,14 +711,17 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; - constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_T_SIZE, TMA_SHMEM_ALIGNMENT); - - constexpr int buff_size_scales = - DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales_transpose = - DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_T_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( + CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); const int in_mem = buff_size_aligned_in; @@ -730,23 +732,21 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * const int out_mem = out_data_mem + out_data_transpose_mem; - const int dshmem_size = in_mem + out_mem - + out_scales_transpose_mem + out_scales_mem - + TMA_SHMEM_ALIGNMENT; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, - { - auto kernel = quantize_transpose_nvfp4_persistent_kernel; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - } - ); - ); + const int dshmem_size = + in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_persistent_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 4b7a5c9c44..09e55ceff6 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -781,90 +781,88 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c } } -__device__ __forceinline__ uint32_t -mul_cvt_bf16_to_fp4_8x_round_to_nearest(const uint64_t in03, const uint64_t in47, - const bf16 scaling_coefficient) { +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest( + const uint64_t in03, const uint64_t in47, const bf16 scaling_coefficient) { uint32_t out_8x = 0; constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { asm volatile( - "{\n" - ".reg.f32 zero; \n\t" - "mov.b32 zero, 0; \n\t" - ".reg.b16 scaling_coeff; \n\t" - "mov.b16 scaling_coeff, %3; \n\t" - ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" - "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" - "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" - - ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" - "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" - - ".reg.b8 f0, f1, f2, f3; \n\t" - // Elements reordered to match e2m1x4 packing order (v1,v0) - "cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t" - "mov.b32 %0, {f0, f1, f2, f3};\n" - "}" - : "=r"(out_8x) - : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient))); + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + // Elements reordered to match e2m1x4 packing order (v1,v0) + "cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient))); } else { NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); } return out_8x; } -__device__ __forceinline__ uint32_t -mul_cvt_bf16_to_fp4_8x_stochastic_rounding(const uint64_t in03, const uint64_t in47, - const bf16 scaling_coefficient, const uint32_t rbits03, - const uint32_t rbits47) { +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + const uint64_t in03, const uint64_t in47, const bf16 scaling_coefficient, + const uint32_t rbits03, const uint32_t rbits47) { uint32_t out_8x = 0; constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; if constexpr (has_rs) { asm volatile( - "{\n" - ".reg.f32 zero; \n\t" - "mov.b32 zero, 0; \n\t" - ".reg.b16 scaling_coeff; \n\t" - "mov.b16 scaling_coeff, %3; \n\t" - ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" - "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" - "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" - - ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" - "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" - "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" - - ".reg.b16 b03, b47; \n\t" - // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) - "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" - "mov.b32 %0, {b03, b47};\n" - "}" - : "=r"(out_8x) - : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient)), - "r"(rbits03), "r"(rbits47)); + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient)), + "r"(rbits03), "r"(rbits47)); } else { NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); } return out_8x; } @@ -1718,47 +1716,49 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { } // Loads single BF16/FP16 element from shared memory state space -__device__ __forceinline__ bf16 -ld_shared_b16(const bf16 * __restrict__ src_smem) { +__device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) { const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); bf16 dst; - asm volatile("ld.shared.b16 %0, [%1];" : "=h"(reinterpret_cast(dst)) : "r"(src_smem_ptr)); + asm volatile("ld.shared.b16 %0, [%1];" + : "=h"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); return dst; } // Loads pair of BF16/FP16 values from shared memory state space -__device__ __forceinline__ bf16x2 -ld_shared_b32(const bf16x2 * __restrict__ src_smem) { +__device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_smem) { const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); bf16x2 dst; - asm volatile("ld.shared.b32 %0, [%1];" : "=r"(reinterpret_cast(dst)) : "r"(src_smem_ptr)); + asm volatile("ld.shared.b32 %0, [%1];" + : "=r"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); return dst; } // Loads 8x BF16 values from shared memory state space -__device__ __forceinline__ void -ld_shared_b128(uint64_t& elts03, uint64_t& elts47, const bf16 * __restrict__ src_smem) { +__device__ __forceinline__ void ld_shared_b128(uint64_t &elts03, uint64_t &elts47, + const bf16 *__restrict__ src_smem) { const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); asm volatile( - "{\n\t" - ".reg.b128 xy; \n\t" - "ld.shared.b128 xy, [%2]; \n\t" - "mov.b128 {%0, %1}, xy; \n" - "}\n" - : "=l"(elts03), "=l"(elts47) - : "r"(src_smem_ptr)); + "{\n\t" + ".reg.b128 xy; \n\t" + "ld.shared.b128 xy, [%2]; \n\t" + "mov.b128 {%0, %1}, xy; \n" + "}\n" + : "=l"(elts03), "=l"(elts47) + : "r"(src_smem_ptr)); } // Vectorized store of x8 FP4 elements into shared memory state space -__device__ __forceinline__ void -st_shared_b32(fp4e2m1x2 * __restrict__ dst_smem, uint32_t fp4_pack_x8) { +__device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem, + uint32_t fp4_pack_x8) { const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8)); } // Vectorized store of x16 FP4 elements into shared memory state space -__device__ __forceinline__ void -st_shared_b64(fp4e2m1x2 * __restrict__ dst_smem, uint64_t fp4_pack_x16) { +__device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem, + uint64_t fp4_pack_x16) { const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16)); } From 4f1e8d8a3da7b1637aa73af63cf8388a3c97a806 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Sat, 22 Nov 2025 00:02:34 +0000 Subject: [PATCH 03/23] Fix FP4 guard in ptx Signed-off-by: Oleg Goncharov --- transformer_engine/common/util/ptx.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 09e55ceff6..bbef588832 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1749,19 +1749,23 @@ __device__ __forceinline__ void ld_shared_b128(uint64_t &elts03, uint64_t &elts4 : "r"(src_smem_ptr)); } +#if FP4_TYPE_SUPPORTED // Vectorized store of x8 FP4 elements into shared memory state space __device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem, uint32_t fp4_pack_x8) { const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8)); } +#endif // Vectorized store of x16 FP4 elements into shared memory state space +#if FP4_TYPE_SUPPORTED __device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem, uint64_t fp4_pack_x16) { const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16)); } +#endif } // namespace ptx namespace { From 236d7ee10fa4a76993d29d8509790311712c0516 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Sat, 22 Nov 2025 00:26:25 +0000 Subject: [PATCH 04/23] Fix Signed-off-by: Oleg Goncharov --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 53 ++++++++----------- .../common/cast/dispatch/quantize.cuh | 3 +- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 10 ++-- 3 files changed, 30 insertions(+), 36 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 8b7f8b9634..afd7927da2 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -349,7 +349,6 @@ void compare_nvfp4_tensors(const std::string& name, const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const int rows, const int cols, double atol = 1e-5, double rtol = 1e-8) { - constexpr bool print_detailed_summary = false; std::vector mismatch_messages; size_t total_mismatches = 0; @@ -382,42 +381,36 @@ void compare_nvfp4_tensors(const std::string& name, std::to_string(t) + " vs " + std::to_string(r) + " (abs_diff: " + std::to_string(fabs(t - r)) + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; - if constexpr (print_detailed_summary) { - mismatch_messages.push_back(msg); - - // Optional: limit number of detailed messages to avoid overwhelming output - if (mismatch_messages.size() <= 100) { - std::cout << "Error in tensor " << name << ": " << msg << std::endl; - } - } else { - GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + mismatch_messages.push_back(msg); + + // Optional: limit number of detailed messages to avoid overwhelming output + if (mismatch_messages.size() <= 100) { + std::cout << "Error in tensor " << name << ": " << msg << std::endl; } } } } } - if constexpr (print_detailed_summary) { - // Always report summary - either success or failure - std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; - std::cout << "Total elements checked: " << (rows * cols) << std::endl; - - if (total_mismatches > 0) { - std::cout << "STATUS: FAILED for output" << std::endl; - std::cout << "Total mismatches found: " << total_mismatches << std::endl; - std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; - if (mismatch_messages.size() > 100) { - std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; - } - std::cout << "============================" << std::endl; - - GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; - } else { - std::cout << "STATUS: PASSED for output" << std::endl; - std::cout << "All elements match within tolerance!" << std::endl; - std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; - std::cout << "============================" << std::endl; + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > 100) { + std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; } } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9f7a4a9b01..2821cf781b 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -100,7 +100,8 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); + (cols % 32 == 0) && output_tensor->has_data() && + is_supported_by_CC_100(); // Launch NVFP4 quantize kernel if (use_optimized_kernel) { diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index aff388cd92..21dd383e76 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1161,11 +1161,6 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace quantize_transpose_kernel; using namespace ptx; - if (!use_2d_quantization && input.dtype() == DType::kBFloat16) { - quantize_transpose_persistent_1D(input, noop, output, quant_config, stream); - return; - } - bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to @@ -1173,6 +1168,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); + if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && return_transpose) { + quantize_transpose_persistent_1D(input, noop, output, quant_config, stream); + return; + } + constexpr bool COMPUTE_ACTIVATIONS = false; using ParamOP = Empty; constexpr float (*OP)(float, const ParamOP &) = nullptr; From 558c1269d50d8972d742fa44d9fe382094f1feb2 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Sat, 22 Nov 2025 01:12:56 +0000 Subject: [PATCH 05/23] Fix in ptx. reduxf32 guard Signed-off-by: Oleg Goncharov --- transformer_engine/common/util/ptx.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index bbef588832..178316cee1 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1077,11 +1077,11 @@ __device__ __forceinline__ void fma_f32_bf16(float &out, uint16_t const &a, uint } __device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const &in) { -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in)); -#else - asm volatile( + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in)); + } else { + asm volatile( "{\n\t" ".reg.b32 val;\n" "abs.f32 val, %1;\n" @@ -1089,7 +1089,7 @@ __device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const "}\n\t" : "=r"(reinterpret_cast(out)) : "f"(in)); -#endif + } } __device__ __forceinline__ bf16 get_amax(bf16 a, bf16 b) { From a7a065232574af4d8c6afccfc03bd3a3d9251348 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Sat, 22 Nov 2025 01:34:17 +0000 Subject: [PATCH 06/23] Fix in ptx. reduxf32 guard Signed-off-by: Oleg Goncharov --- transformer_engine/common/util/ptx.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 178316cee1..0b1118b959 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -122,6 +122,8 @@ constexpr bool is_supported_arch() { ptx::FamilySpecific<120>) #define ARCH_HAS_STOCHASTIC_ROUNDING \ NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) +#define ARCH_HAS_REDUX_F32 \ + NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { @@ -1077,8 +1079,8 @@ __device__ __forceinline__ void fma_f32_bf16(float &out, uint16_t const &a, uint } __device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const &in) { - constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; - if constexpr (is_blackwell) { + constexpr bool has_redux_f32 = ARCH_HAS_REDUX_F32; + if constexpr (has_redux_f32) { asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in)); } else { asm volatile( From c8062d3690eb4785ace166424c774239d25fcc29 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Nov 2025 01:35:56 +0000 Subject: [PATCH 07/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/ptx.cuh | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 0b1118b959..37a206eed1 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -122,8 +122,7 @@ constexpr bool is_supported_arch() { ptx::FamilySpecific<120>) #define ARCH_HAS_STOCHASTIC_ROUNDING \ NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) -#define ARCH_HAS_REDUX_F32 \ - NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>) +#define ARCH_HAS_REDUX_F32 NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { @@ -1084,13 +1083,13 @@ __device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in)); } else { asm volatile( - "{\n\t" - ".reg.b32 val;\n" - "abs.f32 val, %1;\n" - "redux.sync.max.u32 %0, val, 0xFFFFFFFF;\n" - "}\n\t" - : "=r"(reinterpret_cast(out)) - : "f"(in)); + "{\n\t" + ".reg.b32 val;\n" + "abs.f32 val, %1;\n" + "redux.sync.max.u32 %0, val, 0xFFFFFFFF;\n" + "}\n\t" + : "=r"(reinterpret_cast(out)) + : "f"(in)); } } From b14b3fd65f71b326ca14809cf81fa0e12d05a24d Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Sat, 22 Nov 2025 01:41:38 +0000 Subject: [PATCH 08/23] Fix Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/dispatch/quantize.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 2821cf781b..61816fad27 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -247,7 +247,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); + (cols % 32 == 0) && output_tensor->has_data() && + is_supported_by_CC_100(); // Launch NVFP4 quantize kernel if (use_optimized_kernel) { From f9cf5e084289b67b5e4a5caae0f03fd306c02839 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Mon, 8 Dec 2025 13:56:25 +0000 Subject: [PATCH 09/23] Fixes per PR review Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/dispatch/quantize.cuh | 6 ++---- .../quantize_transpose_nvfp4_persistent_1D.cuh | 6 +++--- transformer_engine/common/util/ptx.cuh | 10 +++++----- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 61816fad27..9f7a4a9b01 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -100,8 +100,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data() && - is_supported_by_CC_100(); + (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel if (use_optimized_kernel) { @@ -247,8 +246,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data() && - is_supported_by_CC_100(); + (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel if (use_optimized_kernel) { diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh index e716383270..c8ec825fc9 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh @@ -266,12 +266,12 @@ __device__ __forceinline__ void rowwise_scaling( IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int w = 0; w < WAVES; ++w) { - uint64_t &elts03 = *reinterpret_cast(&rIn[w][0]); - uint64_t &elts47 = *reinterpret_cast(&rIn[w][2]); const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + // Load elements - ptx::ld_shared_b128(elts03, elts47, &sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); + __uint128_t& elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); #pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 37a206eed1..5ec5ffe69a 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -122,7 +122,6 @@ constexpr bool is_supported_arch() { ptx::FamilySpecific<120>) #define ARCH_HAS_STOCHASTIC_ROUNDING \ NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) -#define ARCH_HAS_REDUX_F32 NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { @@ -1078,8 +1077,8 @@ __device__ __forceinline__ void fma_f32_bf16(float &out, uint16_t const &a, uint } __device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const &in) { - constexpr bool has_redux_f32 = ARCH_HAS_REDUX_F32; - if constexpr (has_redux_f32) { + constexpr bool is_sm_100f = NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>); + if constexpr (is_sm_100f) { asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in)); } else { asm volatile( @@ -1737,8 +1736,8 @@ __device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_s } // Loads 8x BF16 values from shared memory state space -__device__ __forceinline__ void ld_shared_b128(uint64_t &elts03, uint64_t &elts47, - const bf16 *__restrict__ src_smem) { +__device__ __forceinline__ __uint128_t ld_shared_b128(const bf16 *__restrict__ src_smem) { + uint64_t elts03, elts47; const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); asm volatile( "{\n\t" @@ -1748,6 +1747,7 @@ __device__ __forceinline__ void ld_shared_b128(uint64_t &elts03, uint64_t &elts4 "}\n" : "=l"(elts03), "=l"(elts47) : "r"(src_smem_ptr)); + return (static_cast<__uint128_t>(elts47) << 64) | static_cast<__uint128_t>(elts03); } #if FP4_TYPE_SUPPORTED From c2fd9f0f6351ed85ac5b14504a08c1790cc27b8e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 13:57:42 +0000 Subject: [PATCH 10/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../specialized/quantize_transpose_nvfp4_persistent_1D.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh index c8ec825fc9..80d438cf3b 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh @@ -270,7 +270,7 @@ __device__ __forceinline__ void rowwise_scaling( const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; // Load elements - __uint128_t& elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]); + __uint128_t &elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]); elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); #pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { From 452ea66d85e57f96afa4188f07d8c93da7e5afdc Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Mon, 8 Dec 2025 17:44:15 +0000 Subject: [PATCH 11/23] Fixes per PR review. Added parameter to turn off the persistency Signed-off-by: Oleg Goncharov --- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 4 +- ... => quantize_transpose_nvfp4_tuned_1D.cuh} | 132 +++++++++--------- transformer_engine/common/util/ptx.cuh | 6 +- 3 files changed, 72 insertions(+), 70 deletions(-) rename transformer_engine/common/cast/nvfp4/specialized/{quantize_transpose_nvfp4_persistent_1D.cuh => quantize_transpose_nvfp4_tuned_1D.cuh} (87%) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 21dd383e76..8243bef781 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -21,7 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" -#include "specialized/quantize_transpose_nvfp4_persistent_1D.cuh" +#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh" namespace transformer_engine { namespace dispatch { @@ -1169,7 +1169,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, bool return_transpose = output->has_columnwise_data(); if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && return_transpose) { - quantize_transpose_persistent_1D(input, noop, output, quant_config, stream); + quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); return; } diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh similarity index 87% rename from transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh rename to transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 80d438cf3b..e1bf1f8d10 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -4,12 +4,12 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file quantize_transpose_nvfp4_persistent_1D.cuh - * \brief Persistent kernel to cast to NVFP4 and transpose. +/*! \file quantize_transpose_nvfp4_tuned_1D.cuh + * \brief Tuned kernel to cast to NVFP4 and transpose. */ -#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_PERSISTENT_1D_CUH_ -#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_PERSISTENT_1D_CUH_ +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ #include #include @@ -26,7 +26,7 @@ namespace transformer_engine { namespace dispatch { namespace nvfp4 { -namespace quantize_transpose_persistent_kernel { +namespace quantize_transpose_tuned_kernel { using namespace quantization_and_transposition_SF; using namespace core; @@ -34,27 +34,27 @@ using namespace ptx; #if FP4_TYPE_SUPPORTED -constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) -static_assert(SCALE_DIM == 16 && "NVFP4 block size is 16\0"); +struct TunableConfig { + static constexpr int CHUNK_DIM_Y = 128; + static constexpr int CHUNK_DIM_X = 128; + static constexpr int PREFETCH_STAGES = 1; + static constexpr bool PERSISTENT = true; +}; +constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) constexpr int THREADS_NUM = 128; constexpr int ELTS_PER_THREAD = 16; -constexpr int CHUNK_DIM_Y = 128; -constexpr int CHUNK_DIM_X = 128; constexpr int TILE_DIM_Y = 64; constexpr int TILE_DIM_X = 64; -static_assert(THREADS_NUM == 128 && "Hardcoded and fixed parameter\0"); static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0"); -static_assert(TILE_DIM_Y == 64 && "Hardcoded and fixed parameter\0"); -static_assert(TILE_DIM_X == 64 && "Hardcoded and fixed parameter\0"); static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && "Unbalanced threads workload\0"); -static_assert((CHUNK_DIM_Y % TILE_DIM_Y == 0) && +static_assert((TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0) && "Chunk size Y must be evenly divisible by the tile size Y\0"); -static_assert((CHUNK_DIM_X % TILE_DIM_X == 0) && +static_assert((TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0) && "Chunk size X must be evenly divisible by the tile size X\0"); static_assert((TILE_DIM_Y % SCALE_DIM == 0) && @@ -62,13 +62,13 @@ static_assert((TILE_DIM_Y % SCALE_DIM == 0) && static_assert((TILE_DIM_X % SCALE_DIM == 0) && "Tile size X must be evenly divisible by the scale dim\0"); -constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; -constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y; +constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X; constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; -constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; -constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_Y = TunableConfig::CHUNK_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_X = TunableConfig::CHUNK_DIM_X / SCALE_DIM; constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; @@ -77,8 +77,7 @@ constexpr int STAGES_Y = TILES_Y; constexpr int STAGES_X = TILES_X; constexpr int STAGES = STAGES_Y * STAGES_X; -constexpr int PREFETCH_STAGES = 1; -constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; +constexpr int BUFFS_NUM = TunableConfig::PREFETCH_STAGES + 1; constexpr int BUFFS_NUM_IN = BUFFS_NUM; constexpr int BUFFS_NUM_OUT = BUFFS_NUM; constexpr int BUFFS_NUM_OUT_TR = 2; @@ -105,7 +104,6 @@ constexpr int BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; // Manual swizzling parameters to reduce SHMEM bank conflicts constexpr int PACK_SIZE = 8; -static_assert(PACK_SIZE == 8 && "Pack size is fixed to 8\0"); constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; @@ -125,7 +123,7 @@ constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE / STA static_assert(BUFF_DIM_Y >= SCALE_DIM && "Number of buffer rows must be greater or equal to the size of the columwise " "scaling block\0"); -static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(TunableConfig::CHUNK_DIM_Y >= BUFF_DIM_Y); static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && "Number of buffer rows must be greater or equal to the number of rowwise " "processing threads in Y dimension\0"); @@ -142,12 +140,12 @@ using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_T_DIM_Y][BUFF_OUT_T_DIM_X]; -using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; -using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; +using ScalesType2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>; -__device__ __forceinline__ float get_amax_of_pair(const IType2 xormax_pair) { - return static_cast(__hmax(__habs(xormax_pair.x), __habs(xormax_pair.y))); +__device__ __forceinline__ float get_amax_of_pair(const IType2 pair) { + return static_cast(__hmax(__habs(pair.x), __habs(pair.y))); } // Compute "correct" per-block encoding scaling factor @@ -313,7 +311,7 @@ __device__ __forceinline__ void rowwise_scaling( } template -__global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persistent_kernel( +__global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, const __grid_constant__ CUtensorMap tensor_map_output_t, nvfp4_scale_t *const scales_ptr, @@ -352,7 +350,7 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste constexpr int out_mem_rowwise_data = buff_size_aligned_out; constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE( - CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned extern __shared__ unsigned char dynamic_shmem[]; @@ -422,7 +420,7 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste // Prefetch input data only when processing the first chunk, // which enables the one-iteration overlap throughout the entire kernel life #pragma unroll - for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + for (int stage = 0; stage < TunableConfig::PREFETCH_STAGES; ++stage) { const int buff_in = stage; const int stage_Y = stage / STAGES_X; const int stage_X = stage % STAGES_X; @@ -430,8 +428,8 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste const int stage_offset_Y = stage_Y * TILE_DIM_Y; const int stage_offset_X = stage_X * TILE_DIM_X; - const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; - const int block_offset_X = ctaid_X * CHUNK_DIM_X; + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; const int global_offset_Y = block_offset_Y + stage_offset_Y; const int global_offset_X = block_offset_X + stage_offset_X; @@ -451,24 +449,25 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste } while (!job_finished) { - const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; - const int block_offset_X = ctaid_X * CHUNK_DIM_X; + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; - const int block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; - const int block_offset_X_tr = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; + const int block_offset_X_tr = ctaid_Y * TunableConfig::CHUNK_DIM_Y; const int chunk_rows = rows - block_offset_Y; const int chunk_cols = cols - block_offset_X; - const int scales_block_offset_Y_rowwise = ctaid_Y * CHUNK_DIM_Y; + const int scales_block_offset_Y_rowwise = ctaid_Y * TunableConfig::CHUNK_DIM_Y; const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X; - const int scales_block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; + const int scales_block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y; - if (leading_thread) { - ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); - ptx::clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes( - &workID_mbar, &workID_response); + if constexpr (TunableConfig::PERSISTENT) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } } #pragma unroll @@ -479,19 +478,24 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste const int stage_offset_Y = stage_Y * TILE_DIM_Y; const int stage_offset_X = stage_X * TILE_DIM_X; - if (stage == STAGES - PREFETCH_STAGES) { - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); - ptx::get_cancelled_cta_2D_id(&workID_response, ctaid_X, ctaid_Y); + if (stage == STAGES - TunableConfig::PREFETCH_STAGES) { + if constexpr (TunableConfig::PERSISTENT) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + ctaid_parity ^= 1; + } else { + ctaid_X = -1; + ctaid_Y = -1; + } if (ctaid_X == -1 && ctaid_Y == -1) { job_finished = true; } - ctaid_parity ^= 1; } // Prefetch next stage Input data - if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { - const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; - const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + if (!job_finished || (stage < STAGES - TunableConfig::PREFETCH_STAGES)) { + const int next_prefetch_buff = (buff_in + TunableConfig::PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + TunableConfig::PREFETCH_STAGES) % STAGES; const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; @@ -499,8 +503,8 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X; // Offsets change, because coordinates of the next "to-be-prefetched" CTA do also chage - const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; - const int block_offset_X = ctaid_X * CHUNK_DIM_X; + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; const int global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; const int global_offset_X = block_offset_X + next_prefetch_stage_offset_X; @@ -527,7 +531,7 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste // Wait for TMA transfer to have finished reading shared memory // I.e. the OUT buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read(); + ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization rowwise_scaling(sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, @@ -579,7 +583,7 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste // number of scales in X dimension of this chunk const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); - for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { + for (size_t row = threadIdx.x; row < TunableConfig::CHUNK_DIM_Y; row += THREADS_NUM) { const size_t row_global = scales_block_offset_Y_rowwise + row; if (row_global < rows) { ScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); @@ -596,7 +600,7 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste // number of scales in Y dimension of this chunk const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); - for (size_t row_tr = threadIdx.x; row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { + for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; row_tr += THREADS_NUM) { const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { ScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); @@ -627,13 +631,13 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_persiste } #endif // FP4_TYPE_SUPPORTED -} // namespace quantize_transpose_persistent_kernel +} // namespace quantize_transpose_tuned_kernel -inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor *noop, - Tensor *output, const QuantizationConfig *quant_config, - cudaStream_t stream) { +inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, + Tensor *output, const QuantizationConfig *quant_config, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED - using namespace quantize_transpose_persistent_kernel; + using namespace quantize_transpose_tuned_kernel; using namespace ptx; const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; @@ -666,8 +670,8 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * NVTE_CHECK(cols % 32 == 0, "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA - const int blocks_Y = DIVUP(rows, static_cast(CHUNK_DIM_Y)); - const int blocks_X = DIVUP(cols, static_cast(CHUNK_DIM_X)); + const int blocks_Y = DIVUP(rows, static_cast(TunableConfig::CHUNK_DIM_Y)); + const int blocks_X = DIVUP(cols, static_cast(TunableConfig::CHUNK_DIM_X)); const dim3 grid(blocks_X, blocks_Y); const int block_size = THREADS_NUM; @@ -719,9 +723,9 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_T_SIZE, TMA_SHMEM_ALIGNMENT); constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( - CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( - CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + TunableConfig::CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); const int in_mem = buff_size_aligned_in; @@ -739,7 +743,7 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { auto kernel = - quantize_transpose_nvfp4_persistent_kernel; + quantize_transpose_nvfp4_tuned_1D_kernel; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); kernel<<>>( @@ -756,4 +760,4 @@ inline void quantize_transpose_persistent_1D(const Tensor &input, const Tensor * } // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_PERSISTENT_1D_CUH_ +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 5ec5ffe69a..67fd9ec234 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -317,9 +317,7 @@ __device__ __forceinline__ void mbarrier_wait_parity_relaxed_cta_shared_cta(uint #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -__device__ __forceinline__ void -clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes( - uint64_t *mbar, __uint128_t *response_data_ptr) { +__device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) { constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); @@ -336,7 +334,7 @@ clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes( } } -__device__ __forceinline__ void get_cancelled_cta_2D_id(__uint128_t *response_data_ptr, +__device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_data_ptr, int32_t &ctaid_X, int32_t &ctaid_Y) { constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { From 3eb453bf73dfb84be45d4ea46134dc7e93320069 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:47:55 +0000 Subject: [PATCH 12/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../specialized/quantize_transpose_nvfp4_tuned_1D.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index e1bf1f8d10..aa6b2ebff3 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -600,7 +600,8 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D // number of scales in Y dimension of this chunk const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); - for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; row_tr += THREADS_NUM) { + for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; + row_tr += THREADS_NUM) { const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { ScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); @@ -633,8 +634,8 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D #endif // FP4_TYPE_SUPPORTED } // namespace quantize_transpose_tuned_kernel -inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, - Tensor *output, const QuantizationConfig *quant_config, +inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED using namespace quantize_transpose_tuned_kernel; From 7b11f002bd4865dab545d3d489ce732f5ba14fe2 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 12 Dec 2025 15:41:26 +0000 Subject: [PATCH 13/23] Modified reference CPU implementation in C++ unit tests to match GPU (numerical truncation). Tightened the numerical tolerance Signed-off-by: Oleg Goncharov --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 50 ++++++++----------- .../quantize_transpose_nvfp4_tuned_1D.cuh | 19 +++---- 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index afd7927da2..72950b9bf5 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -114,14 +114,16 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float S_dec_b = block_amax / 6.0f; // Scale & Store per-block decoding scaling factor - const float S_dec_b_fp8 = S_dec_b * S_enc; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; const size_t scale_idx = i * scales_stride + block_X; - scales[scale_idx] = static_cast(S_dec_b_fp8); - const float scale_reciprocal = S_enc_b_fp8; + scales[scale_idx] = S_dec_b_fp8; + // Numercial truncation to match GPU implementation, which uses mixed precision FMA instruction + const float scale_reciprocal = static_cast(static_cast(S_enc_b_fp8)); for (size_t j = j_min; j < j_max; j += 2) { const int idx_pair = (i * cols + j) / 2; @@ -349,6 +351,8 @@ void compare_nvfp4_tensors(const std::string& name, const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const int rows, const int cols, double atol = 1e-5, double rtol = 1e-8) { + constexpr int max_mismatches_to_print = 3; + std::vector mismatch_messages; size_t total_mismatches = 0; @@ -362,29 +366,16 @@ void compare_nvfp4_tensors(const std::string& name, const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = false; - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); - } - if (assertion) { + const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); + if (mismatch) { total_mismatches++; - std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + - std::to_string(t) + " vs " + std::to_string(r) + - " (abs_diff: " + std::to_string(fabs(t - r)) + - ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; - mismatch_messages.push_back(msg); - // Optional: limit number of detailed messages to avoid overwhelming output - if (mismatch_messages.size() <= 100) { + if (total_mismatches <= max_mismatches_to_print) { + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); std::cout << "Error in tensor " << name << ": " << msg << std::endl; } } @@ -400,8 +391,9 @@ void compare_nvfp4_tensors(const std::string& name, std::cout << "STATUS: FAILED for output" << std::endl; std::cout << "Total mismatches found: " << total_mismatches << std::endl; std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; - if (mismatch_messages.size() > 100) { - std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + if (mismatch_messages.size() > max_mismatches_to_print) { + std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print) + << " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl; } std::cout << "============================" << std::endl; @@ -619,8 +611,8 @@ void performTest(float (*OP)(const float), } ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - const double atol = 0.05; - const double rtol = 0.1; + const double atol = 1.0E-6; + const double rtol = 1.0E-6; // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index aa6b2ebff3..344b88a092 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -150,9 +150,9 @@ __device__ __forceinline__ float get_amax_of_pair(const IType2 pair) { // Compute "correct" per-block encoding scaling factor __device__ __forceinline__ bf16 compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, - const float S_dec) { + const float S_enc) { constexpr float float_max = detail::TypeExtrema::max; - const float scale_rcp = fminf(1.0f / (static_cast(S_dec_block) * S_dec), float_max); + const float scale_rcp = fminf(S_enc / static_cast(S_dec_block), float_max); return static_cast(scale_rcp); } @@ -160,7 +160,7 @@ template __device__ __forceinline__ void colwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, - const float S_dec_colwise, const int stage_Y, const int stage_X, const int buff_in, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { const auto &sIn2x = *reinterpret_cast(sIn_ptr); auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); @@ -204,7 +204,7 @@ __device__ __forceinline__ void colwise_scaling( // Store scaling factors to SMEM buffer (R2S) sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; - const bf16 SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_dec_colwise); + const bf16 SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); // Scale elements __align__(8) uint32_t rOut[SCALE_DIM / 8]; @@ -231,7 +231,7 @@ template __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, - const float S_dec_rowwise, const int stage_Y, const int stage_X, const int buff_in, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { const auto &sIn = *reinterpret_cast(sIn_ptr); auto &sOut = *reinterpret_cast(sOut_ptr); @@ -278,7 +278,7 @@ __device__ __forceinline__ void rowwise_scaling( const float block_amax = get_amax_of_pair(thread_amax_2x); const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - const bf16 SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_dec_rowwise); + const bf16 SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); // Store scaling factors to SMEM buffer (R2S) if (SF_storing_thread) { @@ -379,14 +379,11 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D (amax_rowwise_ptr == nullptr) ? 1.0f : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); - // NOTE: This is to match with how emulation code was written. - const float S_dec_rowwise = 1.0 / S_enc_rowwise; const float S_enc_colwise = (amax_colwise_ptr == nullptr) ? S_enc_rowwise : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); - const float S_dec_colwise = 1.0 / S_enc_colwise; __shared__ uint64_t workID_mbar; __shared__ __uint128_t workID_response; @@ -535,12 +532,12 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D // NVFP4 Quantization rowwise_scaling(sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, - S_dec_rowwise, stage_Y, stage_X, buff_in, buff_out, + stage_Y, stage_X, buff_in, buff_out, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { colwise_scaling(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, - S_enc_colwise, S_dec_colwise, stage_Y, stage_X, + S_enc_colwise, stage_Y, stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); } From a38eefff29a37dd56ead14ec29759f1ceb6645bb Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 12 Dec 2025 15:43:52 +0000 Subject: [PATCH 14/23] Disabled persistency by default, as non-persistent kernel is more performant when inputs are large Signed-off-by: Oleg Goncharov --- .../nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 344b88a092..df22316e75 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -38,7 +38,7 @@ struct TunableConfig { static constexpr int CHUNK_DIM_Y = 128; static constexpr int CHUNK_DIM_X = 128; static constexpr int PREFETCH_STAGES = 1; - static constexpr bool PERSISTENT = true; + static constexpr bool PERSISTENT = false; }; constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) From a7015f8226c00ca438d8657d2c657ddfabebaa51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:46:45 +0000 Subject: [PATCH 15/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../quantize_transpose_nvfp4_tuned_1D.cuh | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index df22316e75..e331a7793e 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -157,11 +157,13 @@ __device__ __forceinline__ bf16 compute_nvfp4_scaling_coefficient(const nvfp4_sc } template -__device__ __forceinline__ void colwise_scaling( - const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, - nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, - const int stage_Y, const int stage_X, const int buff_in, - const int buff_out_tr, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { +__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, + const float S_enc_colwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, + uint4 &random_uint4, int &rnd_idx) { const auto &sIn2x = *reinterpret_cast(sIn_ptr); auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); @@ -228,11 +230,13 @@ __device__ __forceinline__ void colwise_scaling( } template -__device__ __forceinline__ void rowwise_scaling( - const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, - nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, - const int stage_Y, const int stage_X, const int buff_in, - const int buff_out, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { +__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, + const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out, RNG_t &rng, uint4 &random_uint4, + int &rnd_idx) { const auto &sIn = *reinterpret_cast(sIn_ptr); auto &sOut = *reinterpret_cast(sOut_ptr); auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -532,13 +536,13 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D // NVFP4 Quantization rowwise_scaling(sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, - stage_Y, stage_X, buff_in, buff_out, - rng, random_uint4, rnd_idx); + stage_Y, stage_X, buff_in, buff_out, rng, + random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { colwise_scaling(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, - S_enc_colwise, stage_Y, stage_X, - buff_in, buff_out_tr, rng, random_uint4, rnd_idx); + S_enc_colwise, stage_Y, stage_X, buff_in, + buff_out_tr, rng, random_uint4, rnd_idx); } // Wait for shared memory writes to be visible to TMA engine From b8a2c60ca61dab8edecb377b1514b1bd7f011186 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Mon, 15 Dec 2025 16:24:38 +0000 Subject: [PATCH 16/23] Use the tuned kernel also for the rowwise only quantization Signed-off-by: Oleg Goncharov --- .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 8243bef781..d3eb541f5b 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1168,7 +1168,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); - if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && return_transpose) { + if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); return; } From 08a82d7256c1100a7633c621de2891530667fe42 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Mon, 15 Dec 2025 16:40:09 +0000 Subject: [PATCH 17/23] Fixed typo Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 72950b9bf5..5fa1d042c6 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -122,7 +122,7 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; - // Numercial truncation to match GPU implementation, which uses mixed precision FMA instruction + // Numerical truncation to match GPU implementation, which uses mixed precision FMA instruction const float scale_reciprocal = static_cast(static_cast(S_enc_b_fp8)); for (size_t j = j_min; j < j_max; j += 2) { From b6728bd37c7d321a02edc8ebe8acc99cfe2f1e85 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Mon, 12 Jan 2026 18:07:22 +0000 Subject: [PATCH 18/23] Initial version of the grouped MXFP8 kernel. Work in progress. Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 6 +- tests/cpp/operator/CMakeLists.txt | 55 +- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 897 ++++++++++++++++ .../test_cast_nvfp4_transpose_grouped.cu | 733 +++++++++++++ transformer_engine/common/cast/cast.cu | 27 +- .../common/cast/dispatch/quantize_grouped.cuh | 121 +++ .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 966 ++++++++++++++++++ .../common/include/transformer_engine/cast.h | 11 + 8 files changed, 2778 insertions(+), 38 deletions(-) create mode 100644 tests/cpp/operator/test_cast_mxfp8_grouped.cu create mode 100644 tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu create mode 100644 transformer_engine/common/cast/dispatch/quantize_grouped.cuh create mode 100644 transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index c2c9d0d915..c64bfbc53f 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -6,9 +6,11 @@ cmake_minimum_required(VERSION 3.18) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + set(CMAKE_CUDA_ARCHITECTURES 100) + # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + set(CMAKE_CUDA_ARCHITECTURES 100) + # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index b2f14b1892..ef088c1151 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,33 +3,34 @@ # See LICENSE for license information. add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu - test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swizzle.cu - test_swap_first_dims.cu + # test_cast.cu + # test_cast_current_scaling.cu + # test_cast_dbias.cu + # test_cast_dbias_dgelu.cu + # test_cast_gated_swiglu.cu + # test_cast_mxfp8_gated_swiglu.cu + # test_qdq.cu + # test_cast_mxfp8.cu + test_cast_mxfp8_grouped.cu + # test_cast_nvfp4_transpose.cu + # test_cast_float8blockwise.cu + # test_dequantize_mxfp8.cu + # test_transpose.cu + # test_cast_transpose.cu + # test_cast_transpose_current_scaling.cu + # test_cast_transpose_dbias.cu + # test_cast_transpose_dbias_dgelu.cu + # test_cast_transpose_dgeglu.cu + # test_act.cu + # test_normalization.cu + # test_normalization_mxfp8.cu + # test_memset.cu + # test_multi_cast_transpose.cu + # test_multi_padding.cu + # test_multi_unpadding.cu + # test_causal_softmax.cu + # test_swizzle.cu + # test_swap_first_dims.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu new file mode 100644 index 0000000000..9c89c916b0 --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -0,0 +1,897 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +enum ActivationKind { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +enum ShapeRepresentation { + SAME_MK = 0, + VARYING_M = 1, + VARYING_K = 2, + VARYING_MK = 3 +}; + +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, + const InputType* input, + const InputType* grad, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + + std::vector output_dbias_fp32(cols, 0); + #pragma omp parallel proc_bind(spread) + { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + + std::vector thread_dbias(cols, 0); + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + // if (processing_method == ProcessingMethod::CAST_DBIAS) { + // // grad is the input + // elt = static_cast(grad[idx]); + // } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + // if (processing_method == ProcessingMethod::CAST_DACT || + // processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + // elt *= static_cast(grad[idx]); + // } + thread_dbias[j] += elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + } + #pragma omp critical + { + for (size_t j = 0; j < cols; ++j) { + output_dbias_fp32[j] += thread_dbias[j]; + } + } + } + // for (size_t j = 0; j < cols; ++j) { + // output_dbias[j] = static_cast(output_dbias_fp32[j]); + // } +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ + +template +void performTest_x1(const ProcessingMethod processing_method, + float (*OP)(const float), + const size_t num_tensors, + const std::vector& logical_shape_vec, + const bool rowwise, + const bool colwise) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = logical_shape_vec[0]; + const size_t cols = logical_shape_vec[1]; + + const size_t M = rows / num_tensors; + const size_t K = cols; + + std::vector scales_rowwise_shape = {rows, cols / 32}; + std::vector scales_colwise_shape = {rows / 32, cols}; + + const size_t elts_num = rows * cols; + const size_t sfs_num = (rows * cols) / 32; + + std::mt19937 gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + + std::vector in_data(elts_num); + + std::vector out_data_rowwise_h(rowwise ? elts_num : 0); + std::vector out_data_colwise_h(colwise ? elts_num : 0); + std::vector out_scales_rowwise_h(rowwise ? sfs_num : 0); + std::vector out_scales_colwise_h(colwise ? sfs_num : 0); + + std::vector out_data_rowwise_ref(rowwise ? elts_num : 0); + std::vector out_data_colwise_ref(colwise ? elts_num : 0); + std::vector out_scales_rowwise_ref(rowwise ? sfs_num : 0); + std::vector out_scales_colwise_ref(colwise ? sfs_num : 0); + + size_t tensor_elts[2] = {128 * 128, 128 * 128}; + std::vector offsets_h(num_tensors); + offsets_h[0] = 0; + for (size_t t = 1; t < num_tensors; ++t) { + offsets_h[t] = offsets_h[t-1] + tensor_elts[t-1]; + } + + for (size_t i = 0; i < elts_num; ++i) { + const float val = dis(gen); + in_data[i] = static_cast(val); + } + + if (rowwise) { + for (size_t i = 0; i < elts_num; ++i) { + out_data_rowwise_h[i] = static_cast(0.0f); + out_data_rowwise_ref[i] = static_cast(0.0f); + } + for (size_t i = 0; i < sfs_num; ++i) { + out_scales_rowwise_h[i] = static_cast(0.0f); + out_scales_rowwise_ref[i] = static_cast(0.0f); + } + } + if (colwise) { + for (size_t i = 0; i < elts_num; ++i) { + out_data_colwise_h[i] = static_cast(0.0f); + out_data_colwise_ref[i] = static_cast(0.0f); + } + for (size_t i = 0; i < sfs_num; ++i) { + out_scales_colwise_h[i] = static_cast(0.0f); + out_scales_colwise_ref[i] = static_cast(0.0f); + } + } + + const size_t in_data_size = elts_num * sizeof(InputType); + const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t out_scales_size = sfs_num * sizeof(fp8e8m0); + + InputType* in_data_d; + OutputType* out_data_rowwise_d; + OutputType* out_data_colwise_d; + fp8e8m0* out_scales_rowwise_d; + fp8e8m0* out_scales_colwise_d; + size_t* offsets_d; + + cudaMalloc((void**)&in_data_d, in_data_size); + cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice); + + cudaMalloc((void**)&offsets_d, in_data_size); + cudaMemcpy(offsets_d, offsets_h.data(), num_tensors * sizeof(size_t), cudaMemcpyHostToDevice); + + NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + NVTEShape offsets_shape_; + offsets_shape_.data[0] = num_tensors; + offsets_shape_.ndim = 1; + + NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + + NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); + + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + + if (rowwise) { + cudaMalloc((void**)&out_data_rowwise_d, out_data_size); + cudaMalloc((void**)&out_scales_rowwise_d, out_scales_size); + cudaMemset(out_data_rowwise_d, 0, out_data_size); + cudaMemset(out_scales_rowwise_d, 0, out_scales_size); + NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size()); + NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_}; + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor); + } + + if (colwise) { + cudaMalloc((void**)&out_data_colwise_d, out_data_size); + cudaMalloc((void**)&out_scales_colwise_d, out_scales_size); + cudaMemset(out_data_colwise_d, 0, out_data_size); + cudaMemset(out_scales_colwise_d, 0, out_scales_size); + NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size()); + NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_}; + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); + } + + /* DO STUFF */ + // Reference (CPU) + for (size_t t = 0; t < num_tensors; ++t) { + const size_t scales_stride_rowwise = K / 32; + const size_t scales_stride_colwise = K; + const size_t data_offset = t * (M * K); + const size_t sfs_offset = t * (M * K / 32); + + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; + + compute_ref( + processing_method, OP, rowwise, colwise, in_ptr, /*grad=*/ nullptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + /*output_dbias=*/ nullptr, M, K, + scales_stride_rowwise, + scales_stride_colwise); + } + + // GPU + nvte_quantize_grouped(in_group_tensor, out_group_tensor, 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (rowwise) { + cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, out_scales_size, cudaMemcpyDeviceToHost); + } + + if (colwise) { + cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, out_scales_size, cudaMemcpyDeviceToHost); + } + + + cudaFree(in_data_d); + cudaFree(offsets_d); + if (rowwise) { + cudaFree(out_data_rowwise_d); + cudaFree(out_scales_rowwise_d); + } + if (colwise) { + cudaFree(out_data_colwise_d); + cudaFree(out_scales_colwise_d); + } + + // const size_t block_size_rows = rowwise ? 1 : 32; + // const size_t block_size_cols = colwise ? 1 : 32; + + // const std::array scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows, + // block_size_cols); + + // const size_t unpadded_blocks_Y = scale_dims[0]; + // const size_t unpadded_blocks_X = scale_dims[1]; + // const size_t blocks_Y = scale_dims[2]; + // const size_t blocks_X = scale_dims[3]; + // const size_t scales_stride = blocks_X; + + // Tensor input("input", shape, itype); + // Tensor grad("grad", shape, itype); + // Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + // Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + + // std::unique_ptr ref_output_c = std::make_unique(rows * cols); + // std::unique_ptr ref_output_dbias = std::make_unique(cols); + // std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); + + // fillCase(&input, InputsFillCase::uniform); + // fillUniform(&grad); + + // Tensor workspace; + // switch (processing_method) { + // case ProcessingMethod::CAST_ONLY: { + // nvte_quantize(input.data(), output_c.data(), 0); + // break; + // } + // case ProcessingMethod::CAST_DBIAS: { + // nvte_quantize_dbias(grad.data(), + // output_c.data(), + // output_dbias.data(), + // workspace.data(), + // 0); + // workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + // nvte_quantize_dbias(grad.data(), + // output_c.data(), + // output_dbias.data(), + // workspace.data(), + // 0); + // break; + // } + // case ProcessingMethod::CAST_DBIAS_DACT: { + // auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + // if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + // else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + // else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + // else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + // nvte_quantize_dbias_dact(grad.data(), + // input.data(), + // output_c.data(), + // output_dbias.data(), + // workspace.data(), + // 0); + // workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + // nvte_quantize_dbias_dact(grad.data(), + // input.data(), + // output_c.data(), + // output_dbias.data(), + // workspace.data(), + // 0); + // break; + // } + // case ProcessingMethod::CAST_DACT: { + // auto nvte_dact = &nvte_dgelu; + // if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + // else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + // else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + // else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + // nvte_dact(grad.data(), input.data(), output_c.data(), 0); + // break; + // } + // case ProcessingMethod::CAST_ACT: { + // auto nvte_act = &nvte_gelu; + // if (OP == &silu) { nvte_act = &nvte_silu; } + // else if (OP == &relu) { nvte_act = &nvte_relu; } + // else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + // else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + // nvte_act(input.data(), output_c.data(), 0); + // break; + // } + // } + + // cudaDeviceSynchronize(); + // auto err = cudaGetLastError(); + // ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // compute_ref(processing_method, + // OP, + // rowwise, + // colwise, + // input.rowwise_cpu_dptr(), + // grad.rowwise_cpu_dptr(), + // ref_output_c.get(), + // ref_output_c.get(), + // ref_output_scales.get(), + // ref_output_scales.get(), + // ref_output_dbias.get(), + // rows, + // cols, + // scales_stride, + // scales_stride); + + // const uint8_t * const gpu_scales_ptr = rowwise + // ? output_c.rowwise_cpu_scale_inv_ptr() + // : output_c.columnwise_cpu_scale_inv_ptr(); + + // const size_t scale_diff_abs_tolerance = 0; + // const double abs_tolerable_mismatches_limit = 0.0; + // const double rel_tolerable_mismatches_limit = 0.0; + + // size_t mismatches_scales = 0; + + // compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + // unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + // mismatches_scales, + // scale_diff_abs_tolerance, + // abs_tolerable_mismatches_limit, + // rel_tolerable_mismatches_limit); + + // const size_t mismatches_elts = 32 * mismatches_scales; + // auto [atol, rtol] = getTolerances(otype); + // compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts); + + // if (processing_method == ProcessingMethod::CAST_DBIAS + // || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + // { + // auto [atol_dbias, rtol_dbias] = getTolerances(itype); + // if (itype == DType::kFloat32) { + // atol_dbias = 1e-4; + // rtol_dbias *= sqrt(static_cast(rows)) ; + // } else { + // rtol_dbias *= 4; + // } + // compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + // } +} + +/** + * Scaling along both dimensions (rows and columns) + * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * AND + * 2) Scaled columns + column-wise scaling factors + */ +/* +template +void performTest_x2(const ProcessingMethod processing_method, + float (*OP)(const float), + const std::pair& shape, + const std::vector& M_i, + const std::vector& Offset_i) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = shape.first; + const size_t cols = shape.second; + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + + std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); + std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + std::unique_ptr ref_output_dbias = std::make_unique(cols); + + fillCase(&input, InputsFillCase::uniform); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DACT: { + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref(processing_method, + OP, + true, + true, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise); + + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + size_t mismatches_scales_rowwise = 0; + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + size_t mismatches_scales_colwise = 0; + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); + + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + } +} +*/ + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector activation_kinds = { + ActivationKind::Identity, + // ActivationKind::GeLU, + // ActivationKind::SiLU, + // ActivationKind::ReLU, + // ActivationKind::QGeLU, + // ActivationKind::SReLU, +}; + +enum ScalingDirection { + ROWWISE = 0, + COLWISE = 1, + BOTH = 2 +}; + +std::vector scaling_directions = { + ScalingDirection::ROWWISE, + // ScalingDirection::COLWISE, + // ScalingDirection::BOTH, +}; + +// {num_tensors, logical_shape_M, logical_shape_K, [M_i], [K_i], [Offset_i]} +std::vector> input_config = { + {1, 128, 128}, + {2, 256, 128}, + // {3, 128 * 3, 256}, + // {5, 256 * 5, 256}, +}; + +} // namespace + +class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam + , // Config + transformer_engine::DType, // InputType + transformer_engine::DType // OutputType + >> {}; + +TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationKind activation = std::get<1>(GetParam()); + const ScalingDirection scaling_direction = std::get<2>(GetParam()); + const std::vector input_config = std::get<3>(GetParam()); + + const size_t num_tensors = input_config[0]; + const std::vector logical_shape = {input_config[1], input_config[2]}; + + // Skips non Act tests if the Activation type is not an identity + if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + && activation != ActivationKind::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + || processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { + GTEST_SKIP(); + } + + bool rowwise = false; + bool colwise = false; + switch (scaling_direction) { + case ScalingDirection::ROWWISE: rowwise = true; break; + case ScalingDirection::COLWISE: colwise = true; break; + case ScalingDirection::BOTH: rowwise = true; colwise = true; break; + } + + auto OP = &identity; + performTest_x1(processing_method, OP, num_tensors, logical_shape, rowwise, colwise); + + // if (processing_method == ProcessingMethod::CAST_ACT) { + // // Forward activations + // auto OP = &identity; + // switch (activation) { + // case ActivationKind::GeLU: OP = &gelu; break; + // case ActivationKind::SiLU: OP = &silu; break; + // case ActivationKind::ReLU: OP = &relu; break; + // case ActivationKind::QGeLU: OP = &qgelu; break; + // case ActivationKind::SReLU: OP = &srelu; break; + // } + + // TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + // if (scaling_direction == ScalingDirection::BOTH) { + // performTest_x2( + // processing_method, OP, tensor_logical_shape, M_i, Offset_i); + // } else { + // performTest_x1( + // processing_method, OP, tensor_logical_shape, M_i, Offset_i, rowwise, colwise); + // } + // ); + // ); + // } else { + // auto OP = &identity; + // switch (activation) { + // case ActivationKind::GeLU: OP = &dgelu; break; + // case ActivationKind::SiLU: OP = &dsilu; break; + // case ActivationKind::ReLU: OP = &drelu; break; + // case ActivationKind::QGeLU: OP = &dqgelu; break; + // case ActivationKind::SReLU: OP = &dsrelu; break; + // } + // TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + // if (scaling_direction == ScalingDirection::BOTH) { + // performTest_x2( + // processing_method, OP, tensor_logical_shape, M_i, Offset_i); + // } else { + // performTest_x1( + // processing_method, OP, tensor_logical_shape, M_i, Offset_i, rowwise, colwise); + // } + // ); + // ); + // } +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: return "CAST_ONLY"; + case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: return ""; + } +} + +std::string to_string(const ActivationKind activation) { + switch (activation) { + case ActivationKind::Identity: return "Identity"; + case ActivationKind::GeLU: return "GeLU"; + case ActivationKind::SiLU: return "SiLU"; + case ActivationKind::ReLU: return "ReLU"; + case ActivationKind::QGeLU: return "QGeLU"; + case ActivationKind::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedFusedCastMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(activation_kinds), + ::testing::ValuesIn(scaling_directions), + ::testing::ValuesIn(input_config), + ::testing::Values(DType::kBFloat16), + ::testing::Values(DType::kFloat8E4M3)), + // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + [](const testing::TestParamInfo& info) { + const ProcessingMethod method = std::get<0>(info.param); + std::string name = to_string(method); + if (method != ProcessingMethod::CAST_ONLY && method != ProcessingMethod::CAST_DBIAS) { + name += "X" + to_string(std::get<1>(info.param)); + } + + switch (std::get<2>(info.param)) { + case ScalingDirection::ROWWISE: name += "_ROWWISE"; break; + case ScalingDirection::COLWISE: name += "_COLWISE"; break; + case ScalingDirection::BOTH: name += "_BOTH"; break; + } + + const std::vector input = std::get<3>(info.param); + name += "_N_" + std::to_string(input[0]); + + name += "_Shape_" + + std::to_string(input[1]) + + "X" + std::to_string(input[2]); + + // name += "_DimsM_"; + // const auto& M_i_ = std::get<5>(info.param); + // for (size_t i = 0; i < M_i_.size(); ++i) { + // const size_t m = M_i_[i]; + // name += std::to_string(m); + // if (i < M_i_.size() - 1) { + // name += "X"; + // } + // } + // name += "_Offsets_"; + // const auto& Offset_i_ = std::get<6>(info.param); + // for (size_t i = 0; i < Offset_i_.size(); ++i) { + // const size_t offset = Offset_i_[i]; + // name += std::to_string(offset); + // if (i < Offset_i_.size() - 1) { + // name += "X"; + // } + // } + name += "_" + test::typeName(std::get<4>(info.param)) + + "_" + test::typeName(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu b/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu new file mode 100644 index 0000000000..1abc9e8b7a --- /dev/null +++ b/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu @@ -0,0 +1,733 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" +#include + +using namespace transformer_engine; +using namespace test; + +namespace { + +double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); + const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); + const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); + return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +} + +template +std::vector create_transpose(const InputType* const input, const size_t rows, size_t cols) { + std::vector input_t(cols * rows); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const size_t idx_t = j * rows + i; + input_t[idx_t] = input[idx]; + } + } + return input_t; +} + +// Compute the global encode scale factor for a given global amax +float compute_global_encode_scaling_factor_FP4(const float global_amax) { + constexpr float fp8_max = 448.0f; // 448.0f; + constexpr float fp4_max = 6.0f; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +// 1D Scaling: Original implementation with 1x16 blocks +template +void quantize_nvfp4_1d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + + constexpr size_t block_size_X = 16; + const size_t blocks_X = divide_round_up(cols, block_size_X); + + std::array cache_buffer; + for (size_t i = 0; i < block_size_X; ++i) { + cache_buffer[i] = 0.0f; + } + + for (size_t i = 0; i < rows; ++i) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t j_min = block_X * block_size_X; + const size_t j_max = j_min + block_size_X; + + // Find block amax + float block_amax = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx] = elt; + block_amax = std::max(block_amax, std::abs(elt)); + } + + // 2. Compute E4M3 scaling factor + // Compute per-block encoding/decoding scaling factor + const float S_dec_b = block_amax / 6.0f; + + // Scale & Store per-block decoding scaling factor + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); + + // Compute "correct" per-block encoding scaling factor + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; + + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = S_dec_b_fp8; + // Numerical truncation to match GPU implementation, which uses mixed precision FMA instruction + const float scale_reciprocal = static_cast(static_cast(S_enc_b_fp8)); + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + + // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + } + } + } +} + +// Compute 2D mathematical scaling factors (8x8 for 128x128 input) +template +void compute_2d_mathematical_scales(float (*OP)(const float), + const InputType* const input, + const size_t rows, + const size_t cols, + const float global_amax, + std::vector>& math_scales) { + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + math_scales.resize(blocks_Y, std::vector(blocks_X)); + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Find 2D block amax over entire 16x16 region + float block_amax = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + block_amax = std::max(block_amax, std::abs(elt)); + } + } + + // Compute E4M3 scaling factor for this 16x16 block + const float S_dec_b = block_amax / 6.0f; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + math_scales[block_Y][block_X] = S_dec_b_fp8; + } + } +} + +// 2D Scaling: NEW implementation with proper replication +template +void quantize_nvfp4_2d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Step 1: Compute mathematical 8x8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr + if (scales != nullptr) { + // Each of the 128 rows gets scaling factors from its corresponding 16×16 block + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + } + + // Step 3: Apply quantization using the mathematical scaling factors + std::array, block_size_Y> cache_buffer; + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Get the scaling factor for this block + const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float scale_reciprocal = S_enc_b_fp8; + + // Process and cache data for this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx_y][cache_idx_x] = elt; + } + } + + // Apply scaling to all elements in this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x1 = j - j_min; + const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1); + + const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1]; + const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ? + cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f; + + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + } + } + } + } +} + +// Wrapper function that calls appropriate implementation based on 2D flag +template +void quantize_nvfp4(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_2d_quantization = false) { + if (use_2d_quantization) { + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } else { + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } +} + +template +void compute_ref(float (*OP)(const float), + const InputType* input, + fp4e2m1x2* output, + fp4e2m1x2* output_t, + fp8e4m3* scales, + fp8e4m3* scales_t, + const float global_amax, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const size_t scales_stride_t, + const bool use_2d_quantization = false) +{ + std::vector input_t = create_transpose(input, rows, cols); + + if (use_2d_quantization) { + // Step 1: Compute mathematical 8×8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Generate scales (128×8) by replicating row-wise + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + + // Step 3: Generate scales_t (128×8) with proper transposed block mapping + for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data + const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X + for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate + const size_t scale_idx = i * scales_stride_t + block_Y_new; + scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig]; + } + } + + // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d + // (This part processes the actual FP4 data using the mathematical scaling factors) + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + + } else { + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + } +} + +void compare_nvfp4_tensors(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8) { + constexpr int max_mismatches_to_print = 3; + + std::vector mismatch_messages; + size_t total_mismatches = 0; + + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + + const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); + if (mismatch) { + total_mismatches++; + // Optional: limit number of detailed messages to avoid overwhelming output + if (total_mismatches <= max_mismatches_to_print) { + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); + std::cout << "Error in tensor " << name << ": " << msg << std::endl; + } + } + } + } + } + + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > max_mismatches_to_print) { + std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print) + << " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl; + } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; + } +} + +// Optional: Function to dump tensor data to files for detailed analysis +void dump_nvfp4_tensor_data(const std::string& prefix, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + std::string test_file = prefix + "_test.txt"; + std::string ref_file = prefix + "_ref.txt"; + std::string diff_file = prefix + "_diff.txt"; + + std::ofstream test_out(test_file); + std::ofstream ref_out(ref_file); + std::ofstream diff_out(diff_file); + + if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + const int pos = idx + k; + + test_out << "pos[" << pos << "] = " << t << std::endl; + ref_out << "pos[" << pos << "] = " << r << std::endl; + diff_out << "pos[" << pos << "] test=" << t << " ref=" << r + << " abs_diff=" << fabs(t - r) + << " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl; + } + } + } + std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl; + } else { + std::cout << "WARNING: Could not open files for tensor data dump" << std::endl; + } +} + +void print_detailed_tensor_comparison(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n", + name.c_str(), rows, cols, rows * cols); + + const int total_elements = rows * cols; + const int check_count = 128; + + printf("--- FIRST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = 0; i < std::min(check_count, total_elements); ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + + if (total_elements > 2 * check_count) { + printf("\n--- LAST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = total_elements - check_count; i < total_elements; ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + } + printf("==================================\n"); +} + +void compareResults_nvfp4(const Tensor &test, + const void *ref, const void *ref_t, const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + if (if_on_gpus) test.to_cpu(); + + const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data = reinterpret_cast(ref); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + + // Print detailed element-by-element comparison + // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); + // print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows); + + // Optionally dump tensor data to files for detailed analysis + if (dump_data) { + dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + + compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); +} + +template +void performTest(float (*OP)(const float), + const size_t K, + const size_t M, + const std::vector& M_i, + const std::vector& Offset_i, + const bool stochastic_rounding = false) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = DType::kFloat4E2M1; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + Tensor input("input", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); + std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); + std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); + std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + + fillCase(&input, InputsFillCase::uniform); + + // Find global amax + float amax = 0.0f; + const InputType* input_dptr = input.rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + amax = fmaxf(amax, static_cast(input_dptr[idx])); + } + } + // Set 2nd stage NVFP4 scaling factor + output.set_scale(amax); + + bool use_2d_quantization = false; + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + output.scale(), + rows, + cols, + scales_stride, + scales_stride_t, + use_2d_quantization); + + QuantizationConfigWrapper quant_config; + + // Initialize stochastic rounding + Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); + rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed + rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence + rng_state.from_cpu(); + quant_config.set_stochastic_rounding(stochastic_rounding); + quant_config.set_rng_state(rng_state.data()); + + // Set 2D quantization based on compile-time flag + quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + + nvte_quantize_v2(input.data(), output.data(), quant_config, 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err)); + } + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + const double atol = 1.0E-6; + const double rtol = 1.0E-6; + + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); + + const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_ptr = ref_scales.get(); + const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + scale_mismatches_num); + + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); +} + +// K = Hidden dim +std::vector K = { + 32, + 64, + 128, + 512 + 32, + 1024 + 64, +}; + +// Logical tensor dim M = Batch Size * Sequence Length +std::vector M = { + 32, + 64, + 256 + 64, + 1024 + 32, +}; + +// Dim M of i-th tensor in a group +std::vector> M_i = { + {32}, + {32, 32}, + {64, 32, 128}, + {32, 96, 160, 64}, + {320, 32, 288, 128}, +}; + +// Offset of i-th tensor in a group +std::vector> Offset_i = { + {0}, + {0, 32}, + {0, 64, 96}, + {0, 32, 128, 288}, + {0, 320, 352, 640}, +}; + +std::vector stochastic_rounding = { + false, + // true +}; + +} // namespace + +class GroupedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam + , + std::vector, + transformer_engine::DType, + bool>> {}; + +TEST_P(GroupedCastTransposeNVFP4TestSuite, TestGroupedCastTransposeNVFP4) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const size_t K = std::get<0>(GetParam()); + const size_t M = std::get<1>(GetParam()); + const std::vector M_i = std::get<2>(GetParam()); + const std::vector Offset_i = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const bool is_stochastic_rounding = std::get<5>(GetParam()); + + if (M_i.size() != Offset_i.size()) { + GTEST_SKIP(); + } + + const size_t group_size = M_i.size(); + // Skip tests if tensors overlap with each other + for (size_t i = 0; i < group_size - 1; ++i) { + if (Offset_i[i] + M_i[i] > Offset_i[i+1]) { + GTEST_SKIP(); + } + } + // Last tensor must be within the allocated group tensor + if (Offset_i.back() + M_i.back() > M) { + GTEST_SKIP(); + } + + // Forward activations + auto OP = &identity; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + performTest(OP, K, M, M_i, Offset_i, is_stochastic_rounding); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(K), + ::testing::ValuesIn(M), + ::testing::ValuesIn(M_i), + ::testing::ValuesIn(Offset_i), + ::testing::Values(DType::kBFloat16) + ::testing::ValuesIn(stochastic_rounding)), + [](const testing::TestParamInfo& info) { + std::string name = ""; + name += "K" + std::to_string(std::get<0>(info.param)) + "X"; + name += "M" + std::to_string(std::get<1>(info.param)) + "X"; + + name += "Group"; + const auto& M_i_ = std::get<2>(info.param); + for (const auto& m: M_i_) { + name += "X" + std::to_string(m); + } + name += "Offset"; + const auto& Offset_i_ = std::get<3>(info.param); + for (const auto& offset: Offset_i_) { + name += "X" + std::to_string(offset); + } + name += "X" + test::typeName(std::get<4>(info.param)); + name += "X" + (std::get<5>(info.param) ? "SR" : "RN"); + return name; + } +); diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 1ed46a3359..8a185c22c2 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -16,6 +16,7 @@ #include "../utils.cuh" #include "dispatch/dequantize.cuh" #include "dispatch/quantize.cuh" +#include "dispatch/quantize_grouped.cuh" #include "transformer_engine/transpose.h" void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -23,7 +24,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea using namespace transformer_engine; constexpr bool IS_ACT = false; - dispatch::quantize_fwd_helper(input, output, nullptr, stream); + // dispatch::quantize_fwd_helper(input, output, nullptr, stream); +} + +void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_grouped); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_grouped_fwd_helper(input, output, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -35,7 +44,7 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no QuantizationConfig quant_config; quant_config.noop_tensor = noop; - nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); + // nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); } void nvte_quantize_v2(const NVTETensor input, NVTETensor output, @@ -44,7 +53,7 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, using namespace transformer_engine; constexpr bool IS_ACT = false; - dispatch::quantize_fwd_helper(input, output, quant_config, stream); + // dispatch::quantize_fwd_helper(input, output, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, @@ -56,15 +65,15 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr bool IS_DACT = false; constexpr const NVTETensor activation_input = nullptr; - dispatch::quantize_bwd_helper( - input, activation_input, output, dbias, workspace, nullptr, stream); + // dispatch::quantize_bwd_helper( + // input, activation_input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; - dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), - stream); + // dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), + // stream); } void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, @@ -86,8 +95,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, } for (int i = 0; i < num_tensors; i++) { - dispatch::quantize_fwd_helper( - inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); + // dispatch::quantize_fwd_helper( + // inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); } // record events on compute streams diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh new file mode 100644 index 0000000000..1c1884cf93 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh @@ -0,0 +1,121 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_grouped.cuh + * \brief Quantize Grouped Tensor dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ + +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/vectorized_pointwise.h" +#include "../core/common.cuh" +#include "../mxfp8/quantize_grouped_mxfp8.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + const NVTEGroupedTensor activation = nullptr; + NVTEGroupedTensor dbias = nullptr; + NVTEGroupedTensor workspace = nullptr; + + const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + GroupedTensor *workspace_tensor = convertNVTEGroupedTensor(workspace); + + mxfp8::quantize_grouped( + input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + +// template +// void quantize_grouped_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTEGroupedTensor output, +// NVTEGroupedTensor dbias, NVTEGroupedTensor workspace, +// const NVTEQuantizationConfig quant_config, cudaStream_t stream) { +// using namespace detail; + +// const Tensor *grad_tensor = convertNVTETensorCheck(grad); +// const Tensor *input_tensor = convertNVTETensor(input); + +// Tensor *output_tensor = convertNVTETensorCheck(output); +// Tensor *dbias_tensor = convertNVTETensor(dbias); +// Tensor *workspace_tensor = convertNVTETensor(workspace); + +// // Quantization config +// QuantizationConfig quant_config_cpp; +// if (quant_config != nullptr) { +// quant_config_cpp = *reinterpret_cast(quant_config); +// } + +// // Noop flag +// Tensor dummy_tensor; +// Tensor *noop_tensor = &dummy_tensor; +// if (quant_config_cpp.noop_tensor != nullptr) { +// noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); +// } + +// // Check for unsupported options +// if (quant_config_cpp.stochastic_rounding) { +// NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, +// "Stochastic rounding is only supported for NVFP4 quantization."); +// } + +// NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), +// "Either rowwise or columnwise output data need to be allocated."); + +// // Dispatch to quantization kernel depending on data format +// switch (output_tensor->scaling_mode) { +// case NVTE_MXFP8_1D_SCALING: { +// mxfp8::quantize( +// *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, +// stream); +// break; +// } +// default: +// NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); +// } +// } + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh new file mode 100644 index 0000000000..a71fcd1366 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -0,0 +1,966 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_grouped_mxfp8.cuh + * \brief CUDA kernels to quantize grouped tensors to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace quantize_grouped_kernel { + + +constexpr int MAX_SUPPORTED_DESCRIPTORS = 64; +__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_DESCRIPTORS]; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; + +constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; +constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + +constexpr size_t BUFF_DIM_Y = THREADS_Y; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; +static_assert(STAGES >= 1); + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +__device__ __forceinline__ size_t +get_current_tensor_id(const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t current_offset, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t* const __restrict__ first_dims_ptr, + const int64_t* const __restrict__ last_dims_ptr, + const int64_t* const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors) + size_t low = 0; + size_t hi = num_tensors; // half-open [low, hi) + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + + // low = first index where offsets[low] > current_offset (or low == num_tensors) + // id = low - 1, but need to evaluate if current_offset < offsets[0] + return (low == 0) ? 0 : (low - 1); + } +} + +__device__ __forceinline__ size_t +get_tensor_rows_num(const size_t tensor_id, + const ShapeRepresentation shape_rep, + const size_t first_logical_dim, + const int64_t* const __restrict__ first_dims_ptr, + const size_t num_tensors) { + size_t rows_num = first_logical_dim; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: rows_num = first_logical_dim / num_tensors; break; + case ShapeRepresentation::VARYING_LAST_DIM: rows_num = first_logical_dim; break; + case ShapeRepresentation::VARYING_FIRST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: rows_num = static_cast(first_dims_ptr[tensor_id]); break; + } + return rows_num; +} + +__device__ __forceinline__ size_t +get_tensor_cols_num(const size_t tensor_id, + const ShapeRepresentation shape_rep, + const size_t last_logical_dim, + const int64_t* const __restrict__ last_dims_ptr) { + size_t cols_num = last_logical_dim; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: + case ShapeRepresentation::VARYING_FIRST_DIM: cols_num = last_logical_dim; break; + case ShapeRepresentation::VARYING_LAST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: cols_num = static_cast(last_dims_ptr[tensor_id]); break; + } + return cols_num; +} + +// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index +__device__ __forceinline__ void +modify_base_tensor_map(const CUtensorMap base_tensor_map, + CUtensorMap* global_tensor_map, + const uintptr_t global_data_ptr, + const size_t global_dim_Y, + const size_t global_dim_X) { + const size_t global_stride = global_dim_X; + + __shared__ CUtensorMap shared_tensor_map; + shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem + + asm volatile( + "{\n\t" + ".reg.b64 tensor_map_ptr; \n\t" + "mov.b64 tensor_map_ptr, %0; \n\t" + "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X + "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" + "}\n" + :: "l"(reinterpret_cast(&shared_tensor_map)), + "l"(global_data_ptr), + "r"(static_cast(global_dim_Y)), + "r"(static_cast(global_dim_X)), + "l"(static_cast(global_stride)) + : "memory" + ); + *global_tensor_map = shared_tensor_map; +} + +template +__global__ void +init_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType* const __restrict__ input_data_ptr, + const IType* const __restrict__ act_input_data_ptr, + const OType* const __restrict__ output_rowwise_data_ptr, + const OType* const __restrict__ output_colwise_data_ptr, + const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t* const __restrict__ offsets_ptr, + const int64_t* const __restrict__ first_dims_ptr, + const int64_t* const __restrict__ last_dims_ptr, + const bool rowwise, + const bool colwise, + const bool compute_activations) { + const bool leading_thread = (threadIdx.x == 0); + const size_t tensor_id = blockIdx.x; + + const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + const size_t offset_elts = offsets_ptr[tensor_id]; + + if (leading_thread && (tensor_id < num_tensors)) { + { + const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], + global_data_ptr, rows, cols); + } + if (compute_activations) { + const uintptr_t global_data_ptr = reinterpret_cast(act_input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], + global_data_ptr, rows, cols); + } + if (rowwise) { + const uintptr_t global_data_ptr = reinterpret_cast(output_rowwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_rowwise, &g_tensor_maps_output_rowwise[tensor_id], + global_data_ptr, rows, cols); + } + if (colwise) { + const uintptr_t global_data_ptr = reinterpret_cast(output_colwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_colwise, &g_tensor_maps_output_colwise[tensor_id], + global_data_ptr, rows, cols); + } + } +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_grouped_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t* const __restrict__ offsets_ptr, + const int64_t* const __restrict__ first_dims_ptr, + const int64_t* const __restrict__ last_dims_ptr, + e8m0_t *const __restrict__ scales_rowwise, + e8m0_t *const __restrict__ scales_colwise, + const float * __restrict__ noop, + float *const __restrict__ dbias_workspace, + float *const __restrict__ amax_ptr) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_global_offset = blockIdx.x * CHUNK_DIM_Y * CHUNK_DIM_X; + + const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, + first_logical_dim, last_logical_dim, + first_dims_ptr, last_dims_ptr, offsets_ptr); + + const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + const size_t scale_stride_rowwise = cols / SCALE_DIM_X; + const size_t scale_stride_colwise = rows / SCALE_DIM_Y; + + const size_t offset_within_tensor = (shape_rep == SAME_BOTH_DIMS) + ? (block_global_offset - tensor_id * rows * cols) + : (block_global_offset - offsets_ptr[tensor_id]); + + if (threadIdx.x == 0) { + printf("Current tensor ID: %lu \n", tensor_id); + printf("Current tensor ROWS: %lu \n", rows); + printf("Current tensor COLS: %lu \n", cols); + } + + const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + const CUtensorMap& tensor_map_input = is_const_last_dim + ? tensor_map_input_static + : g_tensor_maps_input[tensor_id]; + const CUtensorMap& tensor_map_act_input = is_const_last_dim + ? tensor_map_act_input_static + : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap& tensor_map_output_rowwise = is_const_last_dim + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap& tensor_map_output_colwise = is_const_last_dim + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; + + const size_t block_offset_Y = offset_within_tensor / cols; + const size_t block_offset_X = offset_within_tensor % cols; + const size_t blockIdxY = block_offset_Y / CHUNK_DIM_Y; + const size_t blockIdxX = block_offset_X / CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdxY * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdxX * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdxY * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdxX * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + if (is_const_last_dim) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + #pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdxY; + const int dbias_offset_X = blockIdxX * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_grouped_kernel + +template +void quantize_grouped(const GroupedTensor* input, + const GroupedTensor* activations, + const Tensor *noop, + GroupedTensor* output, + GroupedTensor* dbias, + GroupedTensor* workspace, + cudaStream_t stream) +{ + using namespace quantize_grouped_kernel; + + checkCuDriverContext(stream); + + NVTE_CHECK(input->num_tensors == output->num_tensors, "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + const size_t num_tensors = input->num_tensors; + NVTE_CHECK(num_tensors < MAX_SUPPORTED_DESCRIPTORS, + "Number of tensors in a group is larger than the MAX number of supported descriptors (64)."); + + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + const size_t elts_total = first_logical_dim * last_logical_dim; + NVTE_CHECK(first_logical_dim % 128 == 0, "First dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(last_logical_dim % 128 == 0, "Last dimension of a grouped tensor should be divisible by 128."); + + e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + + const int64_t* const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t* const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + const int64_t* const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + + const bool use_rowwise_scaling = output->has_data(); + const bool use_colwise_scaling = output->has_columnwise_data(); + if (use_rowwise_scaling) { + NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); + } + + CheckNoopTensor(*noop, "cast_noop"); + + const size_t blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + const dim3 grid(blocks); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + // const size_t dbias_rows = blocks_Y; + // const size_t dbias_cols = cols; + + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + + ShapeRepresentation shape_rep; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + // if constexpr (IS_DBIAS) { + // NVTE_CHECK(dbias->data.dtype == input_tensor.dtype(), "DBias must have the same type as input_tensor."); + // NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + // NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + // if (workspace->data.dptr == nullptr) { + // workspace->data.shape = {dbias_rows, dbias_cols}; + // workspace->data.dtype = DType::kFloat32; + // return; + // } + // } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, + first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, + first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, + first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + auto kernel = quantize_grouped_mxfp8_kernel + ; + switch (scaling_type) { + case ScalingType::ROWWISE: { + kernel = quantize_grouped_mxfp8_kernel + ; + break; + } + case ScalingType::COLWISE: { + kernel = quantize_grouped_mxfp8_kernel + ; + break; + } + case ScalingType::BIDIMENSIONAL: { + kernel = quantize_grouped_mxfp8_kernel + ; + break; + } + } + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + const IType* const input_dptr = reinterpret_cast(input->data.dptr); + + const IType* const act_input_dptr = (IS_DACT || IS_ACT) + ? reinterpret_cast(activations->data.dptr) + : nullptr; + + OType* const output_rowwise_dptr = use_rowwise_scaling + ? reinterpret_cast(output->data.dptr) + : nullptr; + + OType* const output_colwise_dptr = use_colwise_scaling + ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + + init_tma_descriptors<<>> + (tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, + input_dptr, act_input_dptr, output_rowwise_dptr, output_colwise_dptr, + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, use_colwise_scaling, IS_ACT); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, + scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + + NVTE_CHECK_CUDA(cudaGetLastError()); + + + // if constexpr (IS_DBIAS) { + // if (is_const_last_dim) { + // common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + // } + // } + ); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index a3235e84f1..335bc58c80 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,6 +89,17 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output From 81ededa0b432f35488dd54824d12fe6f3fd4ca72 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 14 Jan 2026 13:21:17 +0000 Subject: [PATCH 19/23] Added support for all shapes. Fixed bugs. Work in progress. Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 558 +++++++----------- .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 93 +-- 2 files changed, 252 insertions(+), 399 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 9c89c916b0..7cc6ce56aa 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -37,10 +37,10 @@ enum ActivationKind { }; enum ShapeRepresentation { - SAME_MK = 0, - VARYING_M = 1, - VARYING_K = 2, - VARYING_MK = 3 + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 }; template @@ -172,6 +172,57 @@ void compute_ref(const ProcessingMethod processing_method, // } } +template +void compare_scaled_elts(const std::string &name, + const T* ref_data, + const T* test_data, + const size_t rows, + const size_t cols, + const bool rowwise, + const size_t tolerable_mismatches_limit = 0, + const double atol = 1e-5, + const double rtol = 1e-8) { + size_t mismatches_num = 0; + int first_mismatch_idx = -1; + + for (size_t i = 0; i < rows * cols; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + std::string direction = rowwise ? "rowwise" : "columnwise"; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << first_mismatch_idx + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } + } +} + /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -179,12 +230,14 @@ void compute_ref(const ProcessingMethod processing_method, * OR * 2) Scaled columns + column-wise scaling factors */ - template void performTest_x1(const ProcessingMethod processing_method, float (*OP)(const float), + const ShapeRepresentation shape_rep, const size_t num_tensors, const std::vector& logical_shape_vec, + const std::vector& first_dims_h, + const std::vector& last_dims_h, const bool rowwise, const bool colwise) { using namespace test; @@ -195,12 +248,12 @@ void performTest_x1(const ProcessingMethod processing_method, const size_t rows = logical_shape_vec[0]; const size_t cols = logical_shape_vec[1]; - const size_t M = rows / num_tensors; - const size_t K = cols; - std::vector scales_rowwise_shape = {rows, cols / 32}; std::vector scales_colwise_shape = {rows / 32, cols}; + const size_t scales_stride_rowwise = scales_rowwise_shape[1]; + const size_t scales_stride_colwise = scales_colwise_shape[1]; + const size_t elts_num = rows * cols; const size_t sfs_num = (rows * cols) / 32; @@ -219,11 +272,13 @@ void performTest_x1(const ProcessingMethod processing_method, std::vector out_scales_rowwise_ref(rowwise ? sfs_num : 0); std::vector out_scales_colwise_ref(colwise ? sfs_num : 0); - size_t tensor_elts[2] = {128 * 128, 128 * 128}; - std::vector offsets_h(num_tensors); - offsets_h[0] = 0; - for (size_t t = 1; t < num_tensors; ++t) { - offsets_h[t] = offsets_h[t-1] + tensor_elts[t-1]; + std::vector offsets_h(num_tensors + 1); + for (size_t t = 0; t < num_tensors + 1; ++t) { + if (t == 0) { + offsets_h[t] = 0; + } else { + offsets_h[t] = offsets_h[t-1] + (first_dims_h[t-1] * last_dims_h[t-1]); + } } for (size_t i = 0; i < elts_num; ++i) { @@ -256,32 +311,66 @@ void performTest_x1(const ProcessingMethod processing_method, const size_t out_data_size = elts_num * sizeof(OutputType); const size_t out_scales_size = sfs_num * sizeof(fp8e8m0); + const size_t first_dims_size = num_tensors * sizeof(size_t); + const size_t last_dims_size = num_tensors * sizeof(size_t); + const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); + InputType* in_data_d; OutputType* out_data_rowwise_d; OutputType* out_data_colwise_d; fp8e8m0* out_scales_rowwise_d; fp8e8m0* out_scales_colwise_d; + size_t* first_dims_d; + size_t* last_dims_d; size_t* offsets_d; cudaMalloc((void**)&in_data_d, in_data_size); + cudaMalloc((void**)&first_dims_d, first_dims_size); + cudaMalloc((void**)&last_dims_d, last_dims_size); + cudaMalloc((void**)&offsets_d, offsets_size); + cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice); - - cudaMalloc((void**)&offsets_d, in_data_size); - cudaMemcpy(offsets_d, offsets_h.data(), num_tensors * sizeof(size_t), cudaMemcpyHostToDevice); + cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + NVTEShape first_dims_shape_; + NVTEShape last_dims_shape_; NVTEShape offsets_shape_; - offsets_shape_.data[0] = num_tensors; + + first_dims_shape_.ndim = 1; + last_dims_shape_.ndim = 1; offsets_shape_.ndim = 1; + first_dims_shape_.data[0] = num_tensors; + last_dims_shape_.data[0] = num_tensors; + offsets_shape_.data[0] = num_tensors + 1; + NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); - - NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; - nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + + if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + } + + if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + } + + if (shape_rep != SAME_BOTH_DIMS) { + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + } if (rowwise) { cudaMalloc((void**)&out_data_rowwise_d, out_data_size); @@ -307,13 +396,15 @@ void performTest_x1(const ProcessingMethod processing_method, nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); } - /* DO STUFF */ // Reference (CPU) for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + const size_t scales_stride_rowwise = K / 32; const size_t scales_stride_colwise = K; - const size_t data_offset = t * (M * K); - const size_t sfs_offset = t * (M * K / 32); + const size_t data_offset = offsets_h[t]; + const size_t sfs_offset = data_offset / 32; const InputType* const in_ptr = in_data.data() + data_offset; OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; @@ -336,18 +427,50 @@ void performTest_x1(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + auto [atol, rtol] = getTolerances(otype); + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + if (rowwise) { cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, out_scales_size, cudaMemcpyDeviceToHost); + + size_t mismatches_scales = 0; + compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), + scales_rowwise_shape[0], scales_rowwise_shape[1], scales_stride_rowwise, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), + out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); } if (colwise) { cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost); cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, out_scales_size, cudaMemcpyDeviceToHost); - } + size_t mismatches_scales = 0; + compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(), + scales_colwise_shape[0], scales_colwise_shape[1], scales_stride_colwise, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), + out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + } cudaFree(in_data_d); + cudaFree(first_dims_d); + cudaFree(last_dims_d); cudaFree(offsets_d); if (rowwise) { cudaFree(out_data_rowwise_d); @@ -357,332 +480,7 @@ void performTest_x1(const ProcessingMethod processing_method, cudaFree(out_data_colwise_d); cudaFree(out_scales_colwise_d); } - - // const size_t block_size_rows = rowwise ? 1 : 32; - // const size_t block_size_cols = colwise ? 1 : 32; - - // const std::array scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows, - // block_size_cols); - - // const size_t unpadded_blocks_Y = scale_dims[0]; - // const size_t unpadded_blocks_X = scale_dims[1]; - // const size_t blocks_Y = scale_dims[2]; - // const size_t blocks_X = scale_dims[3]; - // const size_t scales_stride = blocks_X; - - // Tensor input("input", shape, itype); - // Tensor grad("grad", shape, itype); - // Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); - // Tensor output_dbias("output_dbias", std::vector{ cols }, itype); - - // std::unique_ptr ref_output_c = std::make_unique(rows * cols); - // std::unique_ptr ref_output_dbias = std::make_unique(cols); - // std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); - - // fillCase(&input, InputsFillCase::uniform); - // fillUniform(&grad); - - // Tensor workspace; - // switch (processing_method) { - // case ProcessingMethod::CAST_ONLY: { - // nvte_quantize(input.data(), output_c.data(), 0); - // break; - // } - // case ProcessingMethod::CAST_DBIAS: { - // nvte_quantize_dbias(grad.data(), - // output_c.data(), - // output_dbias.data(), - // workspace.data(), - // 0); - // workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - - // nvte_quantize_dbias(grad.data(), - // output_c.data(), - // output_dbias.data(), - // workspace.data(), - // 0); - // break; - // } - // case ProcessingMethod::CAST_DBIAS_DACT: { - // auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; - // if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } - // else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } - // else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } - // else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } - - // nvte_quantize_dbias_dact(grad.data(), - // input.data(), - // output_c.data(), - // output_dbias.data(), - // workspace.data(), - // 0); - // workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - - // nvte_quantize_dbias_dact(grad.data(), - // input.data(), - // output_c.data(), - // output_dbias.data(), - // workspace.data(), - // 0); - // break; - // } - // case ProcessingMethod::CAST_DACT: { - // auto nvte_dact = &nvte_dgelu; - // if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } - // else if (OP == &drelu) { nvte_dact = &nvte_drelu; } - // else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } - // else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } - - // nvte_dact(grad.data(), input.data(), output_c.data(), 0); - // break; - // } - // case ProcessingMethod::CAST_ACT: { - // auto nvte_act = &nvte_gelu; - // if (OP == &silu) { nvte_act = &nvte_silu; } - // else if (OP == &relu) { nvte_act = &nvte_relu; } - // else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } - // else if (OP == &srelu) { nvte_act = &nvte_srelu; } - - // nvte_act(input.data(), output_c.data(), 0); - // break; - // } - // } - - // cudaDeviceSynchronize(); - // auto err = cudaGetLastError(); - // ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - // compute_ref(processing_method, - // OP, - // rowwise, - // colwise, - // input.rowwise_cpu_dptr(), - // grad.rowwise_cpu_dptr(), - // ref_output_c.get(), - // ref_output_c.get(), - // ref_output_scales.get(), - // ref_output_scales.get(), - // ref_output_dbias.get(), - // rows, - // cols, - // scales_stride, - // scales_stride); - - // const uint8_t * const gpu_scales_ptr = rowwise - // ? output_c.rowwise_cpu_scale_inv_ptr() - // : output_c.columnwise_cpu_scale_inv_ptr(); - - // const size_t scale_diff_abs_tolerance = 0; - // const double abs_tolerable_mismatches_limit = 0.0; - // const double rel_tolerable_mismatches_limit = 0.0; - - // size_t mismatches_scales = 0; - - // compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - // unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - // mismatches_scales, - // scale_diff_abs_tolerance, - // abs_tolerable_mismatches_limit, - // rel_tolerable_mismatches_limit); - - // const size_t mismatches_elts = 32 * mismatches_scales; - // auto [atol, rtol] = getTolerances(otype); - // compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts); - - // if (processing_method == ProcessingMethod::CAST_DBIAS - // || processing_method == ProcessingMethod::CAST_DBIAS_DACT) - // { - // auto [atol_dbias, rtol_dbias] = getTolerances(itype); - // if (itype == DType::kFloat32) { - // atol_dbias = 1e-4; - // rtol_dbias *= sqrt(static_cast(rows)) ; - // } else { - // rtol_dbias *= 4; - // } - // compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); - // } -} - -/** - * Scaling along both dimensions (rows and columns) - * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): - * 1) Scaled rows + row-wise scaling factors - * AND - * 2) Scaled columns + column-wise scaling factors - */ -/* -template -void performTest_x2(const ProcessingMethod processing_method, - float (*OP)(const float), - const std::pair& shape, - const std::vector& M_i, - const std::vector& Offset_i) { - using namespace test; - using EncodingType = fp32; - DType itype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - const size_t rows = shape.first; - const size_t cols = shape.second; - - const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32); - const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1); - - const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; - const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; - const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; - const size_t blocks_X_rowwise = scale_dims_rowwise[3]; - const size_t scales_stride_rowwise = blocks_X_rowwise; - - const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; - const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; - const size_t blocks_Y_colwise = scale_dims_colwise[2]; - const size_t blocks_X_colwise = scale_dims_colwise[3]; - const size_t scales_stride_colwise = blocks_X_colwise; - - Tensor input("input", shape, itype); - Tensor grad("grad", shape, itype); - Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING); - Tensor output_dbias("output_dbias", std::vector{ cols }, itype); - - std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); - std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols); - std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); - std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); - std::unique_ptr ref_output_dbias = std::make_unique(cols); - - fillCase(&input, InputsFillCase::uniform); - fillUniform(&grad); - - Tensor workspace; - switch (processing_method) { - case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output.data(), 0); - break; - } - case ProcessingMethod::CAST_DBIAS: { - nvte_quantize_dbias(grad.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); - workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - - nvte_quantize_dbias(grad.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); - break; - } - case ProcessingMethod::CAST_DBIAS_DACT: { - auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; - if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } - else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } - else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } - else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } - - nvte_quantize_dbias_dact(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); - workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - - nvte_quantize_dbias_dact(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); - break; - } - case ProcessingMethod::CAST_DACT: { - auto nvte_dact = &nvte_dgelu; - if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } - else if (OP == &drelu) { nvte_dact = &nvte_drelu; } - else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } - else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } - - nvte_dact(grad.data(), input.data(), output.data(), 0); - break; - } - case ProcessingMethod::CAST_ACT: { - auto nvte_act = &nvte_gelu; - if (OP == &silu) { nvte_act = &nvte_silu; } - else if (OP == &relu) { nvte_act = &nvte_relu; } - else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } - else if (OP == &srelu) { nvte_act = &nvte_srelu; } - - nvte_act(input.data(), output.data(), 0); - break; - } - } - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - compute_ref(processing_method, - OP, - true, - true, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c_rowwise.get(), - ref_output_c_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_output_dbias.get(), - rows, - cols, - scales_stride_rowwise, - scales_stride_colwise); - - const size_t scale_diff_abs_tolerance = 0; - const double abs_tolerable_mismatches_limit = 0.0; - const double rel_tolerable_mismatches_limit = 0.0; - - size_t mismatches_scales_rowwise = 0; - compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); - - size_t mismatches_scales_colwise = 0; - compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); - - const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; - const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); - compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); - - if (processing_method == ProcessingMethod::CAST_DBIAS - || processing_method == ProcessingMethod::CAST_DBIAS_DACT) - { - auto [atol_dbias, rtol_dbias] = getTolerances(itype); - if (itype == DType::kFloat32) { - atol_dbias = 1e-4; - rtol_dbias *= sqrt(static_cast(rows)) ; - } else { - rtol_dbias *= 4; - } - compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); - } } -*/ std::vector processing_methods = { ProcessingMethod::CAST_ONLY, @@ -716,10 +514,10 @@ std::vector scaling_directions = { // {num_tensors, logical_shape_M, logical_shape_K, [M_i], [K_i], [Offset_i]} std::vector> input_config = { - {1, 128, 128}, - {2, 256, 128}, - // {3, 128 * 3, 256}, - // {5, 256 * 5, 256}, + {0, 1, 128, 128}, + {0, 2, 256, 128}, + {1, 2, 512, 128, 128, 512-128}, + {3, 2, 1, 128 * 128 + 256 * 256, 128, 256, 128, 256}, }; } // namespace @@ -747,9 +545,40 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { const ScalingDirection scaling_direction = std::get<2>(GetParam()); const std::vector input_config = std::get<3>(GetParam()); - const size_t num_tensors = input_config[0]; - const std::vector logical_shape = {input_config[1], input_config[2]}; - + const ShapeRepresentation shape_rep = static_cast(input_config[0]); + const size_t num_tensors = input_config[1]; + const std::vector logical_shape = {input_config[2], input_config[3]}; + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + for (size_t t = 0; t < num_tensors; ++t) { + switch (shape_rep) { + case SAME_BOTH_DIMS: { + first_dims[t] = logical_shape[0] / num_tensors; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_FIRST_DIM: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_LAST_DIM: { + first_dims[t] = logical_shape[0]; + last_dims[t] = input_config[t + (4 + num_tensors)]; + break; + } + case VARYING_BOTH_DIMS: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = input_config[t + (4 + num_tensors)]; + break; + } + } + // Skips tests if tensor dims are not multiples of 128 + if ((first_dims[t] % 128 != 0) || (last_dims[t] % 128 != 0)) { + GTEST_SKIP(); + } + } + // Skips non Act tests if the Activation type is not an identity if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) && activation != ActivationKind::Identity) { @@ -771,7 +600,8 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { } auto OP = &identity; - performTest_x1(processing_method, OP, num_tensors, logical_shape, rowwise, colwise); + performTest_x1(processing_method, OP, shape_rep, num_tensors, logical_shape, + first_dims, last_dims, rowwise, colwise); // if (processing_method == ProcessingMethod::CAST_ACT) { // // Forward activations @@ -867,11 +697,19 @@ INSTANTIATE_TEST_SUITE_P( } const std::vector input = std::get<3>(info.param); - name += "_N_" + std::to_string(input[0]); + name += "_Shape_"; + switch(static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; + case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; + case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; + case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; + }; + + name += "_N_" + std::to_string(input[1]); name += "_Shape_" + - std::to_string(input[1]) + - "X" + std::to_string(input[2]); + std::to_string(input[2]) + + "X" + std::to_string(input[3]); // name += "_DimsM_"; // const auto& M_i_ = std::get<5>(info.param); diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index a71fcd1366..dd666d0e34 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -35,9 +35,9 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_D __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_DESCRIPTORS]; enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, + SAME_BOTH_DIMS = 0, VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, + VARYING_LAST_DIM = 2, VARYING_BOTH_DIMS = 3 }; @@ -112,8 +112,8 @@ get_tensor_rows_num(const size_t tensor_id, const size_t num_tensors) { size_t rows_num = first_logical_dim; switch (shape_rep) { - case ShapeRepresentation::SAME_BOTH_DIMS: rows_num = first_logical_dim / num_tensors; break; - case ShapeRepresentation::VARYING_LAST_DIM: rows_num = first_logical_dim; break; + case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break; + case ShapeRepresentation::VARYING_LAST_DIM: rows_num = first_logical_dim; break; case ShapeRepresentation::VARYING_FIRST_DIM: case ShapeRepresentation::VARYING_BOTH_DIMS: rows_num = static_cast(first_dims_ptr[tensor_id]); break; } @@ -251,6 +251,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } + // TODO: Add "acquire" semantics for Tensor Map, once it has been modified constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; const size_t block_global_offset = blockIdx.x * CHUNK_DIM_Y * CHUNK_DIM_X; @@ -262,10 +263,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); const size_t scale_stride_rowwise = cols / SCALE_DIM_X; - const size_t scale_stride_colwise = rows / SCALE_DIM_Y; + const size_t scale_stride_colwise = cols; const size_t offset_within_tensor = (shape_rep == SAME_BOTH_DIMS) - ? (block_global_offset - tensor_id * rows * cols) + ? block_global_offset // grouped tensor can be treated as continuous tensor for MXFP8 : (block_global_offset - offsets_ptr[tensor_id]); if (threadIdx.x == 0) { @@ -293,6 +294,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t block_offset_X = offset_within_tensor % cols; const size_t blockIdxY = block_offset_Y / CHUNK_DIM_Y; const size_t blockIdxX = block_offset_X / CHUNK_DIM_X; + + // Early exit if the border of the chunk goes over the + if (block_offset_Y >= rows) { + return; + } + const size_t scales_block_offset_Y_rowwise = blockIdxY * CHUNK_DIM_Y; const size_t scales_block_offset_X_rowwise = blockIdxX * CHUNK_DIM_X / SCALE_DIM_X; const size_t scales_block_offset_Y_colwise = blockIdxY * CHUNK_DIM_Y / SCALE_DIM_Y; @@ -766,6 +773,29 @@ void quantize_grouped(const GroupedTensor* input, checkCuDriverContext(stream); + const bool use_rowwise_scaling = output->has_data(); + const bool use_colwise_scaling = output->has_columnwise_data(); + + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + + ShapeRepresentation shape_rep; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + NVTE_CHECK(input->num_tensors == output->num_tensors, "Number of input and output tensors must be same."); NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); @@ -777,25 +807,27 @@ void quantize_grouped(const GroupedTensor* input, const size_t first_logical_dim = input->logical_shape.data[0]; const size_t last_logical_dim = input->logical_shape.data[1]; const size_t elts_total = first_logical_dim * last_logical_dim; - NVTE_CHECK(first_logical_dim % 128 == 0, "First dimension of a grouped tensor should be divisible by 128."); + + // Logical shape of a tensor with varying all dims is [1, M*K] + if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { + NVTE_CHECK(first_logical_dim % 128 == 0, "First dimension of a grouped tensor should be divisible by 128."); + } NVTE_CHECK(last_logical_dim % 128 == 0, "Last dimension of a grouped tensor should be divisible by 128."); e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); - const int64_t* const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); - const int64_t* const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); - const int64_t* const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); - - const bool use_rowwise_scaling = output->has_data(); - const bool use_colwise_scaling = output->has_columnwise_data(); if (use_rowwise_scaling) { NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); } if (use_colwise_scaling) { NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - + + const int64_t* const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t* const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + const int64_t* const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + CheckNoopTensor(*noop, "cast_noop"); const size_t blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); @@ -808,26 +840,7 @@ void quantize_grouped(const GroupedTensor* input, // const size_t dbias_rows = blocks_Y; // const size_t dbias_cols = cols; - ScalingType scaling_type; - if (use_rowwise_scaling && (!use_colwise_scaling)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - scaling_type = ScalingType::COLWISE; - } else if (use_rowwise_scaling && use_colwise_scaling) { - scaling_type = ScalingType::BIDIMENSIONAL; - } - - ShapeRepresentation shape_rep; - if (output->all_same_shape()) { - shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; - } else if (output->all_same_first_dim()) { - shape_rep = ShapeRepresentation::VARYING_LAST_DIM; - } else if (output->all_same_last_dim()) { - shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; - } else if (output->varying_both_dims()) { - shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; - } - + printf("Shape #: %d \n", static_cast(shape_rep)); const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); // if constexpr (IS_DBIAS) { @@ -935,11 +948,13 @@ void quantize_grouped(const GroupedTensor* input, ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - init_tma_descriptors<<>> - (tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, - input_dptr, act_input_dptr, output_rowwise_dptr, output_colwise_dptr, - shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, use_colwise_scaling, IS_ACT); + if (!is_const_last_dim) { + init_tma_descriptors<<>> + (tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, + input_dptr, act_input_dptr, output_rowwise_dptr, output_colwise_dptr, + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, use_colwise_scaling, IS_ACT); + } kernel<<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, From fb83f087e6925bb334f1d362a52531b18a2da746 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 14 Jan 2026 15:32:36 +0000 Subject: [PATCH 20/23] Added acquire memory fence for tensor map Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 10 +- .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 153 ++++++++---------- 2 files changed, 70 insertions(+), 93 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 7cc6ce56aa..52f25facae 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -512,12 +512,12 @@ std::vector scaling_directions = { // ScalingDirection::BOTH, }; -// {num_tensors, logical_shape_M, logical_shape_K, [M_i], [K_i], [Offset_i]} +// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - {0, 1, 128, 128}, - {0, 2, 256, 128}, - {1, 2, 512, 128, 128, 512-128}, - {3, 2, 1, 128 * 128 + 256 * 256, 128, 256, 128, 256}, + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,512-128}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, }; } // namespace diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index dd666d0e34..f74cbd13cc 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -218,6 +218,11 @@ init_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, } } +__device__ __forceinline__ void +fence_acquire_tensormap(const CUtensorMap* tensor_map) { + asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(tensor_map)); +} + template @@ -233,8 +238,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int64_t* const __restrict__ offsets_ptr, const int64_t* const __restrict__ first_dims_ptr, const int64_t* const __restrict__ last_dims_ptr, - e8m0_t *const __restrict__ scales_rowwise, - e8m0_t *const __restrict__ scales_colwise, + e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float * __restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { @@ -251,7 +256,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } - // TODO: Add "acquire" semantics for Tensor Map, once it has been modified constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; const size_t block_global_offset = blockIdx.x * CHUNK_DIM_Y * CHUNK_DIM_X; @@ -265,18 +269,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scale_stride_rowwise = cols / SCALE_DIM_X; const size_t scale_stride_colwise = cols; - const size_t offset_within_tensor = (shape_rep == SAME_BOTH_DIMS) + const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + const size_t offset_within_tensor = is_const_last_dim ? block_global_offset // grouped tensor can be treated as continuous tensor for MXFP8 : (block_global_offset - offsets_ptr[tensor_id]); - if (threadIdx.x == 0) { - printf("Current tensor ID: %lu \n", tensor_id); - printf("Current tensor ROWS: %lu \n", rows); - printf("Current tensor COLS: %lu \n", cols); - } - - const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const CUtensorMap& tensor_map_input = is_const_last_dim ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; @@ -290,16 +288,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + const bool leading_thread = (threadIdx.x == 0); + + if (leading_thread && (!is_const_last_dim)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { fence_acquire_tensormap(&tensor_map_act_input); } + if constexpr (ROWWISE_SCALING) { fence_acquire_tensormap(&tensor_map_output_rowwise); } + if constexpr (COLWISE_SCALING) { fence_acquire_tensormap(&tensor_map_output_colwise); } + } + const size_t block_offset_Y = offset_within_tensor / cols; const size_t block_offset_X = offset_within_tensor % cols; const size_t blockIdxY = block_offset_Y / CHUNK_DIM_Y; const size_t blockIdxX = block_offset_X / CHUNK_DIM_X; + if (leading_thread) { + printf("Current tensor ID: %2lu Rows: %4lu Cols: %4lu BLOCK IdxY: %2lu IdxX %2lu Offset Y: %4lu Offset X: %4lu\n", + tensor_id, rows, cols, blockIdxY, blockIdxX, block_offset_Y, block_offset_X); + } + // Early exit if the border of the chunk goes over the if (block_offset_Y >= rows) { return; } + e8m0_t *const scales_rowwise = scales_rowwise_ptr + (is_const_last_dim ? 0 : offsets_ptr[tensor_id] / SCALE_DIM_X); + e8m0_t *const scales_colwise = scales_colwise_ptr + (is_const_last_dim ? 0 : offsets_ptr[tensor_id] / SCALE_DIM_Y); + const size_t scales_block_offset_Y_rowwise = blockIdxY * CHUNK_DIM_Y; const size_t scales_block_offset_X_rowwise = blockIdxX * CHUNK_DIM_X / SCALE_DIM_X; const size_t scales_block_offset_Y_colwise = blockIdxY * CHUNK_DIM_Y / SCALE_DIM_Y; @@ -315,12 +330,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t thread_offset_Y_colwise = tid_Y_colwise; const size_t thread_offset_X_colwise = tid_X_colwise; - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; @@ -343,12 +352,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_sh = reinterpret_cast(dshmem); @@ -360,8 +366,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - const bool is_master_thread = (threadIdx.x == 0); - float partial_dbias_colwise = 0.0f; float thread_dbias_rowwise[SCALE_DIM_X]; if constexpr (IS_DBIAS) { @@ -373,24 +377,24 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_amax = 0.0f; -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init + // Initialize shared memory barrier with the number of threads participating in the barrier. + #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[STAGES]; - initialize_barriers(mbar, is_master_thread); + initialize_barriers(mbar, leading_thread); int parity = 0; if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); + &mbar[0], leading_thread); } else { copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); + &mbar[0], leading_thread); } -#pragma unroll + #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t buff = stage % BUFFS_NUM; const size_t next_stage = stage + 1; @@ -410,10 +414,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - is_master_thread); + leading_thread); } else { copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); } } @@ -432,7 +436,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll + #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; in_colwise_IType[i] = in_sh[shmem_offset_colwise]; @@ -440,7 +444,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } thread_amax = static_cast(thread_amax_f16); } else { -#pragma unroll + #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; @@ -463,17 +467,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_CACHED_ACT_OP) { cached_act_sh[shmem_offset_colwise] = static_cast(elt); } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } + thread_amax = fmaxf(thread_amax, fabsf(elt)); in_compute_colwise[i] = elt; } } @@ -491,8 +485,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; -// 3. Scale elements -#pragma unroll + // 3. Scale elements + #pragma unroll for (int i = 0; i < SCALE_DIM_Y; ++i) { float in; if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { @@ -520,14 +514,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll + #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll + #pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); } @@ -538,33 +532,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll + #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - // Load cached elements in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } + if constexpr (std::is_same_v) { + #pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { + #pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); } } } @@ -573,7 +561,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } } else { -#pragma unroll + #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; @@ -586,7 +574,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DACT) { act_in.load_from(&act_in_sh[shmem_offset_rowwise]); } -#pragma unroll + #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; // Compute element @@ -607,18 +595,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (!std::is_same_v) { elt = static_cast(static_cast(elt)); } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } + thread_amax = fmaxf(thread_amax, fabsf(elt)); in_compute_rowwise[j] = elt; } } @@ -638,10 +615,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; // 3. Scale elements -#pragma unroll + #pragma unroll for (int w = 0; w < WAVES; ++w) { Vec out; -#pragma unroll + #pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { IType2 in; OType2 &out_pair = reinterpret_cast(out.data.elt[e]); @@ -674,7 +651,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // After syncthreads, writes by all threads are visible to TMA engine. // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { + if (leading_thread) { const int global_offset_Y = block_offset_Y + stage_offset_Y; const int global_offset_X = block_offset_X; const int buff_offset = buff * BUFF_DIM; @@ -750,11 +727,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) block_amax = reduce_max(block_amax, warp_id); } - if (is_master_thread && amax_ptr != nullptr) { + if (leading_thread && amax_ptr != nullptr) { atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, is_master_thread); + destroy_barriers(mbar, leading_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace quantize_grouped_kernel From cd92f7257dcd981c5164b13308df859802013294 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 14 Jan 2026 23:26:42 +0000 Subject: [PATCH 21/23] Fixed stride values in TMA descriptors (should be in bytes) Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 91 ++++++------------- .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 87 +++++++++--------- 2 files changed, 71 insertions(+), 107 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 52f25facae..60cc693847 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -238,6 +238,7 @@ void performTest_x1(const ProcessingMethod processing_method, const std::vector& logical_shape_vec, const std::vector& first_dims_h, const std::vector& last_dims_h, + const std::vector& offsets_h, const bool rowwise, const bool colwise) { using namespace test; @@ -248,15 +249,11 @@ void performTest_x1(const ProcessingMethod processing_method, const size_t rows = logical_shape_vec[0]; const size_t cols = logical_shape_vec[1]; - std::vector scales_rowwise_shape = {rows, cols / 32}; - std::vector scales_colwise_shape = {rows / 32, cols}; - - const size_t scales_stride_rowwise = scales_rowwise_shape[1]; - const size_t scales_stride_colwise = scales_colwise_shape[1]; - const size_t elts_num = rows * cols; const size_t sfs_num = (rows * cols) / 32; + std::vector scales_shape = {sfs_num}; + std::mt19937 gen; std::uniform_real_distribution<> dis(-2.0, 1.0); @@ -272,39 +269,24 @@ void performTest_x1(const ProcessingMethod processing_method, std::vector out_scales_rowwise_ref(rowwise ? sfs_num : 0); std::vector out_scales_colwise_ref(colwise ? sfs_num : 0); - std::vector offsets_h(num_tensors + 1); - for (size_t t = 0; t < num_tensors + 1; ++t) { - if (t == 0) { - offsets_h[t] = 0; - } else { - offsets_h[t] = offsets_h[t-1] + (first_dims_h[t-1] * last_dims_h[t-1]); - } - } - for (size_t i = 0; i < elts_num; ++i) { const float val = dis(gen); in_data[i] = static_cast(val); } + const OutputType zero_elt = static_cast(0.0f); + const fp8e8m0 zero_SF = static_cast(0.0f); if (rowwise) { - for (size_t i = 0; i < elts_num; ++i) { - out_data_rowwise_h[i] = static_cast(0.0f); - out_data_rowwise_ref[i] = static_cast(0.0f); - } - for (size_t i = 0; i < sfs_num; ++i) { - out_scales_rowwise_h[i] = static_cast(0.0f); - out_scales_rowwise_ref[i] = static_cast(0.0f); - } + std::fill(out_data_rowwise_h.begin(), out_data_rowwise_h.end(), zero_elt); + std::fill(out_data_rowwise_ref.begin(), out_data_rowwise_ref.end(), zero_elt); + std::fill(out_scales_rowwise_h.begin(), out_scales_rowwise_h.end(), zero_SF); + std::fill(out_scales_rowwise_ref.begin(), out_scales_rowwise_ref.end(), zero_SF); } if (colwise) { - for (size_t i = 0; i < elts_num; ++i) { - out_data_colwise_h[i] = static_cast(0.0f); - out_data_colwise_ref[i] = static_cast(0.0f); - } - for (size_t i = 0; i < sfs_num; ++i) { - out_scales_colwise_h[i] = static_cast(0.0f); - out_scales_colwise_ref[i] = static_cast(0.0f); - } + std::fill(out_data_colwise_h.begin(), out_data_colwise_h.end(), zero_elt); + std::fill(out_data_colwise_ref.begin(), out_data_colwise_ref.end(), zero_elt); + std::fill(out_scales_colwise_h.begin(), out_scales_colwise_h.end(), zero_SF); + std::fill(out_scales_colwise_ref.begin(), out_scales_colwise_ref.end(), zero_SF); } const size_t in_data_size = elts_num * sizeof(InputType); @@ -378,7 +360,7 @@ void performTest_x1(const ProcessingMethod processing_method, cudaMemset(out_data_rowwise_d, 0, out_data_size); cudaMemset(out_scales_rowwise_d, 0, out_scales_size); NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast(otype), logical_shape_}; - NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size()); + NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_shape.data(), scales_shape.size()); NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_}; nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor); nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor); @@ -390,7 +372,7 @@ void performTest_x1(const ProcessingMethod processing_method, cudaMemset(out_data_colwise_d, 0, out_data_size); cudaMemset(out_scales_colwise_d, 0, out_scales_size); NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast(otype), logical_shape_}; - NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size()); + NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_shape.data(), scales_shape.size()); NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_}; nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor); nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); @@ -438,11 +420,8 @@ void performTest_x1(const ProcessingMethod processing_method, size_t mismatches_scales = 0; compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), - scales_rowwise_shape[0], scales_rowwise_shape[1], scales_stride_rowwise, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + 1, sfs_num, sfs_num, mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); const size_t mismatches_elts = 32 * mismatches_scales; @@ -456,11 +435,8 @@ void performTest_x1(const ProcessingMethod processing_method, size_t mismatches_scales = 0; compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(), - scales_colwise_shape[0], scales_colwise_shape[1], scales_stride_colwise, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + 1, sfs_num, sfs_num, mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); const size_t mismatches_elts = 32 * mismatches_scales; @@ -516,8 +492,9 @@ std::vector scaling_directions = { std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, - {VARYING_FIRST_DIM, 2, 512,128, 128,512-128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace @@ -550,11 +527,12 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { const std::vector logical_shape = {input_config[2], input_config[3]}; std::vector first_dims(num_tensors); std::vector last_dims(num_tensors); + std::vector offsets(num_tensors + 1, 0); for (size_t t = 0; t < num_tensors; ++t) { switch (shape_rep) { case SAME_BOTH_DIMS: { first_dims[t] = logical_shape[0] / num_tensors; - last_dims[t] = logical_shape[1]; + last_dims[t] = logical_shape[1]; break; } case VARYING_FIRST_DIM: { @@ -564,7 +542,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { } case VARYING_LAST_DIM: { first_dims[t] = logical_shape[0]; - last_dims[t] = input_config[t + (4 + num_tensors)]; + last_dims[t] = input_config[t + 4]; break; } case VARYING_BOTH_DIMS: { @@ -573,6 +551,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { break; } } + offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t]; // Skips tests if tensor dims are not multiples of 128 if ((first_dims[t] % 128 != 0) || (last_dims[t] % 128 != 0)) { GTEST_SKIP(); @@ -601,7 +580,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { auto OP = &identity; performTest_x1(processing_method, OP, shape_rep, num_tensors, logical_shape, - first_dims, last_dims, rowwise, colwise); + first_dims, last_dims, offsets, rowwise, colwise); // if (processing_method == ProcessingMethod::CAST_ACT) { // // Forward activations @@ -711,24 +690,6 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(input[2]) + "X" + std::to_string(input[3]); - // name += "_DimsM_"; - // const auto& M_i_ = std::get<5>(info.param); - // for (size_t i = 0; i < M_i_.size(); ++i) { - // const size_t m = M_i_[i]; - // name += std::to_string(m); - // if (i < M_i_.size() - 1) { - // name += "X"; - // } - // } - // name += "_Offsets_"; - // const auto& Offset_i_ = std::get<6>(info.param); - // for (size_t i = 0; i < Offset_i_.size(); ++i) { - // const size_t offset = Offset_i_[i]; - // name += std::to_string(offset); - // if (i < Offset_i_.size() - 1) { - // name += "X"; - // } - // } name += "_" + test::typeName(std::get<4>(info.param)) + "_" + test::typeName(std::get<5>(info.param)); return name; diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index f74cbd13cc..89fb4aa88c 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -52,6 +52,8 @@ constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; @@ -83,24 +85,24 @@ get_current_tensor_id(const ShapeRepresentation shape_rep, const size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; } else { - // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors) - size_t low = 0; - size_t hi = num_tensors; // half-open [low, hi) + // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors] + size_t low = 1; + size_t hi = num_tensors; // [low, hi] - while (low < hi) { + while (low <= hi) { const size_t mid = low + (hi - low) / 2; const size_t mid_offset = static_cast(offsets_ptr[mid]); if (mid_offset <= current_offset) { low = mid + 1; } else { - hi = mid; + hi = mid - 1; } } // low = first index where offsets[low] > current_offset (or low == num_tensors) // id = low - 1, but need to evaluate if current_offset < offsets[0] - return (low == 0) ? 0 : (low - 1); + return low - 1; } } @@ -110,7 +112,7 @@ get_tensor_rows_num(const size_t tensor_id, const size_t first_logical_dim, const int64_t* const __restrict__ first_dims_ptr, const size_t num_tensors) { - size_t rows_num = first_logical_dim; + size_t rows_num = 0; switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break; case ShapeRepresentation::VARYING_LAST_DIM: rows_num = first_logical_dim; break; @@ -125,7 +127,7 @@ get_tensor_cols_num(const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, const int64_t* const __restrict__ last_dims_ptr) { - size_t cols_num = last_logical_dim; + size_t cols_num = 0; switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: case ShapeRepresentation::VARYING_FIRST_DIM: cols_num = last_logical_dim; break; @@ -135,14 +137,15 @@ get_tensor_cols_num(const size_t tensor_id, return cols_num; } -// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index +// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index +template __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, CUtensorMap* global_tensor_map, const uintptr_t global_data_ptr, const size_t global_dim_Y, const size_t global_dim_X) { - const size_t global_stride = global_dim_X; + const size_t global_stride_bytes = global_dim_X * sizeof(T); __shared__ CUtensorMap shared_tensor_map; shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem @@ -160,7 +163,7 @@ modify_base_tensor_map(const CUtensorMap base_tensor_map, "l"(global_data_ptr), "r"(static_cast(global_dim_Y)), "r"(static_cast(global_dim_X)), - "l"(static_cast(global_stride)) + "l"(static_cast(global_stride_bytes)) : "memory" ); *global_tensor_map = shared_tensor_map; @@ -197,22 +200,22 @@ init_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, if (leading_thread && (tensor_id < num_tensors)) { { const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); - modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], + modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], global_data_ptr, rows, cols); } if (compute_activations) { const uintptr_t global_data_ptr = reinterpret_cast(act_input_data_ptr + offset_elts); - modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], + modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], global_data_ptr, rows, cols); } if (rowwise) { const uintptr_t global_data_ptr = reinterpret_cast(output_rowwise_data_ptr + offset_elts); - modify_base_tensor_map(base_tensor_map_output_rowwise, &g_tensor_maps_output_rowwise[tensor_id], + modify_base_tensor_map(base_tensor_map_output_rowwise, &g_tensor_maps_output_rowwise[tensor_id], global_data_ptr, rows, cols); } if (colwise) { const uintptr_t global_data_ptr = reinterpret_cast(output_colwise_data_ptr + offset_elts); - modify_base_tensor_map(base_tensor_map_output_colwise, &g_tensor_maps_output_colwise[tensor_id], + modify_base_tensor_map(base_tensor_map_output_colwise, &g_tensor_maps_output_colwise[tensor_id], global_data_ptr, rows, cols); } } @@ -258,7 +261,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - const size_t block_global_offset = blockIdx.x * CHUNK_DIM_Y * CHUNK_DIM_X; + const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK; const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, first_logical_dim, last_logical_dim, @@ -270,10 +273,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scale_stride_colwise = cols; const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const size_t offset_within_tensor = is_const_last_dim - ? block_global_offset // grouped tensor can be treated as continuous tensor for MXFP8 - : (block_global_offset - offsets_ptr[tensor_id]); + // grouped tensor can be treated as continuous tensor for MXFP8 + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t offset_within_tensor = block_global_offset - tensor_base; const CUtensorMap& tensor_map_input = is_const_last_dim ? tensor_map_input_static @@ -297,28 +301,31 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (COLWISE_SCALING) { fence_acquire_tensormap(&tensor_map_output_colwise); } } - const size_t block_offset_Y = offset_within_tensor / cols; - const size_t block_offset_X = offset_within_tensor % cols; - const size_t blockIdxY = block_offset_Y / CHUNK_DIM_Y; - const size_t blockIdxX = block_offset_X / CHUNK_DIM_X; + const size_t blocks_X_num_in_current_tensor = cols / CHUNK_DIM_X; + const size_t block_id_in_current_tensor = is_single_tensor + ? blockIdx.x + : (blockIdx.x - tensor_base / ELTS_PER_CHUNK); + + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + if (leading_thread) { + printf("Current tensor ID: %2lu offset_within_tensor: %4lu CHUNK_DIM_Y: %4lu CHUNK_DIM_X: %4lu\n", + tensor_id, offset_within_tensor, CHUNK_DIM_Y, CHUNK_DIM_X); printf("Current tensor ID: %2lu Rows: %4lu Cols: %4lu BLOCK IdxY: %2lu IdxX %2lu Offset Y: %4lu Offset X: %4lu\n", - tensor_id, rows, cols, blockIdxY, blockIdxX, block_offset_Y, block_offset_X); + tensor_id, rows, cols, block_id_Y, block_id_X, block_offset_Y, block_offset_X); } - // Early exit if the border of the chunk goes over the - if (block_offset_Y >= rows) { - return; - } + e8m0_t *const scales_rowwise = scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); - e8m0_t *const scales_rowwise = scales_rowwise_ptr + (is_const_last_dim ? 0 : offsets_ptr[tensor_id] / SCALE_DIM_X); - e8m0_t *const scales_colwise = scales_colwise_ptr + (is_const_last_dim ? 0 : offsets_ptr[tensor_id] / SCALE_DIM_Y); - - const size_t scales_block_offset_Y_rowwise = blockIdxY * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdxX * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdxY * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdxX * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; const size_t tid_X_rowwise = threadIdx.x % THREADS_X; @@ -335,8 +342,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -607,9 +612,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_X = scales_offset_X_rowwise; const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } + scales_rowwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; @@ -711,8 +714,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } const int dbias_stride = cols; - const int dbias_offset_Y = blockIdxY; - const int dbias_offset_X = blockIdxX * CHUNK_DIM_X + threadIdx.x; + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); if (!col_out_of_bounds_dbias) { From fc2a53ff484d67efdff562daf3eb93fb987f1688 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 15 Jan 2026 16:15:21 +0000 Subject: [PATCH 22/23] Clean up. Small fixes. Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 90 +++------ .../common/cast/dispatch/quantize_grouped.cuh | 8 +- .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 175 ++++++++---------- 3 files changed, 109 insertions(+), 164 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 60cc693847..fb9d6a9768 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -231,18 +231,18 @@ void compare_scaled_elts(const std::string &name, * 2) Scaled columns + column-wise scaling factors */ template -void performTest_x1(const ProcessingMethod processing_method, - float (*OP)(const float), - const ShapeRepresentation shape_rep, - const size_t num_tensors, - const std::vector& logical_shape_vec, - const std::vector& first_dims_h, - const std::vector& last_dims_h, - const std::vector& offsets_h, - const bool rowwise, - const bool colwise) { +void performTest(const ProcessingMethod processing_method, + float (*OP)(const float), + const ShapeRepresentation shape_rep, + const size_t num_tensors, + const std::vector& logical_shape_vec, + const std::vector& first_dims_h, + const std::vector& last_dims_h, + const std::vector& offsets_h, + const bool rowwise, + const bool colwise) { using namespace test; - using EncodingType = fp32; + DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; @@ -484,8 +484,8 @@ enum ScalingDirection { std::vector scaling_directions = { ScalingDirection::ROWWISE, - // ScalingDirection::COLWISE, - // ScalingDirection::BOTH, + ScalingDirection::COLWISE, + ScalingDirection::BOTH, }; // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} @@ -493,6 +493,8 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; @@ -521,6 +523,8 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { const ActivationKind activation = std::get<1>(GetParam()); const ScalingDirection scaling_direction = std::get<2>(GetParam()); const std::vector input_config = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const DType output_type = std::get<5>(GetParam()); const ShapeRepresentation shape_rep = static_cast(input_config[0]); const size_t num_tensors = input_config[1]; @@ -579,52 +583,14 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { } auto OP = &identity; - performTest_x1(processing_method, OP, shape_rep, num_tensors, logical_shape, - first_dims, last_dims, offsets, rowwise, colwise); - - // if (processing_method == ProcessingMethod::CAST_ACT) { - // // Forward activations - // auto OP = &identity; - // switch (activation) { - // case ActivationKind::GeLU: OP = &gelu; break; - // case ActivationKind::SiLU: OP = &silu; break; - // case ActivationKind::ReLU: OP = &relu; break; - // case ActivationKind::QGeLU: OP = &qgelu; break; - // case ActivationKind::SReLU: OP = &srelu; break; - // } - - // TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - // TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - // if (scaling_direction == ScalingDirection::BOTH) { - // performTest_x2( - // processing_method, OP, tensor_logical_shape, M_i, Offset_i); - // } else { - // performTest_x1( - // processing_method, OP, tensor_logical_shape, M_i, Offset_i, rowwise, colwise); - // } - // ); - // ); - // } else { - // auto OP = &identity; - // switch (activation) { - // case ActivationKind::GeLU: OP = &dgelu; break; - // case ActivationKind::SiLU: OP = &dsilu; break; - // case ActivationKind::ReLU: OP = &drelu; break; - // case ActivationKind::QGeLU: OP = &dqgelu; break; - // case ActivationKind::SReLU: OP = &dsrelu; break; - // } - // TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - // TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - // if (scaling_direction == ScalingDirection::BOTH) { - // performTest_x2( - // processing_method, OP, tensor_logical_shape, M_i, Offset_i); - // } else { - // performTest_x1( - // processing_method, OP, tensor_logical_shape, M_i, Offset_i, rowwise, colwise); - // } - // ); - // ); - // } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(processing_method, OP, shape_rep, num_tensors, + logical_shape, first_dims, last_dims, offsets, + rowwise, colwise); + ); + ); } std::string to_string(const ProcessingMethod method) { @@ -658,10 +624,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(activation_kinds), ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), - ::testing::Values(DType::kBFloat16), - ::testing::Values(DType::kFloat8E4M3)), - // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh index 1c1884cf93..207811b395 100644 --- a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh +++ b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh @@ -49,14 +49,14 @@ void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTenso switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { const NVTEGroupedTensor activation = nullptr; - NVTEGroupedTensor dbias = nullptr; - NVTEGroupedTensor workspace = nullptr; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); - GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); - GroupedTensor *workspace_tensor = convertNVTEGroupedTensor(workspace); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); mxfp8::quantize_grouped( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index 89fb4aa88c..644b94c8c8 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -28,11 +28,11 @@ namespace mxfp8 { namespace quantize_grouped_kernel { -constexpr int MAX_SUPPORTED_DESCRIPTORS = 64; -__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_DESCRIPTORS]; -__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_DESCRIPTORS]; -__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_DESCRIPTORS]; -__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_DESCRIPTORS]; +constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; +__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; enum ShapeRepresentation { SAME_BOTH_DIMS = 0, @@ -77,31 +77,25 @@ get_current_tensor_id(const ShapeRepresentation shape_rep, const size_t current_offset, const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t* const __restrict__ first_dims_ptr, - const int64_t* const __restrict__ last_dims_ptr, const int64_t* const __restrict__ offsets_ptr) { if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t current_row = current_offset / last_logical_dim; const size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; } else { - // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors] - size_t low = 1; + size_t low = 0; size_t hi = num_tensors; // [low, hi] - while (low <= hi) { + while (low < hi) { const size_t mid = low + (hi - low) / 2; const size_t mid_offset = static_cast(offsets_ptr[mid]); if (mid_offset <= current_offset) { low = mid + 1; } else { - hi = mid - 1; + hi = mid; } } - - // low = first index where offsets[low] > current_offset (or low == num_tensors) - // id = low - 1, but need to evaluate if current_offset < offsets[0] return low - 1; } } @@ -171,24 +165,24 @@ modify_base_tensor_map(const CUtensorMap base_tensor_map, template __global__ void -init_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, - const __grid_constant__ CUtensorMap base_tensor_map_act_input, - const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, - const IType* const __restrict__ input_data_ptr, - const IType* const __restrict__ act_input_data_ptr, - const OType* const __restrict__ output_rowwise_data_ptr, - const OType* const __restrict__ output_colwise_data_ptr, - const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t first_logical_dim, - const size_t last_logical_dim, - const int64_t* const __restrict__ offsets_ptr, - const int64_t* const __restrict__ first_dims_ptr, - const int64_t* const __restrict__ last_dims_ptr, - const bool rowwise, - const bool colwise, - const bool compute_activations) { +update_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType* const __restrict__ input_data_ptr, + const IType* const __restrict__ act_input_data_ptr, + const OType* const __restrict__ output_rowwise_data_ptr, + const OType* const __restrict__ output_colwise_data_ptr, + const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t* const __restrict__ offsets_ptr, + const int64_t* const __restrict__ first_dims_ptr, + const int64_t* const __restrict__ last_dims_ptr, + const bool rowwise, + const bool colwise, + const bool compute_activations) { const bool leading_thread = (threadIdx.x == 0); const size_t tensor_id = blockIdx.x; @@ -264,37 +258,35 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK; const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, - first_logical_dim, last_logical_dim, - first_dims_ptr, last_dims_ptr, offsets_ptr); + first_logical_dim, last_logical_dim, offsets_ptr); const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); const size_t scale_stride_rowwise = cols / SCALE_DIM_X; const size_t scale_stride_colwise = cols; - const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); // grouped tensor can be treated as continuous tensor for MXFP8 const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); const size_t offset_within_tensor = block_global_offset - tensor_base; - const CUtensorMap& tensor_map_input = is_const_last_dim + const CUtensorMap& tensor_map_input = is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; - const CUtensorMap& tensor_map_act_input = is_const_last_dim + const CUtensorMap& tensor_map_act_input = is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap& tensor_map_output_rowwise = is_const_last_dim + const CUtensorMap& tensor_map_output_rowwise = is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap& tensor_map_output_colwise = is_const_last_dim + const CUtensorMap& tensor_map_output_colwise = is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; const bool leading_thread = (threadIdx.x == 0); - if (leading_thread && (!is_const_last_dim)) { + if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); if constexpr (COMPUTE_ACTIVATIONS) { fence_acquire_tensormap(&tensor_map_act_input); } if constexpr (ROWWISE_SCALING) { fence_acquire_tensormap(&tensor_map_output_rowwise); } @@ -312,13 +304,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; const size_t block_offset_X = block_id_X * CHUNK_DIM_X; - if (leading_thread) { - printf("Current tensor ID: %2lu offset_within_tensor: %4lu CHUNK_DIM_Y: %4lu CHUNK_DIM_X: %4lu\n", - tensor_id, offset_within_tensor, CHUNK_DIM_Y, CHUNK_DIM_X); - printf("Current tensor ID: %2lu Rows: %4lu Cols: %4lu BLOCK IdxY: %2lu IdxX %2lu Offset Y: %4lu Offset X: %4lu\n", - tensor_id, rows, cols, block_id_Y, block_id_X, block_offset_Y, block_offset_X); - } - e8m0_t *const scales_rowwise = scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); e8m0_t *const scales_colwise = scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); @@ -678,7 +663,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) parity ^= 1; if constexpr (IS_DBIAS) { - if (is_const_last_dim) { + if (is_single_tensor) { float thread_partial_dbias = 0.0f; if constexpr (COLWISE_SCALING) { thread_partial_dbias = partial_dbias_colwise; @@ -741,12 +726,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void quantize_grouped(const GroupedTensor* input, - const GroupedTensor* activations, +void quantize_grouped(const GroupedTensor *input, + const GroupedTensor *activations, const Tensor *noop, - GroupedTensor* output, - GroupedTensor* dbias, - GroupedTensor* workspace, + GroupedTensor *output, + Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { using namespace quantize_grouped_kernel; @@ -776,12 +761,15 @@ void quantize_grouped(const GroupedTensor* input, shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; } + // Treat a grouped tensor with const last dims as a single tensor + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + NVTE_CHECK(input->num_tensors == output->num_tensors, "Number of input and output tensors must be same."); NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); const size_t num_tensors = input->num_tensors; - NVTE_CHECK(num_tensors < MAX_SUPPORTED_DESCRIPTORS, + NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than the MAX number of supported descriptors (64)."); const size_t first_logical_dim = input->logical_shape.data[0]; @@ -816,24 +804,21 @@ void quantize_grouped(const GroupedTensor* input, const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - // const size_t dbias_rows = blocks_Y; - // const size_t dbias_cols = cols; - - printf("Shape #: %d \n", static_cast(shape_rep)); - const bool is_const_last_dim = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - - // if constexpr (IS_DBIAS) { - // NVTE_CHECK(dbias->data.dtype == input_tensor.dtype(), "DBias must have the same type as input_tensor."); - // NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - // NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - // if (workspace->data.dptr == nullptr) { - // workspace->data.shape = {dbias_rows, dbias_cols}; - // workspace->data.dtype = DType::kFloat32; - // return; - // } - // } + + const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_cols = last_logical_dim; + if constexpr (IS_DBIAS) { + NVTE_CHECK(is_single_tensor, "DBias is only supported for tensors with the const last dimension."); + NVTE_CHECK(dbias->data.dtype == input->dtype(), "DBias must have the same type as input_tensor."); + NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); @@ -912,44 +897,40 @@ void quantize_grouped(const GroupedTensor* input, } } - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - const IType* const input_dptr = reinterpret_cast(input->data.dptr); - - const IType* const act_input_dptr = (IS_DACT || IS_ACT) - ? reinterpret_cast(activations->data.dptr) - : nullptr; - - OType* const output_rowwise_dptr = use_rowwise_scaling - ? reinterpret_cast(output->data.dptr) - : nullptr; - - OType* const output_colwise_dptr = use_colwise_scaling - ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - - if (!is_const_last_dim) { - init_tma_descriptors<<>> + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType* const input_dptr = reinterpret_cast(input->data.dptr); + + const IType* const act_input_dptr = (IS_DACT || IS_ACT) + ? reinterpret_cast(activations->data.dptr) + : nullptr; + + OType* const output_rowwise_dptr = use_rowwise_scaling + ? reinterpret_cast(output->data.dptr) + : nullptr; + + OType* const output_colwise_dptr = use_colwise_scaling + ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>> (tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, use_colwise_scaling, IS_ACT); } + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + kernel<<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + } NVTE_CHECK_CUDA(cudaGetLastError()); - - - // if constexpr (IS_DBIAS) { - // if (is_const_last_dim) { - // common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - // } - // } ); // NOLINT(*) ); // NOLINT(*) } From 74a79175d0940a62270bc2adba32ee5156e97143 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 16:16:25 +0000 Subject: [PATCH 23/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 10 +- transformer_engine/common/cast/cast.cu | 3 +- .../common/cast/dispatch/quantize_grouped.cuh | 3 +- .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 451 +++++++++--------- 4 files changed, 230 insertions(+), 237 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index fb9d6a9768..d3a137c56d 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -263,7 +263,7 @@ void performTest(const ProcessingMethod processing_method, std::vector out_data_colwise_h(colwise ? elts_num : 0); std::vector out_scales_rowwise_h(rowwise ? sfs_num : 0); std::vector out_scales_colwise_h(colwise ? sfs_num : 0); - + std::vector out_data_rowwise_ref(rowwise ? elts_num : 0); std::vector out_data_colwise_ref(colwise ? elts_num : 0); std::vector out_scales_rowwise_ref(rowwise ? sfs_num : 0); @@ -310,7 +310,7 @@ void performTest(const ProcessingMethod processing_method, cudaMalloc((void**)&first_dims_d, first_dims_size); cudaMalloc((void**)&last_dims_d, last_dims_size); cudaMalloc((void**)&offsets_d, offsets_size); - + cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice); cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); @@ -393,7 +393,7 @@ void performTest(const ProcessingMethod processing_method, OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + sfs_offset; fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; - + compute_ref( processing_method, OP, rowwise, colwise, in_ptr, /*grad=*/ nullptr, out_data_rowwise_ptr, out_data_colwise_ptr, @@ -518,7 +518,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { using namespace transformer_engine; using namespace test; - + const ProcessingMethod processing_method = std::get<0>(GetParam()); const ActivationKind activation = std::get<1>(GetParam()); const ScalingDirection scaling_direction = std::get<2>(GetParam()); @@ -536,7 +536,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { switch (shape_rep) { case SAME_BOTH_DIMS: { first_dims[t] = logical_shape[0] / num_tensors; - last_dims[t] = logical_shape[1]; + last_dims[t] = logical_shape[1]; break; } case VARYING_FIRST_DIM: { diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 8a185c22c2..46cf8b8127 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -27,7 +27,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea // dispatch::quantize_fwd_helper(input, output, nullptr, stream); } -void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { +void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_grouped); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh index 207811b395..cc3b886e67 100644 --- a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh +++ b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh @@ -59,7 +59,8 @@ void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTenso Tensor *workspace_tensor = convertNVTETensor(workspace); mxfp8::quantize_grouped( - input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); break; } default: diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index 644b94c8c8..a279b56ac0 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -27,7 +27,6 @@ namespace dispatch { namespace mxfp8 { namespace quantize_grouped_kernel { - constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; __device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; @@ -35,9 +34,9 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_T __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, + SAME_BOTH_DIMS = 0, VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, + VARYING_LAST_DIM = 2, VARYING_BOTH_DIMS = 3 }; @@ -71,25 +70,22 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 -__device__ __forceinline__ size_t -get_current_tensor_id(const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t current_offset, - const size_t first_logical_dim, - const size_t last_logical_dim, - const int64_t* const __restrict__ offsets_ptr) { +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t current_row = current_offset / last_logical_dim; - const size_t rows_per_tensor = first_logical_dim / num_tensors; + const size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; } else { size_t low = 0; - size_t hi = num_tensors; // [low, hi] - + size_t hi = num_tensors; // [low, hi] + while (low < hi) { const size_t mid = low + (hi - low) / 2; const size_t mid_offset = static_cast(offsets_ptr[mid]); - + if (mid_offset <= current_offset) { low = mid + 1; } else { @@ -100,93 +96,86 @@ get_current_tensor_id(const ShapeRepresentation shape_rep, } } -__device__ __forceinline__ size_t -get_tensor_rows_num(const size_t tensor_id, - const ShapeRepresentation shape_rep, - const size_t first_logical_dim, - const int64_t* const __restrict__ first_dims_ptr, - const size_t num_tensors) { +__device__ __forceinline__ size_t get_tensor_rows_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { size_t rows_num = 0; switch (shape_rep) { - case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break; - case ShapeRepresentation::VARYING_LAST_DIM: rows_num = first_logical_dim; break; + case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break; + case ShapeRepresentation::VARYING_LAST_DIM: + rows_num = first_logical_dim; + break; case ShapeRepresentation::VARYING_FIRST_DIM: - case ShapeRepresentation::VARYING_BOTH_DIMS: rows_num = static_cast(first_dims_ptr[tensor_id]); break; + case ShapeRepresentation::VARYING_BOTH_DIMS: + rows_num = static_cast(first_dims_ptr[tensor_id]); + break; } return rows_num; } -__device__ __forceinline__ size_t -get_tensor_cols_num(const size_t tensor_id, - const ShapeRepresentation shape_rep, - const size_t last_logical_dim, - const int64_t* const __restrict__ last_dims_ptr) { +__device__ __forceinline__ size_t get_tensor_cols_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { size_t cols_num = 0; switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: - case ShapeRepresentation::VARYING_FIRST_DIM: cols_num = last_logical_dim; break; + case ShapeRepresentation::VARYING_FIRST_DIM: + cols_num = last_logical_dim; + break; case ShapeRepresentation::VARYING_LAST_DIM: - case ShapeRepresentation::VARYING_BOTH_DIMS: cols_num = static_cast(last_dims_ptr[tensor_id]); break; + case ShapeRepresentation::VARYING_BOTH_DIMS: + cols_num = static_cast(last_dims_ptr[tensor_id]); + break; } return cols_num; } // Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index template -__device__ __forceinline__ void -modify_base_tensor_map(const CUtensorMap base_tensor_map, - CUtensorMap* global_tensor_map, - const uintptr_t global_data_ptr, - const size_t global_dim_Y, - const size_t global_dim_X) { +__device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, + CUtensorMap *global_tensor_map, + const uintptr_t global_data_ptr, + const size_t global_dim_Y, + const size_t global_dim_X) { const size_t global_stride_bytes = global_dim_X * sizeof(T); __shared__ CUtensorMap shared_tensor_map; shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem asm volatile( - "{\n\t" - ".reg.b64 tensor_map_ptr; \n\t" - "mov.b64 tensor_map_ptr, %0; \n\t" - "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" - "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y - "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X - "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" - "}\n" - :: "l"(reinterpret_cast(&shared_tensor_map)), - "l"(global_data_ptr), - "r"(static_cast(global_dim_Y)), - "r"(static_cast(global_dim_X)), - "l"(static_cast(global_stride_bytes)) - : "memory" - ); + "{\n\t" + ".reg.b64 tensor_map_ptr; \n\t" + "mov.b64 tensor_map_ptr, %0; \n\t" + "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X + "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" + "}\n" ::"l"(reinterpret_cast(&shared_tensor_map)), + "l"(global_data_ptr), "r"(static_cast(global_dim_Y)), + "r"(static_cast(global_dim_X)), "l"(static_cast(global_stride_bytes)) + : "memory"); *global_tensor_map = shared_tensor_map; -} +} template -__global__ void -update_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, - const __grid_constant__ CUtensorMap base_tensor_map_act_input, - const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, - const IType* const __restrict__ input_data_ptr, - const IType* const __restrict__ act_input_data_ptr, - const OType* const __restrict__ output_rowwise_data_ptr, - const OType* const __restrict__ output_colwise_data_ptr, - const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t first_logical_dim, - const size_t last_logical_dim, - const int64_t* const __restrict__ offsets_ptr, - const int64_t* const __restrict__ first_dims_ptr, - const int64_t* const __restrict__ last_dims_ptr, - const bool rowwise, - const bool colwise, - const bool compute_activations) { +__global__ void update_tma_descriptors( + const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType *const __restrict__ input_data_ptr, + const IType *const __restrict__ act_input_data_ptr, + const OType *const __restrict__ output_rowwise_data_ptr, + const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, + const bool compute_activations) { const bool leading_thread = (threadIdx.x == 0); const size_t tensor_id = blockIdx.x; - const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); const size_t offset_elts = offsets_ptr[tensor_id]; @@ -195,51 +184,49 @@ update_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input { const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], - global_data_ptr, rows, cols); + global_data_ptr, rows, cols); } if (compute_activations) { - const uintptr_t global_data_ptr = reinterpret_cast(act_input_data_ptr + offset_elts); + const uintptr_t global_data_ptr = + reinterpret_cast(act_input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], - global_data_ptr, rows, cols); + global_data_ptr, rows, cols); } if (rowwise) { - const uintptr_t global_data_ptr = reinterpret_cast(output_rowwise_data_ptr + offset_elts); - modify_base_tensor_map(base_tensor_map_output_rowwise, &g_tensor_maps_output_rowwise[tensor_id], - global_data_ptr, rows, cols); + const uintptr_t global_data_ptr = + reinterpret_cast(output_rowwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_rowwise, + &g_tensor_maps_output_rowwise[tensor_id], global_data_ptr, rows, + cols); } if (colwise) { - const uintptr_t global_data_ptr = reinterpret_cast(output_colwise_data_ptr + offset_elts); - modify_base_tensor_map(base_tensor_map_output_colwise, &g_tensor_maps_output_colwise[tensor_id], - global_data_ptr, rows, cols); + const uintptr_t global_data_ptr = + reinterpret_cast(output_colwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_colwise, + &g_tensor_maps_output_colwise[tensor_id], global_data_ptr, rows, + cols); } } } -__device__ __forceinline__ void -fence_acquire_tensormap(const CUtensorMap* tensor_map) { - asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(tensor_map)); +__device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) { + asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map)); } template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - quantize_grouped_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input_static, - const __grid_constant__ CUtensorMap tensor_map_act_input_static, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, - const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, - const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t first_logical_dim, - const size_t last_logical_dim, - const int64_t* const __restrict__ offsets_ptr, - const int64_t* const __restrict__ first_dims_ptr, - const int64_t* const __restrict__ last_dims_ptr, - e8m0_t *const __restrict__ scales_rowwise_ptr, - e8m0_t *const __restrict__ scales_colwise_ptr, - const float * __restrict__ noop, - float *const __restrict__ dbias_workspace, - float *const __restrict__ amax_ptr) { +__global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_grouped_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, + const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -260,7 +247,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, first_logical_dim, last_logical_dim, offsets_ptr); - const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); const size_t scale_stride_rowwise = cols / SCALE_DIM_X; const size_t scale_stride_colwise = cols; @@ -271,41 +259,44 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); const size_t offset_within_tensor = block_global_offset - tensor_base; - const CUtensorMap& tensor_map_input = is_single_tensor - ? tensor_map_input_static - : g_tensor_maps_input[tensor_id]; - const CUtensorMap& tensor_map_act_input = is_single_tensor - ? tensor_map_act_input_static - : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap& tensor_map_output_rowwise = is_single_tensor - ? tensor_map_output_rowwise_static - : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap& tensor_map_output_colwise = is_single_tensor - ? tensor_map_output_colwise_static - : g_tensor_maps_output_colwise[tensor_id]; + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = + is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = + is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; const bool leading_thread = (threadIdx.x == 0); if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { fence_acquire_tensormap(&tensor_map_act_input); } - if constexpr (ROWWISE_SCALING) { fence_acquire_tensormap(&tensor_map_output_rowwise); } - if constexpr (COLWISE_SCALING) { fence_acquire_tensormap(&tensor_map_output_colwise); } + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } } const size_t blocks_X_num_in_current_tensor = cols / CHUNK_DIM_X; - const size_t block_id_in_current_tensor = is_single_tensor - ? blockIdx.x - : (blockIdx.x - tensor_base / ELTS_PER_CHUNK); + const size_t block_id_in_current_tensor = + is_single_tensor ? blockIdx.x : (blockIdx.x - tensor_base / ELTS_PER_CHUNK); const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; const size_t block_offset_X = block_id_X * CHUNK_DIM_X; - - e8m0_t *const scales_rowwise = scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); - e8m0_t *const scales_colwise = scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; @@ -333,8 +324,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); constexpr size_t elt_input_mem = buff_size_aligned_in; constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); @@ -367,8 +360,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_amax = 0.0f; - // Initialize shared memory barrier with the number of threads participating in the barrier. - #pragma nv_diag_suppress static_var_with_dynamic_init +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[STAGES]; initialize_barriers(mbar, leading_thread); @@ -384,7 +377,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) &mbar[0], leading_thread); } - #pragma unroll +#pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t buff = stage % BUFFS_NUM; const size_t next_stage = stage + 1; @@ -426,7 +419,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { IType thread_amax_f16 = static_cast(0.0f); - #pragma unroll +#pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; in_colwise_IType[i] = in_sh[shmem_offset_colwise]; @@ -434,7 +427,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } thread_amax = static_cast(thread_amax_f16); } else { - #pragma unroll +#pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; @@ -475,8 +468,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - // 3. Scale elements - #pragma unroll +// 3. Scale elements +#pragma unroll for (int i = 0; i < SCALE_DIM_Y; ++i) { float in; if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { @@ -504,14 +497,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); } @@ -522,7 +515,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; @@ -533,15 +526,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries if constexpr (std::is_same_v) { - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); } } else { - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); } } @@ -551,7 +543,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } } else { - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; @@ -564,7 +556,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DACT) { act_in.load_from(&act_in_sh[shmem_offset_rowwise]); } - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; // Compute element @@ -602,11 +594,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - // 3. Scale elements - #pragma unroll +// 3. Scale elements +#pragma unroll for (int w = 0; w < WAVES; ++w) { Vec out; - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { IType2 in; OType2 &out_pair = reinterpret_cast(out.data.elt[e]); @@ -673,16 +665,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // WIDTH = THREADS_X * (SCALE_DIM_X + 1) // Added extra 1-element padding per thread_X to reduce bank conflicts float *partial_dbias_rowwise = reinterpret_cast(dshmem); - + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - + const int shmem_thread_offset = tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; const int shmem_elt_idx = swizzled_group_offset + e; @@ -690,7 +682,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } __syncthreads(); - #pragma unroll +#pragma unroll for (int i = 0; i < THREADS_Y; ++i) { // Add extra element offset per MXFP8 scaling block [1x32] const int scaling_block = threadIdx.x / SCALE_DIM_X; @@ -726,14 +718,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void quantize_grouped(const GroupedTensor *input, - const GroupedTensor *activations, - const Tensor *noop, - GroupedTensor *output, - Tensor *dbias, - Tensor *workspace, - cudaStream_t stream) -{ +void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activations, + const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { using namespace quantize_grouped_kernel; checkCuDriverContext(stream); @@ -764,27 +751,31 @@ void quantize_grouped(const GroupedTensor *input, // Treat a grouped tensor with const last dims as a single tensor const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - NVTE_CHECK(input->num_tensors == output->num_tensors, "Number of input and output tensors must be same."); + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); const size_t num_tensors = input->num_tensors; - NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, - "Number of tensors in a group is larger than the MAX number of supported descriptors (64)."); - + NVTE_CHECK( + num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + "Number of tensors in a group is larger than the MAX number of supported descriptors (64)."); + const size_t first_logical_dim = input->logical_shape.data[0]; const size_t last_logical_dim = input->logical_shape.data[1]; const size_t elts_total = first_logical_dim * last_logical_dim; // Logical shape of a tensor with varying all dims is [1, M*K] if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { - NVTE_CHECK(first_logical_dim % 128 == 0, "First dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(first_logical_dim % 128 == 0, + "First dimension of a grouped tensor should be divisible by 128."); } - NVTE_CHECK(last_logical_dim % 128 == 0, "Last dimension of a grouped tensor should be divisible by 128."); - + NVTE_CHECK(last_logical_dim % 128 == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); - + if (use_rowwise_scaling) { NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); } @@ -792,10 +783,10 @@ void quantize_grouped(const GroupedTensor *input, NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const int64_t* const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); - const int64_t* const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); - const int64_t* const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); - + const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + CheckNoopTensor(*noop, "cast_noop"); const size_t blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); @@ -803,13 +794,16 @@ void quantize_grouped(const GroupedTensor *input, const size_t block_size = THREADS_PER_CHUNK; const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { - NVTE_CHECK(is_single_tensor, "DBias is only supported for tensors with the const last dimension."); - NVTE_CHECK(dbias->data.dtype == input->dtype(), "DBias must have the same type as input_tensor."); + NVTE_CHECK(is_single_tensor, + "DBias is only supported for tensors with the const last dimension."); + NVTE_CHECK(dbias->data.dtype == input->dtype(), + "DBias must have the same type as input_tensor."); NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); @@ -824,8 +818,10 @@ void quantize_grouped(const GroupedTensor *input, float *const amax_ptr = reinterpret_cast(output->amax.dptr); const float *noop_ptr = reinterpret_cast(noop->data.dptr); - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype(), OType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; @@ -835,37 +831,33 @@ void quantize_grouped(const GroupedTensor *input, constexpr size_t input_type_bit_size = TypeInfo::size; constexpr size_t output_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_input, input->data, - first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, input_type_bit_size); + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, - first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, input_type_bit_size); + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); } if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, - first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); } if (use_colwise_scaling) { create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, output_type_bit_size); - } - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + } constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); constexpr size_t elt_input_mem = buff_size_aligned_in; constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); @@ -877,62 +869,61 @@ void quantize_grouped(const GroupedTensor *input, const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - auto kernel = quantize_grouped_mxfp8_kernel - ; + auto kernel = quantize_grouped_mxfp8_kernel; switch (scaling_type) { case ScalingType::ROWWISE: { - kernel = quantize_grouped_mxfp8_kernel - ; + kernel = quantize_grouped_mxfp8_kernel; break; } case ScalingType::COLWISE: { - kernel = quantize_grouped_mxfp8_kernel - ; + kernel = quantize_grouped_mxfp8_kernel; break; } case ScalingType::BIDIMENSIONAL: { - kernel = quantize_grouped_mxfp8_kernel - ; + kernel = quantize_grouped_mxfp8_kernel; break; } } - // Update tensor descriptors before launching the kernel + // Update tensor descriptors before launching the kernel if (!is_single_tensor) { - const IType* const input_dptr = reinterpret_cast(input->data.dptr); - - const IType* const act_input_dptr = (IS_DACT || IS_ACT) - ? reinterpret_cast(activations->data.dptr) - : nullptr; - - OType* const output_rowwise_dptr = use_rowwise_scaling - ? reinterpret_cast(output->data.dptr) - : nullptr; - - OType* const output_colwise_dptr = use_colwise_scaling - ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - update_tma_descriptors<<>> - (tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, - input_dptr, act_input_dptr, output_rowwise_dptr, output_colwise_dptr, - shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, use_colwise_scaling, IS_ACT); + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + (IS_DACT || IS_ACT) ? reinterpret_cast(activations->data.dptr) + : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, + use_colwise_scaling, IS_ACT); } - - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size)); kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, - shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, - scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - } - NVTE_CHECK_CUDA(cudaGetLastError()); - ); // NOLINT(*) - ); // NOLINT(*) + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8