Skip to content

Commit 21a4010

Browse files
soodoshllshengliangxuEdwardf0t1
authored
Add Quantizers for Qwen3VLMoeTextDecoderLayer (#666)
## What does this PR do? **Type of change:** ? new feature **Overview:** ? huggingface transformers library implements Qwen3VL Moe layer as a monolithic module, instead of assembling it using Linear layers, which cannot be recognized by modelopt's quantizer now. This PR introduces a conversion from hf's qwen3vl_moe MoE layers to qewn3_moe MoE layers which consist of a set of Linear layers. ## Testing Tested with ```python python hf_ptq.py --pyt_ckpt_path=Qwen/Qwen3-VL-30B-A3B-Instruct --qformat=nvfp4 --dataset wikipedia ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added quantization support for Qwen3VL models with sparse mixture-of-experts (MoE) architecture, enabling efficient model compression for this model type. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Qidong Su <qidongs@nvidia.com> Signed-off-by: Qidong Su <soodoshll@gmail.com> Co-authored-by: Shengliang Xu <106840466+shengliangxu@users.noreply.github.com> Co-authored-by: Zhiyu <bestczy317@gmail.com>
1 parent b0e7d9f commit 21a4010

1 file changed

Lines changed: 98 additions & 0 deletions

File tree

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,86 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
571571
return self.w2_linear[expert_idx](x1)
572572

573573

574+
class _QuantQwen3VLMoeTextExperts(QuantModule):
575+
def _setup(self):
576+
"""Modify the Qwen3VLMoeTextExperts by using nn.Linear layers."""
577+
from accelerate import init_empty_weights
578+
579+
dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device
580+
581+
def _copy_weight(module, weight):
582+
module.to_empty(device=device)
583+
with torch.no_grad():
584+
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)
585+
586+
# The attribute name was changed from `intermediate_size` to `intermediate_dim` in
587+
# https://github.com/huggingface/transformers/commit/0642963ba13f2dae0596fe489415569e1d91fbda
588+
if hasattr(self, "intermediate_size"):
589+
expert_dim = self.intermediate_size
590+
elif hasattr(self, "intermediate_dim"):
591+
expert_dim = self.intermediate_dim
592+
else:
593+
raise AttributeError("Could not find intermediate dimension size in model")
594+
595+
with init_empty_weights():
596+
gate_proj = nn.ModuleList(
597+
[
598+
nn.Linear(self.hidden_size, expert_dim, bias=False)
599+
for _ in range(self.num_experts)
600+
]
601+
)
602+
up_proj = nn.ModuleList(
603+
[
604+
nn.Linear(self.hidden_size, expert_dim, bias=False)
605+
for _ in range(self.num_experts)
606+
]
607+
)
608+
down_proj = nn.ModuleList(
609+
[
610+
nn.Linear(expert_dim, self.hidden_size, bias=False)
611+
for _ in range(self.num_experts)
612+
]
613+
)
614+
615+
for idx in range(self.num_experts):
616+
_copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, :expert_dim].T)
617+
_copy_weight(up_proj[idx], self.gate_up_proj[idx, :, expert_dim:].T)
618+
_copy_weight(down_proj[idx], self.down_proj[idx, :].T)
619+
620+
delattr(self, "gate_up_proj")
621+
delattr(self, "down_proj")
622+
self.gate_proj = gate_proj
623+
self.up_proj = up_proj
624+
self.down_proj = down_proj
625+
626+
def forward(
627+
self,
628+
hidden_states: torch.Tensor,
629+
routing_weights: torch.Tensor,
630+
router_indices: torch.Tensor,
631+
) -> torch.Tensor:
632+
batch_size = hidden_states.shape[0]
633+
hidden_states = hidden_states.reshape(-1, self.hidden_size)
634+
next_states = torch.zeros_like(hidden_states)
635+
with torch.no_grad():
636+
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
637+
expert_mask = expert_mask.permute(2, 1, 0)
638+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
639+
for expert_idx in expert_hit:
640+
with torch.no_grad():
641+
_, token_idx = torch.where(expert_mask[expert_idx[0]])
642+
current_state = hidden_states[token_idx]
643+
gate = self.gate_proj[expert_idx](current_state)
644+
up = self.up_proj[expert_idx](current_state)
645+
gated_output = up * self.act_fn(gate)
646+
out = self.down_proj[expert_idx](gated_output)
647+
weighted_output = out * routing_weights[token_idx, expert_idx, None]
648+
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
649+
next_states = next_states.view(batch_size, -1, self.hidden_size)
650+
651+
return next_states
652+
653+
574654
class _QuantDbrxFFN(_QuantSparseMoe):
575655
@property
576656
def num_experts(self):
@@ -660,6 +740,24 @@ def top_k(self, value):
660740
except ImportError:
661741
pass
662742

743+
try:
744+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
745+
Qwen3VLMoeTextExperts,
746+
Qwen3VLMoeTextSparseMoeBlock,
747+
)
748+
749+
if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
750+
QuantModuleRegistry.register(
751+
{Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
752+
)(_QuantSparseMoe)
753+
754+
if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
755+
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
756+
_QuantQwen3VLMoeTextExperts
757+
)
758+
except ImportError:
759+
pass
760+
663761

664762
class _QuantGptOssExperts(_QuantFunctionalMixin):
665763
"""Quantized wrapper for `transformers.GptOssExperts`.

0 commit comments

Comments
 (0)