Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
mto.enable_huggingface_checkpointing()


_QWEN36_AUTOQ_DISABLED_LAYERS = (
"*shared_expert_gate*",
"*linear_attn.in_proj_a*",
"*linear_attn.in_proj_b*",
)
_VLM_AUTOQ_DISABLED_LAYERS = ("*visual*", "*mtp*", "*vision_tower*")


def get_auto_quantize_disabled_layers(model) -> list[str]:
"""Return layer patterns that should be excluded from AutoQuantize search."""
disabled_layers = [
entry["quantizer_name"]
for entry in _default_disabled_quantizer_cfg
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
]
disabled_layers.extend(p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
if is_multimodal_model(model):
disabled_layers.extend(p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
return disabled_layers


def extract_and_prepare_language_model_from_vl(full_model):
"""Extract language model from VL model and disable quantization for non-language components.

Expand Down Expand Up @@ -362,14 +383,11 @@ def forward_step(model, batch):
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
),
verbose=True,
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
disabled_layers=[
entry["quantizer_name"]
for entry in _default_disabled_quantizer_cfg
if "parent_class" not in entry
],
disabled_layers=get_auto_quantize_disabled_layers(language_model),
method=auto_quantize_method,
checkpoint=auto_quantize_checkpoint,
cost_model=args.auto_quantize_cost_model,
active_moe_expert_ratio=args.auto_quantize_active_moe_expert_ratio,
)

calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
Expand Down Expand Up @@ -506,13 +524,15 @@ def load_model(args: argparse.Namespace):
: len(args.dataset)
]

# We only quantize the language model for VLMs other than the type supported above.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
full_model
)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type
# AutoQuantize walks the outer CausalLM so lm_head is visible to the
# search. Visual/MTP siblings are excluded by disabled-layer patterns.
if args.auto_quantize_bits is None:
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
full_model
)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type

tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)

Expand Down Expand Up @@ -1018,10 +1038,18 @@ def quantize_main(
"Auto quantization needs multiple quantization format."
)

# For VL models, autoquant must walk submodules of the OUTER CausalLM
# (which carries lm_head and the LM-head forward path) — otherwise
# lm_head and any sibling-of-language_model modules are silently
# invisible to the search. ``forward_step`` also needs the outer model
# to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
# Visual tower and MTP siblings are auto-excluded inside
# ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
auto_quantize(
args,
language_model,
full_model,
calib_dataloader,
auto_quantize_method=args.auto_quantize_method,
)

else:
Expand Down Expand Up @@ -1326,6 +1354,27 @@ def parse_args() -> argparse.Namespace:
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
),
)
parser.add_argument(
"--auto_quantize_cost_model",
type=str,
default="weight",
choices=["weight", "active_moe"],
help=(
"Cost model for auto_quantize effective-bits accounting. 'weight' counts all "
"quantizable weights equally. 'active_moe' scales routed MoE expert weights by "
"--auto_quantize_active_moe_expert_ratio, or infers top_k/num_experts from model config."
),
)
parser.add_argument(
"--auto_quantize_active_moe_expert_ratio",
type=float,
default=None,
help=(
"Routed MoE expert active ratio for --auto_quantize_cost_model active_moe. "
"For top-k MoE this is top_k / num_experts. If omitted, common model config "
"fields such as num_experts_per_tok and num_experts are used when available."
),
)
parser.add_argument(
"--moe_calib_experts_ratio",
type=float,
Expand All @@ -1347,6 +1396,18 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args()
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
if args.auto_quantize_active_moe_expert_ratio is not None and not (
0.0 < args.auto_quantize_active_moe_expert_ratio <= 1.0
):
parser.error("--auto_quantize_active_moe_expert_ratio must be in the range (0.0, 1.0].")
if (
args.auto_quantize_cost_model == "weight"
and args.auto_quantize_active_moe_expert_ratio is not None
):
parser.error(
"--auto_quantize_active_moe_expert_ratio requires "
"--auto_quantize_cost_model active_moe."
)

if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense":
parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).")
Expand Down
91 changes: 91 additions & 0 deletions modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,97 @@ def set_expert_quantizer_amax(
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]


_LINEAR_ATTN_FUSED_PAIRS = [
("in_proj_qkv", "in_proj_z"),
("in_proj_b", "in_proj_a"),
]


def _tensor_values_equal(left: torch.Tensor | None, right: torch.Tensor | None) -> bool:
if left is None or right is None:
return left is right
if left.is_meta or right.is_meta:
return False
return torch.equal(left, right)


def _safe_quantizer_amax(quantizer) -> torch.Tensor | None:
try:
return getattr(quantizer, "amax", None)
except AssertionError:
return None


def _linear_fusion_scales_match(left: nn.Module, right: nn.Module) -> bool:
left_iq = getattr(left, "input_quantizer", None)
right_iq = getattr(right, "input_quantizer", None)
if (
left_iq is not None
and right_iq is not None
and getattr(left_iq, "is_enabled", False)
and getattr(right_iq, "is_enabled", False)
and not _tensor_values_equal(_safe_quantizer_amax(left_iq), _safe_quantizer_amax(right_iq))
):
return False

left_wq = getattr(left, "weight_quantizer", None)
right_wq = getattr(right, "weight_quantizer", None)
if left_wq is None or right_wq is None:
return True

if isinstance(left_wq, SequentialQuantizer) and isinstance(right_wq, SequentialQuantizer):
if (
len(left_wq) > 0
and len(right_wq) > 0
and getattr(left_wq[-1], "is_enabled", False)
and getattr(right_wq[-1], "is_enabled", False)
):
return _tensor_values_equal(
_safe_quantizer_amax(left_wq[-1]), _safe_quantizer_amax(right_wq[-1])
)
return True

if hasattr(left_wq, "global_amax") and hasattr(right_wq, "global_amax"):
return _tensor_values_equal(left_wq.global_amax, right_wq.global_amax)

if getattr(left_wq, "is_enabled", False) and getattr(right_wq, "is_enabled", False):
return _tensor_values_equal(_safe_quantizer_amax(left_wq), _safe_quantizer_amax(right_wq))

return True


def sync_linear_attn_fused_projection_amax(model: nn.Module) -> int:
"""Sync quantizer amaxes for GDN projections that serving engines fuse.

Qwen3.5/Qwen3-Next GDN exports keep ``in_proj_qkv`` and ``in_proj_z`` as
separate HF tensors, but vLLM fuses them into ``in_proj_qkvz`` at load time.
Likewise ``in_proj_b`` and ``in_proj_a`` may be fused as ``in_proj_ba``.
Sharing the quantizer scale domains before export avoids serving-time fused
loaders having to reconcile different scalar/global scales.

Returns:
Number of projection pairs whose scale state changed.
"""
changed = 0
for _, sub_module in model.named_modules():
for left_name, right_name in _LINEAR_ATTN_FUSED_PAIRS:
left = getattr(sub_module, left_name, None)
right = getattr(sub_module, right_name, None)
if left is None or right is None:
continue
left_format = get_quantization_format(left)
right_format = get_quantization_format(right)
if left_format != right_format or left_format is None:
continue
if left_format == QUANTIZATION_NONE:
continue
matched_before = _linear_fusion_scales_match(left, right)
preprocess_linear_fusion([left, right])
if not matched_before:
changed += 1
return changed


def sync_moe_gate_up_amax(model: nn.Module) -> int:
"""Take element-wise max of gate and up weight quantizer amaxes per expert.

Expand Down
10 changes: 10 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
is_moe,
is_quantlinear,
set_expert_quantizer_amax,
sync_linear_attn_fused_projection_amax,
sync_moe_gate_up_amax,
)
from .model_config import (
Expand Down Expand Up @@ -810,6 +811,15 @@ def _export_transformers_checkpoint(
f"Taking element-wise max of amaxes for serving-engine fusion."
)

# Safety net for Qwen3.5/Qwen3-Next GDN projections. These remain separate
# HF tensors, but vLLM fuses qkv+z and b+a at load time.
synced = sync_linear_attn_fused_projection_amax(model)
if synced:
warnings.warn(
f"Synced quantizer amax/global_amax for {synced} linear-attention "
f"projection pair(s) that are fused by serving engines."
)

# Process all quantized modules and export weights
_process_quantized_modules(model, dtype, is_modelopt_qlora)

Expand Down
Loading