diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d660c1de4c8..647bd6b5e9b 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -137,6 +137,27 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: mto.enable_huggingface_checkpointing() +_QWEN36_AUTOQ_DISABLED_LAYERS = ( + "*shared_expert_gate*", + "*linear_attn.in_proj_a*", + "*linear_attn.in_proj_b*", +) +_VLM_AUTOQ_DISABLED_LAYERS = ("*visual*", "*mtp*", "*vision_tower*") + + +def get_auto_quantize_disabled_layers(model) -> list[str]: + """Return layer patterns that should be excluded from AutoQuantize search.""" + disabled_layers = [ + entry["quantizer_name"] + for entry in _default_disabled_quantizer_cfg + if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*" + ] + disabled_layers.extend(p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers) + if is_multimodal_model(model): + disabled_layers.extend(p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers) + return disabled_layers + + def extract_and_prepare_language_model_from_vl(full_model): """Extract language model from VL model and disable quantization for non-language components. @@ -362,14 +383,11 @@ def forward_step(model, batch): len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1) ), verbose=True, - # Disable all default disabled layers such as lm_head, mlp.gate, router etc. - disabled_layers=[ - entry["quantizer_name"] - for entry in _default_disabled_quantizer_cfg - if "parent_class" not in entry - ], + disabled_layers=get_auto_quantize_disabled_layers(language_model), method=auto_quantize_method, checkpoint=auto_quantize_checkpoint, + cost_model=args.auto_quantize_cost_model, + active_moe_expert_ratio=args.auto_quantize_active_moe_expert_ratio, ) calibrate_loop = create_forward_loop(dataloader=calib_dataloader) @@ -506,13 +524,15 @@ def load_model(args: argparse.Namespace): : len(args.dataset) ] - # We only quantize the language model for VLMs other than the type supported above. - extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl( - full_model - ) - if extracted_lm is not None: - language_model = extracted_lm - model_type = extracted_model_type + # AutoQuantize walks the outer CausalLM so lm_head is visible to the + # search. Visual/MTP siblings are excluded by disabled-layer patterns. + if args.auto_quantize_bits is None: + extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl( + full_model + ) + if extracted_lm is not None: + language_model = extracted_lm + model_type = extracted_model_type tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) @@ -1018,10 +1038,18 @@ def quantize_main( "Auto quantization needs multiple quantization format." ) + # For VL models, autoquant must walk submodules of the OUTER CausalLM + # (which carries lm_head and the LM-head forward path) — otherwise + # lm_head and any sibling-of-language_model modules are silently + # invisible to the search. ``forward_step`` also needs the outer model + # to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``). + # Visual tower and MTP siblings are auto-excluded inside + # ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns. auto_quantize( args, - language_model, + full_model, calib_dataloader, + auto_quantize_method=args.auto_quantize_method, ) else: @@ -1326,6 +1354,27 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) + parser.add_argument( + "--auto_quantize_cost_model", + type=str, + default="weight", + choices=["weight", "active_moe"], + help=( + "Cost model for auto_quantize effective-bits accounting. 'weight' counts all " + "quantizable weights equally. 'active_moe' scales routed MoE expert weights by " + "--auto_quantize_active_moe_expert_ratio, or infers top_k/num_experts from model config." + ), + ) + parser.add_argument( + "--auto_quantize_active_moe_expert_ratio", + type=float, + default=None, + help=( + "Routed MoE expert active ratio for --auto_quantize_cost_model active_moe. " + "For top-k MoE this is top_k / num_experts. If omitted, common model config " + "fields such as num_experts_per_tok and num_experts are used when available." + ), + ) parser.add_argument( "--moe_calib_experts_ratio", type=float, @@ -1347,6 +1396,18 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") + if args.auto_quantize_active_moe_expert_ratio is not None and not ( + 0.0 < args.auto_quantize_active_moe_expert_ratio <= 1.0 + ): + parser.error("--auto_quantize_active_moe_expert_ratio must be in the range (0.0, 1.0].") + if ( + args.auto_quantize_cost_model == "weight" + and args.auto_quantize_active_moe_expert_ratio is not None + ): + parser.error( + "--auto_quantize_active_moe_expert_ratio requires " + "--auto_quantize_cost_model active_moe." + ) if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense": parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).") diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e8ee5afd451..9825d50f141 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -1173,6 +1173,97 @@ def set_expert_quantizer_amax( _GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")] +_LINEAR_ATTN_FUSED_PAIRS = [ + ("in_proj_qkv", "in_proj_z"), + ("in_proj_b", "in_proj_a"), +] + + +def _tensor_values_equal(left: torch.Tensor | None, right: torch.Tensor | None) -> bool: + if left is None or right is None: + return left is right + if left.is_meta or right.is_meta: + return False + return torch.equal(left, right) + + +def _safe_quantizer_amax(quantizer) -> torch.Tensor | None: + try: + return getattr(quantizer, "amax", None) + except AssertionError: + return None + + +def _linear_fusion_scales_match(left: nn.Module, right: nn.Module) -> bool: + left_iq = getattr(left, "input_quantizer", None) + right_iq = getattr(right, "input_quantizer", None) + if ( + left_iq is not None + and right_iq is not None + and getattr(left_iq, "is_enabled", False) + and getattr(right_iq, "is_enabled", False) + and not _tensor_values_equal(_safe_quantizer_amax(left_iq), _safe_quantizer_amax(right_iq)) + ): + return False + + left_wq = getattr(left, "weight_quantizer", None) + right_wq = getattr(right, "weight_quantizer", None) + if left_wq is None or right_wq is None: + return True + + if isinstance(left_wq, SequentialQuantizer) and isinstance(right_wq, SequentialQuantizer): + if ( + len(left_wq) > 0 + and len(right_wq) > 0 + and getattr(left_wq[-1], "is_enabled", False) + and getattr(right_wq[-1], "is_enabled", False) + ): + return _tensor_values_equal( + _safe_quantizer_amax(left_wq[-1]), _safe_quantizer_amax(right_wq[-1]) + ) + return True + + if hasattr(left_wq, "global_amax") and hasattr(right_wq, "global_amax"): + return _tensor_values_equal(left_wq.global_amax, right_wq.global_amax) + + if getattr(left_wq, "is_enabled", False) and getattr(right_wq, "is_enabled", False): + return _tensor_values_equal(_safe_quantizer_amax(left_wq), _safe_quantizer_amax(right_wq)) + + return True + + +def sync_linear_attn_fused_projection_amax(model: nn.Module) -> int: + """Sync quantizer amaxes for GDN projections that serving engines fuse. + + Qwen3.5/Qwen3-Next GDN exports keep ``in_proj_qkv`` and ``in_proj_z`` as + separate HF tensors, but vLLM fuses them into ``in_proj_qkvz`` at load time. + Likewise ``in_proj_b`` and ``in_proj_a`` may be fused as ``in_proj_ba``. + Sharing the quantizer scale domains before export avoids serving-time fused + loaders having to reconcile different scalar/global scales. + + Returns: + Number of projection pairs whose scale state changed. + """ + changed = 0 + for _, sub_module in model.named_modules(): + for left_name, right_name in _LINEAR_ATTN_FUSED_PAIRS: + left = getattr(sub_module, left_name, None) + right = getattr(sub_module, right_name, None) + if left is None or right is None: + continue + left_format = get_quantization_format(left) + right_format = get_quantization_format(right) + if left_format != right_format or left_format is None: + continue + if left_format == QUANTIZATION_NONE: + continue + matched_before = _linear_fusion_scales_match(left, right) + preprocess_linear_fusion([left, right]) + if not matched_before: + changed += 1 + return changed + + def sync_moe_gate_up_amax(model: nn.Module) -> int: """Take element-wise max of gate and up weight quantizer amaxes per expert. diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a76783ac172..31ba958cceb 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -73,6 +73,7 @@ is_moe, is_quantlinear, set_expert_quantizer_amax, + sync_linear_attn_fused_projection_amax, sync_moe_gate_up_amax, ) from .model_config import ( @@ -810,6 +811,15 @@ def _export_transformers_checkpoint( f"Taking element-wise max of amaxes for serving-engine fusion." ) + # Safety net for Qwen3.5/Qwen3-Next GDN projections. These remain separate + # HF tensors, but vLLM fuses qkv+z and b+a at load time. + synced = sync_linear_attn_fused_projection_amax(model) + if synced: + warnings.warn( + f"Synced quantizer amax/global_amax for {synced} linear-attention " + f"projection pair(s) that are fused by serving engines." + ) + # Process all quantized modules and export weights _process_quantized_modules(model, dtype, is_modelopt_qlora) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index f1db2df9e84..1d89cf43258 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -46,6 +46,101 @@ from .utils import is_quantized_linear +def _is_fused_experts_module(module: nn.Module) -> bool: + """Return True if ``module`` is a quantized fused-MoE-experts container. + + These modules expose plural ``*_input_quantizer`` and ``*_weight_quantizers`` + (an ``nn.ModuleList`` of per-expert quantizers) instead of the singular + ``input_quantizer`` / ``weight_quantizer`` attrs found on standard + ``nn.Linear``-derived QuantModules. AutoQuantize hparam discovery and cost + accounting need to recognize this layout to enumerate fused experts as + search dimensions. + """ + # Late import to avoid a circular import at module load time. + try: + from .plugins.huggingface import _QuantFusedExperts + except ImportError: + return False + return isinstance(module, _QuantFusedExperts) + + +# Quantizer attribute names that participate in AutoQuantize snapshot/restore. +_STD_QUANTIZER_ATTRS = ("input_quantizer", "weight_quantizer", "output_quantizer") +_FUSED_EXPERTS_QUANTIZER_ATTRS = ( + "gate_up_proj_input_quantizer", + "gate_up_proj_weight_quantizers", + "down_proj_input_quantizer", + "down_proj_weight_quantizers", +) + + +def _get_quantizer_attrs(module: nn.Module) -> tuple[str, ...]: + """Return the quantizer attribute names that AutoQuantize must snapshot/restore. + + For fused MoE experts, this returns the four plural quantizer attrs (two + shared input quantizers + two ``ModuleList`` of per-expert weight quantizers). + For standard Linear-derived QuantModules, returns the canonical trio. + """ + if _is_fused_experts_module(module): + return _FUSED_EXPERTS_QUANTIZER_ATTRS + return _STD_QUANTIZER_ATTRS + + +def _make_fresh_quantizer_for_attr(module: nn.Module, attr_name: str) -> nn.Module: + """Return a fresh, default quantizer object suitable to overwrite ``module.``. + + For ModuleList attrs (per-expert quantizers on fused-experts modules), the + returned ModuleList preserves the original list length so per-expert + enumeration stays consistent across recipes. + """ + current = getattr(module, attr_name, None) + if isinstance(current, nn.ModuleList): + return nn.ModuleList(TensorQuantizer() for _ in range(len(current))) + return TensorQuantizer() + + +def _get_module_weight_numel(module: nn.Module) -> int: + """Return the total parameter count of a module's quantizable weights. + + Standard QuantLinear modules have a single ``weight`` parameter. Fused + experts modules have two 3-D fused parameters (``gate_up_proj`` and + ``down_proj``) instead — both contribute to the cost accounting. + """ + if _is_fused_experts_module(module): + total = 0 + for attr in ("gate_up_proj", "down_proj"): + param = getattr(module, attr, None) + if param is not None: + total += param.numel() + return total + weight = getattr(module, "weight", None) + return weight.numel() if weight is not None else 0 + + +_ROUTED_MOE_EXPERT_NAME_RE = re.compile(r"(^|\.)experts(\.|$)") + + +def _is_routed_moe_module_name(name: str) -> bool: + """Return True for routed MoE expert modules, excluding shared experts.""" + return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE.search(name) is not None + + +def _get_active_moe_cost_weight( + module_names: Sequence[str], active_moe_expert_ratio: float | None +) -> float: + """Return cost multiplier for the active-MoE cost model. + + Routed experts are only partially active per decoded token, so their search + cost can be scaled by ``top_k / num_experts``. Non-MoE and shared-expert + modules stay at full weight. + """ + if active_moe_expert_ratio is None: + return 1.0 + if any(_is_routed_moe_module_name(n) for n in module_names): + return active_moe_expert_ratio + return 1.0 + + def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: """Estimate the compression ratio of a quantization configuration. @@ -204,6 +299,7 @@ def __init__( score_modules: list[nn.Module] | None = None, name: str | None = None, quant_module_names: list[str] | None = None, + cost_weight: float = 1.0, ) -> None: """Initializes Hparam with original value and choices.""" choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)}) @@ -211,6 +307,8 @@ def __init__( self.name = name self.quant_module_names = quant_module_names or [] + assert cost_weight > 0.0, "cost_weight must be positive." + self.cost_weight = cost_weight self.quant_modules = list(set(quant_modules or [])) self.score_modules = list(set(score_modules or self.quant_modules)) @@ -218,26 +316,26 @@ def __init__( # This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer # a dynamic attribute for backward compatibility with the model_calib.py # TODO: Make input_quantizer, weight_quantizer, output_quantizer a dynamic attribute and get rid of this hack + # NOTE: For fused-experts modules, the relevant attrs are plural + # (``*_input_quantizer`` + ``*_weight_quantizers`` ModuleList) — see + # ``_get_quantizer_attrs``. Both layouts share the same snapshot dict + # shape so ``active.setter`` swaps the right child modules. self._all_quantizer_choices = {quant_recipe: {} for quant_recipe in self.choices} quant_recipe: QuantRecipe for quant_recipe in self.choices: for quant_module in self.quant_modules: - for quantizer_attr_name in [ - "input_quantizer", - "weight_quantizer", - "output_quantizer", - ]: - setattr(quant_module, quantizer_attr_name, TensorQuantizer()) + attr_names = _get_quantizer_attrs(quant_module) + for attr_name in attr_names: + setattr( + quant_module, + attr_name, + _make_fresh_quantizer_for_attr(quant_module, attr_name), + ) set_quantizer_by_cfg(quant_module, quant_recipe.config.quant_cfg) self._all_quantizer_choices[quant_recipe][quant_module] = { - quantizer_attr_name: getattr(quant_module, quantizer_attr_name) - for quantizer_attr_name in [ - "input_quantizer", - "weight_quantizer", - "output_quantizer", - ] + attr_name: getattr(quant_module, attr_name) for attr_name in attr_names } self.active = self.original @@ -303,15 +401,18 @@ def get_score(self, recipe: QuantRecipe) -> float: total_score += importance.item() return total_score - def get_cost(self, recipe: QuantRecipe) -> float: + def get_cost(self, recipe: QuantRecipe, cost_weight: float | None = None) -> float: """Get the cost for a given recipe. The cost is the total weight size of the quantizable modules multiplied by the compression ratio of the recipe. """ + cost_weight = self.cost_weight if cost_weight is None else cost_weight cost = 0 for quant_module in self.quant_modules: - weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size([quant_module]) + weight_size = ( + _AutoQuantizeBaseSearcher._get_total_weight_size([quant_module]) * cost_weight + ) parallel_state = getattr(quant_module, "parallel_state", None) if parallel_state is None: @@ -341,7 +442,21 @@ def get_cost(self, recipe: QuantRecipe) -> float: @property def attrs(self) -> list[str]: """Return the attributes of the hparam for repr.""" - return ["name", *super().attrs] + return ["name", "cost_weight", *super().attrs] + + +_LINEAR_ATTN_QKVZ_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_qkv|in_proj_z)$") +_LINEAR_ATTN_BA_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_a|in_proj_b)$") + + +def _linear_attn_qkvz_group_key(_model, name: str) -> str | None: + m = _LINEAR_ATTN_QKVZ_RE.match(name) + return f"{m.group(1)}/qkvz" if m else None + + +def _linear_attn_ba_group_key(_model, name: str) -> str | None: + m = _LINEAR_ATTN_BA_RE.match(name) + return f"{m.group(1)}/ba" if m else None class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): @@ -365,6 +480,13 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts + # Qwen3.5/3.6 hybrid linear_attn: vLLM fuses (in_proj_qkv, in_proj_z) + # into ``in_proj_qkvz`` and (in_proj_a, in_proj_b) into ``in_proj_ba`` and + # requires fused shards to share quant_algo. Two callables (not one + # regex) so qkv+z and a+b produce DIFFERENT group keys; each pair + # stays with its own fusion partner. + _linear_attn_qkvz_group_key, + _linear_attn_ba_group_key, ] score_module_rules = [] @@ -381,6 +503,8 @@ def default_search_config(self): "disabled_layers": None, "verbose": is_master(), "checkpoint": None, + "cost_model": "weight", + "active_moe_expert_ratio": None, } @property @@ -388,6 +512,9 @@ def default_state_dict(self) -> SearchStateDict: """Get the default state dict for AutoQuantize.""" return { "method": self.method_name, + "cost_model": "weight", + "active_moe_expert_ratio": None, + "cost_denominator": None, "candidate_stats": defaultdict(dict), "quantizer_states": {}, "best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False}, @@ -403,6 +530,18 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: assert config["forward_step"] is not None, ( "`forward_step` must be provided for `auto_quantize`." ) + if config["cost_model"] not in ("weight", "active_moe"): + raise ValueError( + f"Invalid cost_model: {config['cost_model']}. " + "Valid options are 'weight' and 'active_moe'." + ) + active_moe_expert_ratio = config["active_moe_expert_ratio"] + if active_moe_expert_ratio is not None and not (0.0 < active_moe_expert_ratio <= 1.0): + raise ValueError("active_moe_expert_ratio must be in the range (0.0, 1.0].") + if config["cost_model"] == "active_moe" and active_moe_expert_ratio is None: + raise ValueError( + "active_moe_expert_ratio must be set when using active_moe cost accounting." + ) return config def load_search_checkpoint(self) -> bool: @@ -410,9 +549,15 @@ def load_search_checkpoint(self) -> bool: @staticmethod def _is_auto_quantize_module(module): - return ( - is_quantized_linear(module) or isinstance(module, QuantLinearConvBase) - ) and isinstance(module, QuantModule) + if (is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)) and isinstance( + module, QuantModule + ): + return True + # Fused MoE experts: a single ``QuantModule`` that owns N per-expert + # weight quantizers in an ``nn.ModuleList`` plus shared input quantizers. + # All N experts in a layer share one search dimension (one recipe per + # fused module). + return _is_fused_experts_module(module) and isinstance(module, QuantModule) @staticmethod def _get_search_recipes(quantization_formats): @@ -490,7 +635,9 @@ def _get_score_module_from_name( ) return quant_module - def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers=None): + def insert_hparams_after_merge_rules( + self, model, quant_recipes, disabled_layers=None, active_moe_expert_ratio=None + ): """Restrict the search space using the merge rules and insert the hparams for the model.""" # TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer # Hence we need to restrict the search space so that all these layers share the same recipe @@ -545,6 +692,8 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers quant_modules = [module for module, _, _, _ in module_info_list] disabled = any(disabled for _, _, disabled, _ in module_info_list) score_modules = [score_module for _, _, _, score_module in module_info_list] + quant_module_names = [name for _, name, _, _ in module_info_list] + cost_weight = _get_active_moe_cost_weight(quant_module_names, active_moe_expert_ratio) _quant_recipes = None if disabled else quant_recipes hparam = QuantRecipeHparam( @@ -552,7 +701,8 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers quant_modules=quant_modules, score_modules=score_modules, name=str(group_key), - quant_module_names=[name for _, name, _, _ in module_info_list], + quant_module_names=quant_module_names, + cost_weight=cost_weight, ) for module in quant_modules: @@ -584,23 +734,30 @@ def initialize_candidate_stats(self): if not isinstance(hparam, QuantRecipeHparam): continue - formats, scores, costs = [], [], [] + formats, scores, costs, active_costs = [], [], [], [] prev_score = float("inf") + constraint_cost_weight = ( + hparam.cost_weight if self.config["cost_model"] == "active_moe" else 1.0 + ) for recipe in hparam.choices: formats.append(recipe) score = hparam.get_score(recipe) # type: ignore [arg-type] - cost = hparam.get_cost(recipe) # type: ignore [arg-type] + cost = hparam.get_cost(recipe, cost_weight=constraint_cost_weight) # type: ignore [arg-type] + active_cost = hparam.get_cost(recipe, cost_weight=hparam.cost_weight) # type: ignore [arg-type] score = min(score, prev_score) # TODO: Should we get rid of this? scores.append(score) costs.append(cost) + active_costs.append(active_cost) prev_score = score self.candidate_stats[name]["formats"] = formats self.candidate_stats[name]["scores"] = scores self.candidate_stats[name]["costs"] = costs + self.candidate_stats[name]["active_costs"] = active_costs self.candidate_stats[name]["module_names"] = hparam.quant_module_names + self.candidate_stats[name]["cost_weight"] = hparam.cost_weight def _run_func(self, func, num_iters=1, desc=""): for i, data in tqdm( @@ -625,12 +782,30 @@ def before_search(self): f"Checkpoint method '{restored_method}' does not match current method " f"'{self.method_name}'. Use a different checkpoint path." ) + restored_cost_model = getattr(self, "cost_model", "weight") + restored_active_moe_expert_ratio = getattr(self, "active_moe_expert_ratio", None) + if self.candidate_stats and ( + restored_cost_model != self.config["cost_model"] + or restored_active_moe_expert_ratio != self.config["active_moe_expert_ratio"] + ): + raise ValueError( + "Checkpoint AutoQuantize cost model does not match current search config: " + f"checkpoint=({restored_cost_model}, {restored_active_moe_expert_ratio}), " + f"current=({self.config['cost_model']}, {self.config['active_moe_expert_ratio']}). " + "Use a different checkpoint path." + ) self.method = self.method_name + self.cost_model = self.config["cost_model"] + self.active_moe_expert_ratio = self.config["active_moe_expert_ratio"] + self.cost_denominator = getattr(self, "cost_denominator", None) search_recipes = self._get_search_recipes(self.config["quantization_formats"]) self._verify_constraint(search_recipes) self.insert_hparams_after_merge_rules( - self.model, search_recipes, self.config["disabled_layers"] + self.model, + search_recipes, + self.config["disabled_layers"], + self.config["active_moe_expert_ratio"], ) QuantRecipe.disable_folding_pqs_to_weights() @@ -712,14 +887,23 @@ def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="Au @staticmethod def _get_total_weight_size(modules): return sum( - ( - module.weight.numel() - if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module) - else 0 - ) + _get_module_weight_numel(module) + if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module) + else 0 for module in modules ) + @staticmethod + def _get_total_weight_size_from_named_modules(named_modules, active_moe_expert_ratio=None): + total_weight_size = 0.0 + for name, module in named_modules: + if not _AutoQuantizeBaseSearcher._is_auto_quantize_module(module): + continue + total_weight_size += _get_module_weight_numel(module) * _get_active_moe_cost_weight( + [name], active_moe_expert_ratio + ) + return total_weight_size + def _get_constraints_for_search(self, max_weight_size, lower_bound=None): constraints = { "weight_size_after_compression": ( @@ -729,6 +913,12 @@ def _get_constraints_for_search(self, max_weight_size, lower_bound=None): } return constraints, "weight_size_after_compression" + def _get_search_lower_bounds(self): + cost_model = getattr(self, "cost_model", getattr(self, "config", {}).get("cost_model")) + if cost_model == "active_moe": + return [0.99, 0.90, None] + return [None, 0.99, 0.90] + @abstractmethod def run_search_with_stats(self, max_weight_size, verbose=False): """Run the search with stats to get the best recipe and whether the constraints are satisfied.""" @@ -742,8 +932,24 @@ def run_search(self): ) compression = self._get_formatted_weight_compression_constraint() - total_weight_size = self._get_total_weight_size(self.model.modules()) + if self.config["cost_model"] == "active_moe": + total_weight_size = self._get_total_weight_size_from_named_modules( + self.model.named_modules(), self.config["active_moe_expert_ratio"] + ) + else: + total_weight_size = self._get_total_weight_size(self.model.modules()) + self.cost_denominator = total_weight_size max_weight_size = total_weight_size * compression + if verbose: + print_rank_0( + "AutoQuantize cost model: " + f"{self.config['cost_model']}" + + ( + f" (active_moe_expert_ratio={self.config['active_moe_expert_ratio']})" + if self.config["cost_model"] == "active_moe" + else "" + ) + ) # Run the search with stats to get the best recipe and whether the constraints are satisfied best_recipe_info, is_satisfied = self.run_search_with_stats(max_weight_size, verbose) @@ -1048,7 +1254,7 @@ def run_search_with_stats(self, max_weight_size, verbose=False): """ # TODO: Do this only for rank 0 in the respective pipeline group - for lower_bound in [None, 0.99, 0.90]: + for lower_bound in self._get_search_lower_bounds(): # The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not # specified. I dont know why this happens. # As a workaround, lets specify a lower bound for the weight compression if previous @@ -1329,6 +1535,7 @@ def _cfg_to_dict(v): return v quant_cfg: list[dict] = [{"quantizer_name": "*", "enable": False}] + per_module_entries: list[dict] = [] _per_module_attrs = ("input_quantizer", "weight_quantizer", "output_quantizer") # Track global (non per-module) recipe entries. Last recipe wins for each pattern. global_entries: dict[str, dict] = {} @@ -1349,7 +1556,7 @@ def _cfg_to_dict(v): } if matched_cfg is not None: entry["cfg"] = _cfg_to_dict(matched_cfg) - quant_cfg.append(entry) + per_module_entries.append(entry) # Collect non-per-module entries (e.g. *[kv]_bmm_quantizer) from winning recipes. for recipe_entry in recipe.config.quant_cfg: @@ -1366,7 +1573,10 @@ def _cfg_to_dict(v): ge["cfg"] = _cfg_to_dict(cfg) global_entries[pattern] = ge + # Keep path-scoped recipe entries before explicit module entries so selected + # modules override default disables such as ``*lm_head*``. quant_cfg.extend(global_entries.values()) + quant_cfg.extend(per_module_entries) warnings.warn( "get_auto_quantize_config: returned config uses algorithm='max'. " "Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. " @@ -1379,7 +1589,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): effective_bits = constraints["effective_bits"] compression = effective_bits / 16.0 candidate_stats = search_state["candidate_stats"] - total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values()) + total_weight_size = search_state.get("cost_denominator") or sum( + s["costs"][-1] for s in candidate_stats.values() + ) max_weight_size = total_weight_size * compression method = search_state["method"] @@ -1393,6 +1605,10 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): ) searcher.candidate_stats = candidate_stats + searcher.cost_model = search_state.get("cost_model", "weight") + searcher.config = { + "cost_model": searcher.cost_model, + } best_recipe_info, _ = searcher.run_search_with_stats(max_weight_size, verbose=verbose) best_recipe = {name: info["format"] for name, info in best_recipe_info.items()} @@ -1413,6 +1629,8 @@ def _match_quantizer_cfg(quant_cfg, quantizer_attr): matched = None matched_enable = None for entry in quant_cfg: + if "parent_class" in entry: + continue pattern = entry["quantizer_name"] cfg = entry.get("cfg") enable = entry.get("enable", True) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 5e65f9cc1d4..7e2a8bac6d1 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -265,6 +265,54 @@ def forward_loop(model) -> None: "awq_clip", } +_ACTIVE_MOE_TOP_K_ATTRS = ( + "num_experts_per_tok", + "num_experts_per_token", + "moe_top_k", + "top_k", + "num_selected_experts", +) +_ACTIVE_MOE_NUM_EXPERTS_ATTRS = ( + "num_experts", + "num_local_experts", + "n_routed_experts", + "moe_num_experts", + "num_routed_experts", +) + + +def _iter_model_configs(model: nn.Module): + seen = set() + for obj in (model, getattr(model, "model", None), getattr(model, "language_model", None)): + config = getattr(obj, "config", None) + if config is None or id(config) in seen: + continue + seen.add(id(config)) + yield config + for nested_attr in ("text_config", "language_config"): + nested_config = getattr(config, nested_attr, None) + if nested_config is None or id(nested_config) in seen: + continue + seen.add(id(nested_config)) + yield nested_config + + +def _get_first_numeric_config_attr(model: nn.Module, attr_names: tuple[str, ...]) -> float | None: + for config in _iter_model_configs(model): + for attr_name in attr_names: + value = getattr(config, attr_name, None) + if isinstance(value, (int, float)): + return float(value) + return None + + +def _infer_active_moe_expert_ratio(model: nn.Module) -> float | None: + num_active_experts = _get_first_numeric_config_attr(model, _ACTIVE_MOE_TOP_K_ATTRS) + num_experts = _get_first_numeric_config_attr(model, _ACTIVE_MOE_NUM_EXPERTS_ATTRS) + if num_active_experts is None or num_experts is None or num_experts <= 0: + return None + return min(num_active_experts / num_experts, 1.0) + def auto_quantize( model: nn.Module, @@ -283,6 +331,8 @@ def auto_quantize( verbose: bool = False, method: str = "gradient", checkpoint: str | None = None, + cost_model: str = "weight", + active_moe_expert_ratio: float | None = None, ): r"""Perform optimal per-layer quantization by searching for the best quantization formats per-layer. @@ -433,6 +483,12 @@ def forward_backward_step(model, batch) -> None: checkpoint: (Optional) Path to checkpoint file for saving/restoring auto_quantize search state. If the checkpoint file exists, the search state will be restored from it, skipping the expensive score estimation step. + cost_model: Cost metric used for the effective-bits constraint. ``"weight"`` (default) counts + all quantizable weights equally. ``"active_moe"`` scales routed MoE expert weights by + ``active_moe_expert_ratio`` so the budget approximates active decode weight traffic. + active_moe_expert_ratio: Ratio of routed MoE experts active per token, normally + ``num_experts_per_tok / num_experts``. If omitted with ``cost_model="active_moe"``, the + ratio is inferred from common model config fields when available. Returns: A tuple (model, state_dict) where ``model`` is the searched and quantized model and ``state_dict`` contains the history and detailed stats of the search procedure. @@ -514,6 +570,22 @@ def forward_backward_step(model, batch) -> None: else: raise ValueError(f"Invalid method: {method}. Valid options are 'gradient' or 'kl_div'.") + if cost_model not in ("weight", "active_moe"): + raise ValueError( + f"Invalid cost_model: {cost_model}. Valid options are 'weight' and 'active_moe'." + ) + if active_moe_expert_ratio is not None and not (0.0 < active_moe_expert_ratio <= 1.0): + raise ValueError("active_moe_expert_ratio must be in the range (0.0, 1.0].") + if cost_model == "weight" and active_moe_expert_ratio is not None: + raise ValueError("active_moe_expert_ratio requires cost_model='active_moe'.") + if cost_model == "active_moe" and active_moe_expert_ratio is None: + active_moe_expert_ratio = _infer_active_moe_expert_ratio(model) + if active_moe_expert_ratio is None: + raise ValueError( + "Could not infer active_moe_expert_ratio from model.config. " + "Pass active_moe_expert_ratio explicitly." + ) + model = apply_mode( model, mode="auto_quantize", @@ -530,6 +602,8 @@ def forward_backward_step(model, batch) -> None: "disabled_layers": disabled_layers, "verbose": verbose, "checkpoint": checkpoint, + "cost_model": cost_model, + "active_moe_expert_ratio": active_moe_expert_ratio, } # Disable all quantizers; AutoQuantize will enable the needed ones set_quantizer_by_cfg(model, [{"quantizer_name": "*", "enable": False}]) diff --git a/tests/unit/torch/export/test_export_weight.py b/tests/unit/torch/export/test_export_weight.py index 13617994bf0..3829777c77b 100644 --- a/tests/unit/torch/export/test_export_weight.py +++ b/tests/unit/torch/export/test_export_weight.py @@ -19,7 +19,13 @@ from _test_utils.torch.export.utils import ToyModel, partial_fp8_config, partial_w4a8_config import modelopt.torch.quantization as mtq +from modelopt.torch.export.layer_utils import sync_linear_attn_fused_projection_amax from modelopt.torch.export.unified_export_hf import _export_quantized_weight +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + NVFP4StaticQuantizer, + TensorQuantizer, +) from modelopt.torch.quantization.utils import quantizer_attr_names @@ -96,3 +102,74 @@ def test_export_per_block_quantized_weight(): assert hasattr(model.linears[2], quantizer_attrs.output_quantizer) assert not getattr(model.linears[2], quantizer_attrs.output_quantizer).is_enabled assert not hasattr(model.linears[2], quantizer_attrs.output_scale) + + +class _GatedDeltaNetProjectionToy(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear_attn = torch.nn.Module() + self.linear_attn.in_proj_qkv = torch.nn.Linear(16, 48, bias=False) + self.linear_attn.in_proj_z = torch.nn.Linear(16, 16, bias=False) + + +def _attach_quantizers(module, weight_cfg, input_cfg): + module.weight_quantizer = TensorQuantizer(weight_cfg) + module.input_quantizer = TensorQuantizer(input_cfg) + + +def test_sync_linear_attn_fused_projection_fp8_amax(): + model = _GatedDeltaNetProjectionToy() + quant_cfg = QuantizerAttributeConfig(num_bits=(4, 3), axis=None) + _attach_quantizers(model.linear_attn.in_proj_qkv, quant_cfg, quant_cfg) + _attach_quantizers(model.linear_attn.in_proj_z, quant_cfg, quant_cfg) + + model.linear_attn.in_proj_qkv.weight_quantizer.amax = torch.tensor(3.0) + model.linear_attn.in_proj_z.weight_quantizer.amax = torch.tensor(5.0) + model.linear_attn.in_proj_qkv.input_quantizer.amax = torch.tensor(7.0) + model.linear_attn.in_proj_z.input_quantizer.amax = torch.tensor(11.0) + + synced = sync_linear_attn_fused_projection_amax(model) + + assert synced == 1 + assert torch.equal(model.linear_attn.in_proj_qkv.weight_quantizer.amax, torch.tensor(5.0)) + assert torch.equal(model.linear_attn.in_proj_z.weight_quantizer.amax, torch.tensor(5.0)) + assert torch.equal(model.linear_attn.in_proj_qkv.input_quantizer.amax, torch.tensor(11.0)) + assert torch.equal(model.linear_attn.in_proj_z.input_quantizer.amax, torch.tensor(11.0)) + assert hasattr(model.linear_attn, "in_proj_qkv") + assert hasattr(model.linear_attn, "in_proj_z") + + +def test_sync_linear_attn_fused_projection_nvfp4_global_amax(): + model = _GatedDeltaNetProjectionToy() + weight_cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + axis=None, + ) + input_cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + axis=None, + ) + _attach_quantizers(model.linear_attn.in_proj_qkv, weight_cfg, input_cfg) + _attach_quantizers(model.linear_attn.in_proj_z, weight_cfg, input_cfg) + NVFP4StaticQuantizer.from_tensor_quantizer( + model.linear_attn.in_proj_qkv.weight_quantizer, global_amax=torch.tensor(13.0) + ) + NVFP4StaticQuantizer.from_tensor_quantizer( + model.linear_attn.in_proj_z.weight_quantizer, global_amax=torch.tensor(17.0) + ) + model.linear_attn.in_proj_qkv.input_quantizer.amax = torch.tensor(19.0) + model.linear_attn.in_proj_z.input_quantizer.amax = torch.tensor(23.0) + + synced = sync_linear_attn_fused_projection_amax(model) + + assert synced == 1 + assert torch.equal( + model.linear_attn.in_proj_qkv.weight_quantizer.global_amax, torch.tensor(17.0) + ) + assert torch.equal(model.linear_attn.in_proj_z.weight_quantizer.global_amax, torch.tensor(17.0)) + assert torch.equal(model.linear_attn.in_proj_qkv.input_quantizer.amax, torch.tensor(23.0)) + assert torch.equal(model.linear_attn.in_proj_z.input_quantizer.amax, torch.tensor(23.0)) + assert hasattr(model.linear_attn, "in_proj_qkv") + assert hasattr(model.linear_attn, "in_proj_z") diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..394a64ef563 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -15,6 +15,7 @@ import copy import io +from types import SimpleNamespace import pytest import torch @@ -24,6 +25,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.algorithms import ( + AutoQuantizeGradientSearcher, QuantRecipe, QuantRecipeHparam, estimate_quant_compression, @@ -62,6 +64,31 @@ def get_input(self): return torch.randn(1, 4, 32) +class _AutoQuantMoeModel(torch.nn.Module): + def __init__(self, num_experts_attr="num_experts"): + super().__init__() + self.config = SimpleNamespace(text_config=SimpleNamespace(num_experts_per_tok=2)) + setattr(self.config.text_config, num_experts_attr, 8) + self.mlp = torch.nn.Module() + self.mlp.experts = torch.nn.ModuleList() + for _ in range(2): + expert = torch.nn.Module() + expert.gate_proj = torch.nn.Linear(32, 32) + expert.up_proj = torch.nn.Linear(32, 32) + expert.down_proj = torch.nn.Linear(32, 32) + self.mlp.experts.append(expert) + self.mlp.shared_expert = torch.nn.Linear(32, 32) + + def forward(self, x): + y = self.mlp.shared_expert(x) + for expert in self.mlp.experts: + y = y + expert.down_proj(expert.gate_proj(x) + expert.up_proj(x)) + return y + + def get_input(self): + return torch.randn(1, 4, 32) + + @pytest.mark.parametrize( ("quant_cfg", "other_quant_cfg", "is_less_than"), [ @@ -109,6 +136,80 @@ def test_quant_recipe_hparam(): assert torch.allclose(output_test, output_ref) +def test_quant_recipe_hparam_cost_weight(): + model_test = mtq.quantize(torch.nn.Linear(4, 16), mtq.INT8_DEFAULT_CFG) + search_recipes = [QuantRecipe(mtq.INT8_DEFAULT_CFG)] + hparam = QuantRecipeHparam( + search_recipes, + quant_modules=[model_test], + quant_module_names=["layers.0.mlp.experts.0.down_proj"], + cost_weight=0.25, + ) + + dense_cost = hparam.get_cost(QuantRecipe(quant_cfg=None)) + int8_cost = hparam.get_cost(QuantRecipe(mtq.INT8_DEFAULT_CFG)) + + assert dense_cost == pytest.approx(model_test.weight.numel() * 0.25) + assert int8_cost == pytest.approx(model_test.weight.numel() * 0.25 * 0.5) + + +@pytest.mark.parametrize("num_experts_attr", ["num_experts", "num_local_experts"]) +def test_auto_quantize_active_moe_cost_model(num_experts_attr): + model = _AutoQuantMoeModel(num_experts_attr) + + _, search_history = mtq.auto_quantize( + model, + constraints={"effective_bits": 6.0}, + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model.get_input() for _ in range(2)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + num_calib_steps=2, + num_score_steps=2, + cost_model="active_moe", + ) + + assert search_history["cost_model"] == "active_moe" + assert search_history["active_moe_expert_ratio"] == pytest.approx(0.25) + weighted_no_quant_cost = sum( + stats["costs"][-1] for stats in search_history["candidate_stats"].values() + ) + assert search_history["cost_denominator"] == pytest.approx(weighted_no_quant_cost) + routed_stats = [ + stats + for stats in search_history["candidate_stats"].values() + if any("mlp.experts" in name for name in stats["module_names"]) + ] + shared_stats = [ + stats + for stats in search_history["candidate_stats"].values() + if any("mlp.shared_expert" in name for name in stats["module_names"]) + ] + assert routed_stats + assert shared_stats + assert all(stats["cost_weight"] == pytest.approx(0.25) for stats in routed_stats) + assert all(stats["cost_weight"] == pytest.approx(1.0) for stats in shared_stats) + assert all("active_costs" in stats for stats in search_history["candidate_stats"].values()) + + +def test_active_moe_search_prefers_budget_lower_bound(): + searcher = AutoQuantizeGradientSearcher() + searcher.config = {"cost_model": "active_moe"} + searcher.cost_model = "active_moe" + searcher.candidate_stats = { + "layers.0.mlp.quant_recipe": { + "formats": ["under_budget", "near_budget"], + "costs": [1.0, 4.95], + "scores": [0.0, 10.0], + } + } + + best_recipes, is_satisfied = searcher.run_search_with_stats(5.0) + + assert is_satisfied + assert best_recipes["layers.0.mlp.quant_recipe"]["format"] == "near_budget" + + # use this config to test custom quantization config INT8_CUSTOM_QUANT_TEST_CFG = { "quant_cfg": [ @@ -508,3 +609,28 @@ def test_get_auto_quantize_config(method): fresh_model = mtq.quantize(fresh_model, config, forward_loop=lambda m: m(model.get_input())) output = fresh_model(model.get_input()) assert output is not None + + +def test_get_auto_quantize_config_keeps_selected_lm_head_enabled(): + recipe = QuantRecipe(mtq.FP8_DEFAULT_CFG) + search_state = { + "best": {"recipe": {"lm_head.quant_recipe": recipe}}, + "candidate_stats": {"lm_head.quant_recipe": {"module_names": ["lm_head"]}}, + } + + config = mtq.get_auto_quantize_config(search_state) + quant_cfg = config["quant_cfg"] + + default_disable_idx = next( + idx for idx, entry in enumerate(quant_cfg) if entry["quantizer_name"] == "*lm_head*" + ) + weight_idx = next( + idx + for idx, entry in enumerate(quant_cfg) + if entry["quantizer_name"] == "lm_head.weight_quantizer" + ) + weight_entry = quant_cfg[weight_idx] + + assert default_disable_idx < weight_idx + assert weight_entry["enable"] is True + assert weight_entry["cfg"]["num_bits"] == (4, 3)