diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 6ca99c7f963..693506929d9 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -49,18 +49,7 @@ dense | sparsegpt) ;; ;; esac -#Iterate over list of qformats provided and check if they are valid -IFS="," -for qformat in $QFORMAT; do - case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;; - *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2 - exit 1 - ;; - esac -done -IFS=" " +# Quant format / recipe validation is delegated to hf_ptq.py. script_dir="$(dirname "$(readlink -f "$0")")" @@ -72,7 +61,14 @@ fi QFORMAT_MODIFIED="${QFORMAT//,/_}" -MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}} +# When using --recipe, build the model name from the recipe basename (without +# directory or .yaml suffix) so each recipe gets its own SAVE_PATH. +if [ -n "$RECIPE" ]; then + RECIPE_TAG=$(basename "$RECIPE" .yaml | sed 's/[^0-9a-zA-Z\-]/_/g') + MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_recipe_${RECIPE_TAG} +else + MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}} +fi SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME} @@ -177,11 +173,16 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH if [[ "$MODEL_CONFIG_EXIST" == false ]]; then echo "Quantizing original model..." + if [ -n "$RECIPE" ]; then + QUANT_SPEC_ARGS="--recipe=$RECIPE" + else + QUANT_SPEC_ARGS="--qformat=${QFORMAT// /,}" + fi python hf_ptq.py \ --pyt_ckpt_path=$MODEL_PATH \ --export_path=$SAVE_PATH \ --sparsity_fmt=$SPARSITY_FMT \ - --qformat="${QFORMAT// /,}" \ + $QUANT_SPEC_ARGS \ --calib_size=$CALIB_SIZE \ --batch_size=$CALIB_BATCH_SIZE \ --inference_tensor_parallel=$TP \ @@ -203,7 +204,7 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH exit 0 fi - if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then + if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]] || [[ "$RECIPE" == *"nvfp4"* ]]; then cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1) if [ "$cuda_major" -lt 10 ]; then @@ -212,6 +213,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH fi fi + if [ -n "$RECIPE" ]; then + echo "Recipe $RECIPE used. Please deploy with TensorRT-LLM directly. Checkpoint export_path: $SAVE_PATH" + exit 0 + fi + if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then echo "Quant $QFORMAT specified. Please read TensorRT-LLM quantization support matrix https://nvidia.github.io/TensorRT-LLM/features/quantization.html#quantization-in-tensorrt-llm and use TensorRT-LLM for deployment. Checkpoint export_path: $SAVE_PATH" exit 0 diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 3817c1dee7c..2a9a28b3566 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -20,6 +20,7 @@ parse_options() { # Default values MODEL_PATH="" QFORMAT="" + RECIPE="" KV_CACHE_QUANT="" TP=1 PP=1 @@ -37,13 +38,14 @@ parse_options() { CAST_MXFP4_TO_NVFP4=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") eval set -- "$ARGS" while true; do case "$1" in --model ) MODEL_PATH="$2"; shift 2;; --quant ) QFORMAT="$2"; shift 2;; + --recipe ) RECIPE="$2"; shift 2;; --kv_cache_quant ) KV_CACHE_QUANT="$2"; shift 2;; --tp ) TP="$2"; shift 2;; --pp ) PP="$2"; shift 2;; @@ -99,12 +101,19 @@ parse_options() { fi # Verify required options are provided - if [ -z "$MODEL_PATH" ] || [ -z "$QFORMAT" ] || [ -z "$TASKS" ]; then - echo "Usage: $0 --model= --quant= --tasks=" + if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || ([ -z "$QFORMAT" ] && [ -z "$RECIPE" ]); then + echo "Usage: $0 --model= (--quant= | --recipe=) --tasks=" echo "Optional args: --sparsity= --awq_block_size= --calib=" exit 1 fi + # --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec, while + # --quant selects a built-in qformat preset. Pick exactly one. + if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ]; then + echo "Cannot specify both --quant and --recipe; pick one." >&2 + exit 1 + fi + VALID_TASKS=("quant" "mmlu" "lm_eval" "livecodebench" "simple_eval") for task in $(echo "$TASKS" | tr ',' ' '); do @@ -135,6 +144,7 @@ parse_options() { echo "=================" echo "model: $MODEL_PATH" echo "quant: $QFORMAT" + echo "recipe: $RECIPE" echo "tp (TensorRT-LLM Checkpoint only): $TP" echo "pp (TensorRT-LLM Checkpoint only): $PP" echo "sparsity: $SPARSITY_FMT" diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index 39b07b4faa9..70f729cffb0 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -32,6 +32,7 @@ # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * from .fp8_kernel import * + from .nvfp4_fp8_sweep import * # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py new file mode 100644 index 00000000000..4b9f19837f2 --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. + +Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single +kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates +and emits the per-block ``best_amax`` directly. + +The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see +:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on +the per-block scale is the identity, so the kernel can use +``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it +runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). + +Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. +""" + +import torch +import triton +import triton.language as tl + +from .nvfp4_quant import fp4_round_magnitude + +__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] + + +def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: + """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + return fp8_values[valid_mask] / 448.0 + + +# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: +# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms +# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 +# would underfill the SMs. +_FP8_SWEEP_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), + triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), + triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), +] + + +@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) +@triton.jit +def _fp8_scale_sweep_kernel( + x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) + candidates_ptr, # [NUM_CANDIDATES] fp32 + global_amax_ptr, # scalar fp32 + best_amax_ptr, # [N_BLOCKS] fp32 output + N_BLOCKS, + BLOCK_SIZE: tl.constexpr, + NUM_CANDIDATES: tl.constexpr, + BLOCKS_PER_PROGRAM: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCKS_PER_PROGRAM + block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) + block_mask = block_idx < N_BLOCKS + + # Load weights for this tile and pre-compute their absolute values once. + # The squared error is sign-invariant since FP4 quant preserves sign: + # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 + # so we never need ``w`` itself again, dropping a tl.where + negation per element. + elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + elem_mask = block_mask[:, None] + w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) + + global_amax = tl.load(global_amax_ptr).to(tl.float32) + + best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) + best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) + + # Loop over the 126 FP8 candidates (compile-time unrolled). + # Scales are guaranteed positive and finite (constructed from a positive candidate + # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is + # unnecessary apart from the global_amax == 0 case handled below. + for k in tl.static_range(NUM_CANDIDATES): + c = tl.load(candidates_ptr + k).to(tl.float32) + scale = c * global_amax / 6.0 + # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is + # the same for every candidate, so any best_idx is fine. + scale_safe = tl.where(scale == 0.0, 1.0, scale) + q_mag = fp4_round_magnitude(w_abs / scale_safe) + diff = w_abs - q_mag * scale + loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] + is_better = loss < best_loss + best_loss = tl.where(is_better, loss, best_loss) + best_idx = tl.where(is_better, k, best_idx) + + # Map each block's winning candidate index back to its amax = global_amax * c[best]. + best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) + best_amax = global_amax * best_c + tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) + + +def nvfp4_fp8_scale_sweep( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, + candidates: torch.Tensor | None = None, +) -> torch.Tensor: + """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. + + Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into + a single Triton kernel: every block's weight elements are loaded once, all 126 + candidates are evaluated in registers, and the running argmin is kept inline. + + Args: + x: Weight tensor on CUDA. Total element count must be divisible by + ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. + global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). + block_size: NVFP4 block size (typically 16). + candidates: Optional precomputed candidate tensor of shape ``[126]`` (must + be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. + + Returns: + ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. + """ + if not x.is_cuda: + raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") + if not isinstance(block_size, int) or block_size <= 0: + raise ValueError(f"block_size must be a positive int, got {block_size!r}.") + if x.numel() % block_size != 0: + raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") + + if candidates is None: + candidates = fp8_scale_candidates(x.device) + candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) + if candidates.ndim != 1 or candidates.numel() == 0: + raise ValueError( + f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}." + ) + + n_blocks = x.numel() // block_size + x_flat = x.contiguous().view(-1) + global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) + best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) + with torch.cuda.device(x.device): + _fp8_scale_sweep_kernel[grid]( + x_flat, + candidates, + global_amax_f32, + best_amax, + n_blocks, + BLOCK_SIZE=block_size, + NUM_CANDIDATES=int(candidates.numel()), + ) + return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 1f439a7e778..fff0b8af1b2 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -24,7 +24,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -192,9 +192,100 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: return torch.ones_like(self._initial_amax) * self._global_amax * candidates def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates.""" + """Generate 126 valid FP8 E4M3 scale candidates. + + Kept in sync with ``fp8_scale_candidates`` in + ``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3 + spec is fixed, and the parity test exercises both paths against each other. + """ uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) fp8_values = uint8_values.view(torch.float8_e4m3fn).float() valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) fp8_values = fp8_values[valid_mask] return fp8_values / 448.0 + + +class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): + """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. + + Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 + candidates in a single fused Triton kernel — one weight read instead of 126. + + Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. + This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where + the calibrator is collected once per weight and immediately consumed. For + activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. + Call :meth:`reset` to free internal state and re-enable :meth:`collect`. + """ + + def __init__( + self, + amax: torch.Tensor, + global_amax: torch.Tensor, + axis: int | tuple | list | None = None, + quant_func: Callable | None = None, + error_func: Callable | None = None, + ): + """Initialize the Triton-fused NVFP4 MSE calibrator. + + See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by + the kernel path but accepted for API parity. Tile shape and ``num_warps`` are + autotuned by the kernel per ``N_BLOCKS``. + """ + super().__init__( + amax=amax, + global_amax=global_amax, + axis=axis, + quant_func=quant_func, + error_func=error_func, + ) + # Stash shape metadata so collect() can keep working after reset() releases + # the (potentially large) _initial_amax buffer. + self._initial_amax_shape = tuple(amax.shape) + self._initial_amax_dtype = amax.dtype + self._n_blocks = int(amax.numel()) + self._best_amax: torch.Tensor | None = None + + @torch.no_grad() + def collect(self, x: torch.Tensor): + """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + if self._best_amax is not None: + raise RuntimeError( + "TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to " + "discard the previous result before collecting again." + ) + + x = x.detach() + # The weight quantizer reshapes its input to [n_blocks, block_size] before + # calling collect (see TensorQuantizer._process_for_blockquant). + assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." + block_size = x.shape[-1] + n_blocks = x.numel() // block_size + if n_blocks != self._n_blocks: + raise ValueError( + f"initial amax.numel() ({self._n_blocks}) does not match the number " + f"of NVFP4 blocks in x ({n_blocks})." + ) + + best_amax_flat = nvfp4_fp8_scale_sweep( + x, + self._global_amax, + block_size=block_size, + ) + # Match the original shape/dtype of the initial amax so downstream + # load_calib_amax behaves identically to the reference path. + self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to( + self._initial_amax_dtype + ) + + @torch.no_grad() + def compute_amax(self, verbose: bool = False): + """Return the per-block amax computed during ``collect``.""" + return self._best_amax + + def reset(self): + """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" + self._best_amax = None + super().reset() diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 4ce0f62a75d..cd86ff1c72e 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,7 @@ """Calibration utilities.""" import math +import os import time import warnings from collections.abc import Callable @@ -37,7 +38,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -354,6 +355,11 @@ def mse_calibrate( weight_quantizers = [] seen_modules = set() + # Triton-fused FP8 sweep is on by default for NVFP4 static quant; set + # MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging. + use_triton_fp8_sweep = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" + nvfp4_calibrator_cls = TritonNVFP4MSECalibrator if use_triton_fp8_sweep else NVFP4MSECalibrator + for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): @@ -391,8 +397,7 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with NVFP4MSECalibrator - module._calibrator = NVFP4MSECalibrator( + module._calibrator = nvfp4_calibrator_cls( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml new file mode 100644 index 00000000000..749a875b368 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + NVFP4 W4A4 for expert layers only with MSE-search FP8 scale calibration on + expert weights, FP8 KV cache with constant amax (skips KV calibration; amax + hardcoded to FP8 E4M3 max 448.0). Expert weight quantizers are static + (per-block amax fixed by the MSE FP8-scale sweep); input quantizers remain + dynamic. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + use_constant_amax: true + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml new file mode 100644 index 00000000000..bca46608d59 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + NVFP4 W4A4 for all MLP layers (dense + MoE) with MSE-search FP8 scale + calibration on MLP weights, FP8 KV cache with constant amax (skips KV + calibration; amax hardcoded to FP8 E4M3 max 448.0). MLP weight quantizers + are static (per-block amax fixed by the MSE FP8-scale sweep); input + quantizers remain dynamic. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*mlp*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + use_constant_amax: true + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py new file mode 100644 index 00000000000..c25867d8327 --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity + speedup tests for the fused NVFP4 FP8 scale sweep Triton kernel. + +Compares :class:`TritonNVFP4MSECalibrator` against the reference +:class:`NVFP4MSECalibrator` on the same inputs and asserts the resulting per-block +amax tensors are bit-identical. Also reports a wall-clock speedup number for the +weight-MSE search step on a representative LLM-sized weight. +""" + +import time + +import pytest +import torch +from conftest import requires_triton + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator, TritonNVFP4MSECalibrator +from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant + +BLOCK_SIZE = 16 + + +def _reference_quant_func(global_amax): + """Reference NVFP4 fake-quant matching what ``mse_calibrate`` plumbs in.""" + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + return quant_func + + +def _run_reference(x, per_block_amax, global_amax): + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +def _run_triton(x, per_block_amax, global_amax): + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +@requires_triton +@pytest.mark.parametrize("seed", [0, 1, 2]) +@pytest.mark.parametrize("num_blocks", [4, 64, 1024]) +def test_parity_random_weights(seed, num_blocks): + """Triton sweep must produce the exact same per-block amax as the reference.""" + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + + assert ref.shape == tri.shape + # Both pick from the same 126-element discrete candidate set, so any disagreement + # would show up as a non-zero diff (not a small float epsilon). Demand exact match. + assert torch.equal(ref, tri), ( + f"Triton sweep diverged from reference: max |diff| = " + f"{(ref - tri).abs().max().item():.3e}, " + f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}" + ) + + +@requires_triton +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_parity_dtypes(dtype): + """Sweep must agree across the dtypes supported by the NVFP4 quantizer.""" + torch.manual_seed(42) + device = "cuda" + num_blocks = 256 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) + # Promote to fp32 for the per-block amax (matches what max_calibrate produces). + per_block_amax = x.float().abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + assert torch.equal(ref, tri) + + +@requires_triton +def test_quantized_output_matches(): + """Round-tripping x through the chosen amax should give the same fake-quant result.""" + torch.manual_seed(7) + device = "cuda" + num_blocks = 128 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + assert torch.equal(ref_xq, tri_xq) + + +@requires_triton +def test_reset_allows_recollect(): + torch.manual_seed(0) + device = "cuda" + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + ) + cal.collect(x) + first = cal.compute_amax().clone() + + # collect() is one-shot per cycle until reset() is called. + with pytest.raises(RuntimeError, match="one-shot"): + cal.collect(x) + + cal.reset() + # After reset, the same calibrator instance can be re-used. + cal.collect(x) + assert torch.equal(first, cal.compute_amax()) + + +@requires_triton +def test_input_validation(): + """``nvfp4_fp8_scale_sweep`` should reject malformed inputs cleanly.""" + from modelopt.torch.kernels.quantization.gemm import fp8_scale_candidates, nvfp4_fp8_scale_sweep + + device = "cuda" + x = torch.randn(64, BLOCK_SIZE, device=device) + g = x.abs().amax() + + # CPU tensor → ValueError (not bare AssertionError). + with pytest.raises(ValueError, match="CUDA"): + nvfp4_fp8_scale_sweep(x.cpu(), g.cpu()) + + # block_size <= 0. + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=0) + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=-1) + + # Non-divisible numel. + with pytest.raises(ValueError, match="not divisible"): + nvfp4_fp8_scale_sweep(x, g, block_size=15) + + # Empty / wrong-rank candidates. + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=torch.empty(0, device=device)) + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=fp8_scale_candidates(device).reshape(2, -1)) + + +def _bench(fn, warmup=2, iters=5): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters + + +@requires_triton +def test_speedup_report(capsys): + """Sanity-check that the Triton path is meaningfully faster on a realistic weight. + + Uses an 8192 x 4096 weight (~33M elements, ~2M NVFP4 blocks) — roughly the size + of an LLM attention/MLP projection. Reports the speedup; does not gate on a + minimum factor (kernel timing is noisy on shared CI), but does require parity + on the chosen amax. + """ + torch.manual_seed(123) + device = "cuda" + cout, cin = 8192, 4096 + x = torch.randn(cout, cin // BLOCK_SIZE, BLOCK_SIZE, device=device, dtype=torch.float32) + x = x.reshape(-1, BLOCK_SIZE) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + # Bit-equality across millions of blocks isn't guaranteed: when two adjacent FP8 + # candidates yield near-identical per-block MSE (within fp32 noise), the reference's + # CUDA fake_e4m3fy path and our Triton inline math can break ties differently. Demand + # instead that the Triton choice produces a per-block MSE within fp32 epsilon of the + # reference's choice. + n_blocks = ref_amax.numel() + n_diff = int((ref_amax != tri_amax).sum()) + if n_diff: + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + per_block_mse_ref = (x - ref_xq).pow(2).sum(dim=-1) + per_block_mse_tri = (x - tri_xq).pow(2).sum(dim=-1) + # Reference is the formal argmin, so triton's loss should be ≥ reference's. + # Allow at most 1e-5 relative gap on differing blocks (observed ~1e-7 in practice). + rel_gap = (per_block_mse_tri - per_block_mse_ref).abs() / per_block_mse_ref.clamp_min(1e-12) + worst = rel_gap.max().item() + assert worst < 1e-5, ( + f"{n_diff}/{n_blocks} blocks disagree with worst relative MSE gap {worst:.3e} " + "— exceeds tie-break tolerance" + ) + + ref_t = _bench(lambda: _run_reference(x, per_block_amax, global_amax)) + tri_t = _bench(lambda: _run_triton(x, per_block_amax, global_amax)) + speedup = ref_t / tri_t + + # Force-print regardless of pytest capture mode. + with capsys.disabled(): + n_blocks = x.numel() // BLOCK_SIZE + print( + f"\n[NVFP4 FP8 sweep] weight=({cout},{cin}) " + f"n_blocks={n_blocks} block_size={BLOCK_SIZE}\n" + f" reference NVFP4MSECalibrator: {ref_t * 1e3:8.2f} ms\n" + f" triton TritonNVFP4MSECalibrator: {tri_t * 1e3:8.2f} ms\n" + f" speedup: {speedup:.1f}x" + )