diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py index 4d9bc6fc29..63ac339348 100644 --- a/modelopt/torch/export/plugins/hf_checkpoint_utils.py +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -22,9 +22,21 @@ import torch from huggingface_hub import snapshot_download +from huggingface_hub.errors import LocalEntryNotFoundError from safetensors.torch import safe_open from tqdm import tqdm +_HF_HUB_OFFLINE_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} + + +def _is_hf_hub_offline() -> bool: + return os.environ.get("HF_HUB_OFFLINE", "").strip().upper() in _HF_HUB_OFFLINE_TRUE_VALUES + + +def _copy_python_files(source_dir: Path, save_dir: Path) -> None: + for py_file in source_dir.glob("*.py"): + shutil.copy2(py_file, save_dir / py_file.name) + def copy_hf_ckpt_remote_code( pretrained_model_path: str | os.PathLike, save_directory: str | os.PathLike @@ -36,7 +48,10 @@ def copy_hf_ckpt_remote_code( frameworks. If ``pretrained_model_path`` is a local directory, Python files are copied directly. - If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), files are downloaded from the Hub. + If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), the Hub + snapshot is resolved first and Python files are copied from that snapshot. When + ``HF_HUB_OFFLINE`` is set, the snapshot must already be available in the local + Hugging Face cache. Args: pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID. @@ -47,14 +62,28 @@ def copy_hf_ckpt_remote_code( save_dir.mkdir(parents=True, exist_ok=True) if hf_checkpoint_path.is_dir(): - for py_file in hf_checkpoint_path.glob("*.py"): - shutil.copy2(py_file, save_dir / py_file.name) + _copy_python_files(hf_checkpoint_path, save_dir) else: - snapshot_download( - repo_id=str(pretrained_model_path), - local_dir=str(save_dir), - allow_patterns=["*.py"], - ) + local_files_only = _is_hf_hub_offline() + try: + source_dir = Path( + snapshot_download( + repo_id=str(pretrained_model_path), + allow_patterns=["*.py"], + local_files_only=local_files_only, + ) + ) + except LocalEntryNotFoundError as exc: + if local_files_only: + raise RuntimeError( + f"Could not copy Python sidecar files for {pretrained_model_path!r} because " + "HF_HUB_OFFLINE is enabled and the files are not available in the local " + "Hugging Face cache. Populate the cache with the model's *.py files or pass " + "a local pretrained model directory." + ) from exc + raise + + _copy_python_files(source_dir, save_dir) def load_multimodal_components( diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 23b8cfd163..862e2031e2 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -61,6 +61,7 @@ get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, + process_layer_quant_config, to_quantized_weight, ) @@ -169,6 +170,7 @@ def __init__( self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] self.exclude_modules = [] + self.layer_config_dict = {} if not hasattr(model, "_modelopt_state"): return @@ -324,22 +326,32 @@ def save_pretrained( print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors") combined_exclude_modules = self._gather_exclude_modules() + combined_layer_config_dict = self._gather_layer_config_dict() if is_last_stage_main_rank and quantization is not None: - self._hf_quant_config = { + if combined_layer_config_dict: + quantization_config = process_layer_quant_config(combined_layer_config_dict) + quantization_config["exclude_modules"] = combined_exclude_modules + else: + quantization_config = { + "quant_algo": quantization, + "exclude_modules": combined_exclude_modules, + } + if quantization == "NVFP4": # update block size + quantization_config["group_size"] = 16 + + if hasattr(self, "kv_cache_dtype"): + quantization_config["kv_cache_quant_algo"] = self.kv_cache_dtype + + raw_hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, }, - "quantization": { - "quant_algo": quantization, - "exclude_modules": combined_exclude_modules, - }, + "quantization": quantization_config, } - if quantization == "NVFP4": # update block size - self._hf_quant_config["quantization"]["group_size"] = 16 - if hasattr(self, "kv_cache_dtype"): - self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype + # Use one serving-facing config for both hf_quant_config.json and config.json. + self._hf_quant_config = convert_hf_quant_config_format(raw_hf_quant_config) with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(self._hf_quant_config, f, indent=4) @@ -359,10 +371,9 @@ def save_pretrained( # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" if self._hf_quant_config and os.path.exists(config_json_file): - converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) with open(config_json_file) as f: config_dict = json.load(f) - config_dict["quantization_config"] = converted_quant_config + config_dict["quantization_config"] = self._hf_quant_config with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) @@ -803,9 +814,7 @@ def _get_quantized_state( name_to_value = {} qformat: str = self._get_quantization_format(module) if qformat is None and "norm" not in prefix: - # Add exclude layers for hf_quant_config. Note that if the prefix is not an empty - # string then it usually ends with "." which needs to be removed. - self.exclude_modules.append(prefix.removesuffix(".")) + self._record_excluded_module(prefix) block_size = get_weight_block_size(module) name_to_value = self._get_weight_bias(module, dtype, name_to_value) @@ -850,6 +859,27 @@ def _get_weight_scales(self, quantized_state: dict[str, Any], qformat: str): return weight_scale, weight_scale_2 + def _record_layer_quant_config(self, prefix: str, qformat: str | None, block_size: int): + """Record per-HF-layer quantization metadata for mixed precision exports.""" + if qformat in (None, QUANTIZATION_NONE): + return + + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + self.layer_config_dict[layer_name + ".quantization"] = qformat + self.layer_config_dict[layer_name + ".awq_block_size"] = block_size + + def _record_excluded_module(self, prefix: str): + """Record an unquantized HF module prefix for hf_quant_config.""" + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + if layer_name not in self.exclude_modules: + self.exclude_modules.append(layer_name) + def _name_remapping( self, module: torch.nn.Module | torch.Tensor, @@ -866,6 +896,7 @@ def _name_remapping( return name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) + self._record_layer_quant_config(prefix, qformat, block_size) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -906,6 +937,8 @@ def _gated_mlp_slicing( gate_proj_prefix = prefix + gate_proj_name + "." up_proj_prefix = prefix + up_proj_name + "." + self._record_layer_quant_config(gate_proj_prefix, qformat, block_size) + self._record_layer_quant_config(up_proj_prefix, qformat, block_size) ffn_hidden_size = module.config.ffn_hidden_size gate_proj_weight = weight[:ffn_hidden_size, :] @@ -986,6 +1019,7 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): for expert_id in range(num_experts): expert_prefix = prefix.format(expert_id) + "." + self._record_layer_quant_config(expert_prefix, qformat, block_size) weight_key = f"weight{expert_id}" if weight_key not in state_dict: @@ -1030,6 +1064,18 @@ def _qkv_slicing( q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." v_proj_prefix = prefix + v_proj_name + "." + self._record_layer_quant_config(q_proj_prefix, qformat, block_size) + self._record_layer_quant_config(k_proj_prefix, qformat, block_size) + self._record_layer_quant_config(v_proj_prefix, qformat, block_size) + if qformat in (None, QUANTIZATION_NONE): + # MCore stores Q/K/V as one fused linear_qkv module, but HF exports them + # as separate q_proj/k_proj/v_proj modules. Record the HF names so + # runtime quant configs do not miss excluded fused-QKV projections. + fused_prefix = prefix.removesuffix(".") + self.exclude_modules = [m for m in self.exclude_modules if m != fused_prefix] + self._record_excluded_module(q_proj_prefix) + self._record_excluded_module(k_proj_prefix) + self._record_excluded_module(v_proj_prefix) config = module.config hidden_size = config.hidden_size @@ -1179,6 +1225,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): weight_scale_list.append(weight_scale) weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1247,6 +1294,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) bias_list.append(bias) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1349,6 +1397,19 @@ def _gather_exclude_modules(self): combined_exclude_modules.update(modules) return sorted(combined_exclude_modules) + def _gather_layer_config_dict(self): + """Get per-layer quantization metadata from all ranks for hf_quant_config.""" + if not torch.distributed.is_initialized(): + return dict(sorted(self.layer_config_dict.items())) + + all_layer_config_dicts = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_layer_config_dicts, self.layer_config_dict) + combined_layer_config_dict = {} + for layer_config_dict in all_layer_config_dicts: + if layer_config_dict: + combined_layer_config_dict.update(layer_config_dict) + return dict(sorted(combined_layer_config_dict.items())) + def export_mcore_gpt_to_hf( model: torch.nn.Module, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0c2033041d..69f4974ef7 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -110,26 +110,44 @@ def _has_expert_parallelism(module: nn.Module) -> bool: return ps is not None and ps.expert_model_parallel_group.is_initialized() -def _check_moe_calibration_complete(quantizer, parallel_state): - """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" +def _is_dynamic_block_quantizer(quantizer) -> bool: + block_sizes = getattr(quantizer, "block_sizes", None) + if isinstance(block_sizes, dict): + return block_sizes.get("type") == "dynamic" + return getattr(block_sizes, "type", None) == "dynamic" + + +def _iter_leaf_quantizers(quantizer): if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - _check_moe_calibration_complete(_q, parallel_state) + yield from _iter_leaf_quantizers(_q) return - for group in [ - parallel_state.data_parallel_group, - parallel_state.expert_model_parallel_group, - parallel_state.tensor_parallel_group, - ]: - if not group.is_initialized(): + yield quantizer + + +def _check_moe_calibration_complete(quantizer, parallel_state): + """Raise error if MoE calibration is incomplete across distributed MoE ranks.""" + for leaf_quantizer in _iter_leaf_quantizers(quantizer): + if _is_dynamic_block_quantizer(leaf_quantizer): continue - has_amax = getattr(quantizer, "_amax", None) is not None - amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) - if any(amax_states) and not all(amax_states): - raise RuntimeError( - "MoE calibration incomplete: some experts received no tokens during calibration. " - "Increase --calib-size to ensure all experts see calibration data." + + has_amax = getattr(leaf_quantizer, "_amax", None) is not None + for group in [ + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ]: + if not group.is_initialized(): + continue + amax_states = DistributedProcessGroup.get_dist_syncd_obj( + has_amax, group, lambda objs: objs ) + if any(amax_states) and not all(amax_states): + raise RuntimeError( + "MoE calibration incomplete: some experts received no tokens during " + "calibration. Increase --calib-size to ensure all experts see calibration " + "data." + ) @torch.no_grad() @@ -175,13 +193,13 @@ def max_calibrate( def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" - if isinstance(quantizer, SequentialQuantizer): - for _q in quantizer: - sync_quantizer_amax_across_dp_ep(_q, parallel_state) - return - if getattr(quantizer, "_amax", None) is not None: - quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) - quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + for leaf_quantizer in _iter_leaf_quantizers(quantizer): + if _is_dynamic_block_quantizer(leaf_quantizer): + continue + leaf_quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + leaf_quantizer.sync_amax_across_distributed_group( + parallel_state.expert_model_parallel_group + ) # TODO: create sync_bias_across_distributed_group # Step 2:Sync amax across data parallelism @@ -226,7 +244,7 @@ def sync_quantizer_amax_across_tp( ) # Skip amax sync for INT4 / W4A8 block quantization # Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale) - if getattr(quantizer.block_sizes, "type", None) == "dynamic": + if _is_dynamic_block_quantizer(quantizer): return if quantizer.axis in axes_for_sync and quantizer.amax is not None: diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 4200aadc73..139ff36a6a 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -24,7 +24,7 @@ from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModule, SequentialQuantizer, TensorQuantizer +from ..nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import _QuantLinear from ..utils import multi_context, replace_function @@ -126,7 +126,7 @@ def modelopt_post_restore(self, prefix: str = ""): def _check_unsupported_states(quantizer: TensorQuantizer): for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: + if k not in ["_amax", "_pre_quant_scale", "_global_amax"]: warnings.warn( f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " f"incorrect. Please implement a custom restore for {k}." @@ -137,6 +137,21 @@ def _has_state(quantizer, name): quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer return hasattr(quantizer, name) + def _has_complete_static_nvfp4_weight_state(quantizer, weight): + quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer + if not isinstance(quantizer, NVFP4StaticQuantizer): + return False + amax = getattr(quantizer, "_amax", None) + global_amax = getattr(quantizer, "global_amax", None) + if amax is None or global_amax is None: + return False + block_sizes = getattr(quantizer, "block_sizes", None) + block_size = block_sizes.get(-1) if isinstance(block_sizes, dict) else None + if block_size is None or weight.shape[-1] % block_size != 0: + return False + expected_blocks = weight.numel() // block_size + return amax.numel() == expected_blocks and global_amax.numel() == 1 + if self.weight is None: return @@ -144,7 +159,11 @@ def _has_state(quantizer, name): _check_unsupported_states( quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] ) - if _has_state(self.weight_quantizer, "_amax"): + # Static NVFP4 weight _amax is saved calibration state. Preserve it when it + # matches the local shard; recalibrating would replace MSE scales with max scales. + if _has_state( + self.weight_quantizer, "_amax" + ) and not _has_complete_static_nvfp4_weight_state(self.weight_quantizer, self.weight): self.weight_quantizer.reset_amax() max_calibrate(self.weight_quantizer, lambda wq: wq(self.weight), distributed_sync=False) if _has_state(self.input_quantizer, "_pre_quant_scale"): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 0b50fd937a..493dd9bd21 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -40,7 +40,7 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer +from ..nn import NVFP4StaticQuantizer, QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from ..utils import sync_moe_expert_amax @@ -190,7 +190,12 @@ def quant_module_set_extra_state(self, state: Any): if quantizer_state is not None: for name, module in self.named_modules(): if isinstance(module, TensorQuantizer): - module.set_from_modelopt_state(quantizer_state[name], properties_only=False) + quantizer_substate = quantizer_state[name] + if quantizer_substate.get("_is_nvfp4_static_quantizer") and not isinstance( + module, NVFP4StaticQuantizer + ): + NVFP4StaticQuantizer.from_tensor_quantizer(module) + module.set_from_modelopt_state(quantizer_substate, properties_only=False) self.modelopt_post_restore() # Handle real_quantizer_state and q_tensor_state @@ -399,6 +404,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: + # Static NVFP4 _global_amax is a replicated scalar; only per-block _amax shards. + if k.endswith("_global_amax"): + continue if "weight_quantizer." in k: weight_quantizer_axis = self.get_submodule(k.rsplit(".", 1)[0]).axis if weight_quantizer_axis is not None: @@ -427,6 +435,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: + # Static NVFP4 _global_amax is a replicated scalar; only per-block _amax shards. + if k.endswith("_global_amax"): + continue if "weight_quantizer." in k: weight_quantizer_axis = None if isinstance(self.weight_quantizer, TensorQuantizer): diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index fe30e283c2..8cb0b66e45 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -55,7 +55,16 @@ def get_e2m1_bounds(cls, device): @classmethod def _is_static_quantizer(cls, weight_quantizer) -> bool: """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax.""" - return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None + global_amax = cls._get_static_global_amax(weight_quantizer) + return global_amax is not None + + @classmethod + def _get_static_global_amax(cls, weight_quantizer): + """Return global amax from live or restored static NVFP4 quantizers.""" + global_amax = getattr(weight_quantizer, "global_amax", None) + if global_amax is None: + global_amax = getattr(weight_quantizer, "_global_amax", None) + return global_amax @classmethod def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): @@ -70,8 +79,9 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): Returns: The global scaling factor as a float tensor. """ - if cls._is_static_quantizer(weight_quantizer): - return weight_quantizer.global_amax.float() / (6.0 * 448.0) + global_amax = cls._get_static_global_amax(weight_quantizer) + if global_amax is not None: + return global_amax.float() / (6.0 * 448.0) else: assert hasattr(weight_quantizer, "_amax"), ( "Weight quantizer does not have attribute amax" @@ -109,7 +119,7 @@ def get_weights_scaling_factor_from_quantizer( if cls._is_static_quantizer(weight_quantizer): # Static path: use pre-computed per-block amax values from quantizer - global_amax = weight_quantizer.global_amax.float() + global_amax = cls._get_static_global_amax(weight_quantizer).float() per_block_amax = weight_quantizer._amax.float() # Compute scales in float diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml new file mode 100644 index 0000000000..df2b30b8ed --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# but with ONE major difference: use max calibration instead of MSE +# - MoE routed experts: NVFP4 W4A4 weight, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: amax/max calibration comparison variant +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj FP8 per-tensor; FP8 KV cache; + everything else(lm_head/MTP/Latent MOE) stay BF16. Amax calibration comparison variant. +quantize: + algorithm: + method: max + quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml new file mode 100644 index 0000000000..729dcd12d5 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: weight MSE with FP8-scale sweep over the 128 e4m3 scale values +# (NVFP4 weights use static block scales selected by MSE; FP8 per-tensor scales +# are also chosen via MSE search instead of plain amax). +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj FP8 per-tensor; FP8 KV cache; + everything else(lm_head/MTP/latent MOE) stay BF16. Weight-MSE calibration with FP8 scale sweep. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # Weight uses static block scales (chosen by MSE); activations stay dynamic. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 3fac8269cc..8dfbc0323c 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -29,7 +29,7 @@ import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp -from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf from modelopt.torch.export.unified_export_megatron import GPTModelExporter from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel @@ -42,15 +42,8 @@ def _verify_model_quant_config( """Verify config.json and hf_quant_config.json""" config_dict = json.load(open(export_dir / "config.json")) hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) - # Make sure config.json and hf_quant_config.json are consistent - assert ( - config_dict["quantization_config"]["quant_algo"] - == hf_quant_config_dict["quantization"]["quant_algo"] - ) - assert ( - config_dict["quantization_config"]["ignore"] - == hf_quant_config_dict["quantization"]["exclude_modules"] - ) + # Make sure config.json and hf_quant_config.json use the same serving config. + assert config_dict["quantization_config"] == hf_quant_config_dict # Verify config.json if kv_cache_quant_cfg: @@ -58,17 +51,17 @@ def _verify_model_quant_config( # Verify hf_quant_config.json if quant_config: - quant_config_dict = hf_quant_config_dict["quantization"] + quant_config_dict = hf_quant_config_dict quant_type = quant_config_dict["quant_algo"] assert ( quant_type in quant_config ) # quant config str is subset of quant config e.g. NVFP4 -> NVFP4_DEFAULT_CFG - assert len(quant_config_dict["exclude_modules"]) > 1 # Dynamically added exclude modules + assert len(quant_config_dict["ignore"]) > 1 # Dynamically added exclude modules if quant_type == "NVFP4": - assert quant_config_dict["group_size"] == 16 + assert quant_config_dict["config_groups"]["group_0"]["weights"]["group_size"] == 16 if kv_cache_quant_cfg: - assert quant_config_dict["kv_cache_quant_algo"] == KV_CACHE_FP8 + assert quant_config_dict["kv_cache_scheme"]["num_bits"] == 8 def _test_unified_export_megatron( @@ -295,6 +288,44 @@ def test_qkv_slicing_gqa_tp2(dist_workers_size_2, tmp_path): dist_workers_size_2.run(partial(_test_qkv_slicing_gqa_tp2, tmp_path)) +def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv(): + """Unquantized fused MCore linear_qkv should become HF q/k/v excludes.""" + exporter = object.__new__(GPTModelExporter) + exporter.dtype = torch.bfloat16 + exporter.exclude_modules = ["backbone.layers.0.mixer"] + exporter.layer_config_dict = {} + exporter._state_dict = {} + + hidden_size = 8 + head_size = 4 + num_attention_heads = 2 + num_query_groups = 1 + qkv_dim = num_attention_heads + 2 * num_query_groups + weight = torch.arange(qkv_dim * head_size * hidden_size, dtype=torch.bfloat16).reshape( + qkv_dim * head_size, hidden_size + ) + + module = torch.nn.Module() + module.config = type( + "Config", + (), + { + "hidden_size": hidden_size, + "num_query_groups": num_query_groups, + "num_attention_heads": num_attention_heads, + "kv_channels": head_size, + }, + )() + exporter._get_quantized_state = lambda *args, **kwargs: ({"weight": weight}, None, 0) + + exporter._qkv_slicing(module, "backbone.layers.0.mixer.") + + assert "backbone.layers.0.mixer" not in exporter.exclude_modules + assert "backbone.layers.0.mixer.q_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.k_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.v_proj" in exporter.exclude_modules + + def _make_exporter_for_mtp(model_dir: Path) -> GPTModelExporter: """Create a minimal GPTModelExporter instance for testing _get_mtp_state_dict.""" exporter = object.__new__(GPTModelExporter) diff --git a/tests/unit/torch/export/test_hf_checkpoint_utils.py b/tests/unit/torch/export/test_hf_checkpoint_utils.py index f83cb35574..33d17eebb3 100644 --- a/tests/unit/torch/export/test_hf_checkpoint_utils.py +++ b/tests/unit/torch/export/test_hf_checkpoint_utils.py @@ -20,6 +20,8 @@ import pytest pytest.importorskip("huggingface_hub") +hf_hub_errors = pytest.importorskip("huggingface_hub.errors") +LocalEntryNotFoundError = hf_hub_errors.LocalEntryNotFoundError from modelopt.torch.export import copy_hf_ckpt_remote_code @@ -59,15 +61,60 @@ def test_copy_hf_ckpt_remote_code_local_dir_no_py_files(tmp_path): assert list(dst_dir.iterdir()) == [], "no files should be copied" -def test_copy_hf_ckpt_remote_code_hub_id(tmp_path): - """copy_hf_ckpt_remote_code delegates to snapshot_download for a Hub model ID.""" +def test_copy_hf_ckpt_remote_code_hub_id(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code copies .py files from the resolved Hub snapshot.""" dst_dir = tmp_path / "dst" - - with patch("modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download") as mock_sd: + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + (snapshot_dir / "modeling_custom.py").write_text("# custom model") + (snapshot_dir / "not_python.txt").write_text("not python") + + monkeypatch.delenv("HF_HUB_OFFLINE", raising=False) + with patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + return_value=str(snapshot_dir), + ) as mock_sd: copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-Nano-12B-v2", dst_dir) mock_sd.assert_called_once_with( repo_id="nvidia/NVIDIA-Nemotron-Nano-12B-v2", - local_dir=str(dst_dir), allow_patterns=["*.py"], + local_files_only=False, + ) + assert (dst_dir / "modeling_custom.py").read_text() == "# custom model" + assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied" + + +def test_copy_hf_ckpt_remote_code_hub_id_offline_uses_cache(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code resolves cached Hub snapshots when HF_HUB_OFFLINE is set.""" + dst_dir = tmp_path / "dst" + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + (snapshot_dir / "nemotron_reasoning_parser.py").write_text("# parser") + + monkeypatch.setenv("HF_HUB_OFFLINE", "1") + with patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + return_value=str(snapshot_dir), + ) as mock_sd: + copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", dst_dir) + + mock_sd.assert_called_once_with( + repo_id="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", + allow_patterns=["*.py"], + local_files_only=True, ) + assert (dst_dir / "nemotron_reasoning_parser.py").read_text() == "# parser" + + +def test_copy_hf_ckpt_remote_code_hub_id_offline_missing_cache_raises(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code raises a clear error when offline cache is missing.""" + monkeypatch.setenv("HF_HUB_OFFLINE", "1") + with ( + patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + side_effect=LocalEntryNotFoundError("missing"), + ), + pytest.raises(RuntimeError, match="HF_HUB_OFFLINE"), + ): + copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", tmp_path / "dst")