diff --git a/examples/specdec_bench/dflash_kimi.yaml b/examples/specdec_bench/dflash_kimi.yaml new file mode 100644 index 00000000000..0e34c5359e7 --- /dev/null +++ b/examples/specdec_bench/dflash_kimi.yaml @@ -0,0 +1,10 @@ +chat_template_args: + thinking: true +engine_args: + mem_fraction_static: 0.9 + speculative_num_draft_tokens: 8 + # cuda_graph_max_bs: 128 + speculative_dflash_draft_window_size: 4096 + disable_cuda_graph: true +sampling_kwargs: + temperature: 0 diff --git a/examples/specdec_bench/dflash_qwen.yaml b/examples/specdec_bench/dflash_qwen.yaml new file mode 100644 index 00000000000..457e1eadaca --- /dev/null +++ b/examples/specdec_bench/dflash_qwen.yaml @@ -0,0 +1,7 @@ +engine_args: + mem_fraction_static: 0.9 + speculative_num_draft_tokens: 8 + speculative_dflash_draft_window_size: 4096 + mamba_scheduler_strategy: extra_buffer +sampling_kwargs: + temperature: 0 diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index f4fbf06c0e8..94932c787b8 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -265,7 +265,7 @@ def run_simple(args): type=str, required=False, default="EAGLE3", - choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"], + choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "DFLASH", "NONE"], help="Speculative algorithm to use", ) parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory") diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index d5ff890ffd7..00ba1de1f44 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -41,46 +41,67 @@ def __init__( speculative_algorithm = "STANDALONE" elif speculative_algorithm == "NGRAM": speculative_algorithm = "LOOKAHEAD" + elif speculative_algorithm == "DFLASH": + pass # SGLang native name, pass through elif speculative_algorithm == "NONE": speculative_algorithm = None + + engine_kwargs = dict( + model_path=model_dir, + skip_tokenizer_init=True, + trust_remote_code=kwargs.get("trust_remote_code", False), + mem_fraction_static=kwargs.get("mem_fraction_static", 0.8), + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), + tp_size=kwargs.get("tensor_parallel_size", 1), + ep_size=kwargs.get("moe_expert_parallel_size", 1), + torch_compile_max_bs=max_concurrent_requests, + max_running_requests=max_concurrent_requests, + attention_backend=kwargs.get("attention_backend"), + enable_torch_compile=kwargs.get("enable_torch_compile", False), + cuda_graph_max_bs=max_concurrent_requests, + disable_cuda_graph=False, + ) if speculative_algorithm is not None: # https://github.com/sgl-project/sglang/pull/3582 - self.model = sgl.Engine( - model_path=model_dir, - skip_tokenizer_init=True, - trust_remote_code=kwargs.get("trust_remote_code", False), - mem_fraction_static=0.8, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), - tp_size=kwargs.get("tensor_parallel_size", 1), - ep_size=kwargs.get("moe_expert_parallel_size", 1), - speculative_algorithm=speculative_algorithm, - speculative_num_steps=kwargs.get("speculative_num_steps", 3), - speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1), - speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4), - speculative_draft_model_path=kwargs.get("draft_model_dir"), - torch_compile_max_bs=max_concurrent_requests, - max_running_requests=max_concurrent_requests, - attention_backend=kwargs.get("attention_backend"), - enable_torch_compile=kwargs.get("enable_torch_compile", False), - cuda_graph_max_bs=max_concurrent_requests, - disable_cuda_graph=False, - ) - else: - self.model = sgl.Engine( - model_path=model_dir, - skip_tokenizer_init=True, - trust_remote_code=kwargs.get("trust_remote_code", False), - mem_fraction_static=0.8, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), - tp_size=kwargs.get("tensor_parallel_size", 1), - ep_size=kwargs.get("moe_expert_parallel_size", 1), - torch_compile_max_bs=max_concurrent_requests, - max_running_requests=max_concurrent_requests, - attention_backend=kwargs.get("attention_backend"), - enable_torch_compile=kwargs.get("enable_torch_compile", False), - cuda_graph_max_bs=max_concurrent_requests, - disable_cuda_graph=False, - ) + engine_kwargs["speculative_algorithm"] = speculative_algorithm + num_draft_tokens = kwargs.get("speculative_num_draft_tokens", 4) + engine_kwargs["speculative_num_draft_tokens"] = num_draft_tokens + engine_kwargs["speculative_draft_model_path"] = kwargs.get("draft_model_dir") + if speculative_algorithm == "DFLASH": + if "speculative_dflash_draft_window_size" in kwargs: + engine_kwargs["speculative_dflash_draft_window_size"] = kwargs[ + "speculative_dflash_draft_window_size" + ] + print( + f"[specdec_bench] DFLASH ignores --draft_length / speculative_num_steps / " + f"speculative_eagle_topk; effective draft block = " + f"speculative_num_draft_tokens={num_draft_tokens}" + ) + else: + engine_kwargs["speculative_num_steps"] = kwargs.get("speculative_num_steps", 3) + engine_kwargs["speculative_eagle_topk"] = kwargs.get("speculative_eagle_topk", 1) + + # Forward any other kwargs (e.g. from runtime_params.engine_args) to + # sgl.Engine, letting yaml override the defaults set above. Skip only + # specdec_bench-internal routing keys that should never reach SGLang. + _internal_keys = frozenset({ + "speculative_algorithm", + "draft_model_dir", + "speculative_num_steps", + "speculative_eagle_topk", + "speculative_num_draft_tokens", + "speculative_dflash_draft_window_size", + "tensor_parallel_size", + "moe_expert_parallel_size", + "tokenizer_path", + "use_draft_logits", + }) + for _k, _v in kwargs.items(): + if _k in _internal_keys: + continue + engine_kwargs[_k] = _v + + self.model = sgl.Engine(**engine_kwargs) self.sampling_config = sampling_kwargs diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index a0a28d78a7f..cfa88aadb00 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -108,7 +108,7 @@ def make_speculative_data_module( raise ValueError("sample_size must be -1 (use all samples) or a positive integer") if data_args.sample_size > 0: dumped_files = dumped_files[: data_args.sample_size] - train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss) + train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer) data_collator = EagleOfflineDataCollator(train_len=train_len) return { @@ -159,10 +159,14 @@ def compute_loss(self, *args, **kwargs): self.state.training_accs = [] if not hasattr(self.state, "component_losses"): self.state.component_losses = {"eagle": [], "preservation": []} + if not hasattr(self.state, "training_stats"): + self.state.training_stats = [] kwargs.pop("num_items_in_batch", None) loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs) if hasattr(outputs, "train_acc") and any(outputs.train_acc): self.state.training_accs.append(outputs.train_acc) + if getattr(outputs, "train_stats", None): + self.state.training_stats.append(outputs.train_stats) # Track per-component losses for key, attr in [ ("eagle", "eagle_loss"), @@ -261,6 +265,22 @@ def on_log(self, args, state, control, **kwargs): print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") logs["estimated_training_ar"] = est_ar + # Aggregate dflash debug stats over the log window. + if getattr(state, "training_stats", None): + keys = set() + for s in state.training_stats: + keys.update(s.keys()) + for k in keys: + vals = [s[k] for s in state.training_stats if k in s] + if not vals: + continue + if isinstance(vals[0], list): + arr = np.array(vals) # [N_steps, P] + for j, m in enumerate(arr.mean(axis=0).tolist()): + logs[f"train_stats/{k}_pos_{j}"] = float(m) + else: + logs[f"train_stats/{k}"] = float(np.mean(vals)) + # log to wandb if wandb is not None and wandb.run is not None and is_master(): if logs: @@ -276,6 +296,8 @@ def on_log(self, args, state, control, **kwargs): state.training_accs = [] if hasattr(state, "component_losses"): state.component_losses = {"eagle": [], "preservation": []} + if hasattr(state, "training_stats"): + state.training_stats = [] return control def on_step_end(self, args, state, control, **kwargs): diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 7e399cf9603..0d3d82a31c3 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -53,7 +53,7 @@ from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 -torch.manual_seed(0) +torch.manual_seed(3) mto.enable_huggingface_checkpointing() @@ -250,6 +250,9 @@ def train(): checkpoint = training_args.resume_from_checkpoint or last_checkpoint + #NOTE: patch for k25 dflash + # checkpoint=None + use_offline_training = data_args.offline_data_path is not None if checkpoint: @@ -370,7 +373,15 @@ def train(): ) print_rank_0("Start training...") - trainer.train(resume_from_checkpoint=checkpoint) + # trainer.train(resume_from_checkpoint=checkpoint) + #NOTE:patch for k25 dflash + trainer.create_optimizer_and_scheduler(num_training_steps=training_args.max_steps) + optimizer_path = os.path.join(checkpoint, "optimizer.pt") + trainer.optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu")) + for param_group in trainer.optimizer.param_groups: + param_group["lr"] = training_args.learning_rate + print_rank_0(f"Loaded optimizer from {optimizer_path}") + trainer.train() #NOTE: patch for k25 dflash trainer.save_state() trainer.save_model(training_args.output_dir) diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index f74fcb1e9fb..60d2b0f35f3 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -35,12 +35,12 @@ """Eagle model utils.""" +from tkinter.constants import NONE from typing import Any import torch from torch.utils.data import Dataset from transformers.trainer_pt_utils import LabelSmoother - IGNORE_TOKEN_ID = LabelSmoother.ignore_index @@ -78,6 +78,51 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) +def compute_assistant_mask_kimi(tokenizer, input_ids): + """Recover the assistant mask from already-tokenized Kimi chat IDs. + + For every <|im_assistant|> token, locate the following <|im_middle|> and + matching <|im_end|>, and mark only the inner content span (exclusive of + both markers). This matches HF's generation-tag mask semantics: only the + assistant's actual reply tokens count, not role/separator markers. + + An unmatched assistant span (interrupted by a new role marker, or a + trailing generation prompt at end of sequence) is marked from + <|im_middle|>+1 up to but not including the next role marker / EOS. If + <|im_middle|> is absent within the span, nothing is marked for it. + """ + ids_list = input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) + + role_to_id = { + role: tokenizer.convert_tokens_to_ids(role) + for role in ("<|im_user|>", "<|im_assistant|>", "<|im_system|>") + } + assistant_id = role_to_id["<|im_assistant|>"] + other_role_ids = {tid for r, tid in role_to_id.items() if r != "<|im_assistant|>"} + end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + middle_id = tokenizer.convert_tokens_to_ids("<|im_middle|>") + + mask = [0] * len(ids_list) + i = 0 + n = len(ids_list) + while i < n: + if ids_list[i] != assistant_id: + i += 1 + continue + j = i + 1 + m = -1 + while j < n and ids_list[j] != end_id and ids_list[j] not in other_role_ids: + if m < 0 and ids_list[j] == middle_id: + m = j + j += 1 + if m >= 0: + for k in range(m + 1, j): + mask[k] = 1 + i = j + 1 if (j < n and ids_list[j] == end_id) else j + + return torch.tensor(mask, dtype=torch.long) + + class OfflineSupervisedDataset(Dataset): """Offline dataset for supervised fine-tuning with pre-dumped hidden states. @@ -105,34 +150,49 @@ def __init__( self, dumped_files, answer_only_loss: bool = False, + tokenizer = None, ): """Initialize with a list of .pt file paths.""" super().__init__() self.dumped_files = dumped_files self.answer_only_loss = answer_only_loss + self.tokenizer = tokenizer def __len__(self): return len(self.dumped_files) def __getitem__(self, i) -> dict[str, torch.Tensor]: - offline_data = torch.load(self.dumped_files[i], weights_only=True) + try: + offline_data = torch.load(self.dumped_files[i], weights_only=True) + except Exception as e: + print(f"Error loading {self.dumped_files[i]}: {e}, trying to load previous file") + return self.__getitem__(i-1) labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID) labels[..., :-1] = offline_data["input_ids"][..., 1:] if self.answer_only_loss: if "loss_mask" not in offline_data: - raise ValueError( - f"answer_only_loss=True requires a 'loss_mask' entry in the offline " - f".pt file, but {self.dumped_files[i]} does not have one. Re-dump " - f"with --answer-only-loss in compute_hidden_states_*.py." - ) - loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype) + loss_mask = compute_assistant_mask_kimi(self.tokenizer, offline_data["input_ids"]) + ratio = loss_mask.float().mean().item() + if ratio < 0.3: + # print(f"Drop sample id {i}, 1s ratio: {ratio:.4f}") + return self.__getitem__(i-1) + # print(f"sample id {i}, input ids length: {len(offline_data['input_ids'])}, loss_mask length: {len(loss_mask)}, 1s ratio: {loss_mask.float().mean().item():.4f}") + + # loss_mask = torch.ones_like(offline_data["input_ids"]) + # raise ValueError( + # f"answer_only_loss=True requires a 'loss_mask' entry in the offline " + # f".pt file, but {self.dumped_files[i]} does not have one. Re-dump " + # f"with --answer-only-loss in compute_hidden_states_*.py." + # ) + # loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype) else: loss_mask = torch.ones_like(offline_data["input_ids"]) + # loss_mask = torch.ones_like(offline_data["input_ids"]) ret = { - "input_ids": offline_data["input_ids"], + "input_ids": offline_data["input_ids"].to(torch.long), "base_model_hidden_states": offline_data["hidden_states"], "aux_hidden_states": offline_data["aux_hidden_states"], "attention_mask": torch.ones_like(offline_data["input_ids"]), @@ -149,14 +209,17 @@ def __init__(self, train_len): """Initialize with the target sequence length for truncation/padding.""" self.train_len = train_len - def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0): - """Pad or truncate a tensor to length along a given dimension.""" + def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0, padding_token_id: int = 0): + """Pad or truncate a tensor to length along a given dimension. + For input_ids, fill the pad with token 163839. + """ dim = dim % x.ndim # support negative dimension - # allocate output tensor + # Determine appropriate padding token + # Only use 163839 for input_ids (handled in the caller) out_shape = list(x.shape) out_shape[dim] = length - out = x.new_zeros(out_shape) + out = x.new_full(out_shape, padding_token_id) # construct copy slice slc = [slice(None)] * x.ndim @@ -168,9 +231,12 @@ def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0): def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: """Collate a list of feature dicts into a single padded/truncated batch.""" + # For input_ids, use 163839 as padding_token_id base_batch = { - k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) - for k in ["input_ids", "attention_mask", "loss_mask", "labels"] + "input_ids": torch.stack([self._pad_or_truncate(item["input_ids"], self.train_len, padding_token_id=163839) for item in features]), + "attention_mask": torch.stack([self._pad_or_truncate(item["attention_mask"], self.train_len) for item in features]), + "loss_mask": torch.stack([self._pad_or_truncate(item["loss_mask"], self.train_len) for item in features]), + "labels": torch.stack([self._pad_or_truncate(item["labels"], self.train_len) for item in features]), } base_model_outputs = { diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1760cb2072d..1a231cb87d0 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -166,9 +166,11 @@ def modify(self, config): else base_config.num_hidden_layers ) num_draft_layers = self.dflash_config.num_hidden_layers - self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + # self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + self.target_layer_ids = [1,12,24,35,47,58] self.dflash_config.target_layer_ids = self.target_layer_ids + # mask_token_id: validated by DFlashConfig, auto-detected from tokenizer context self.mask_token_id = config.dflash_mask_token_id logger.info("DFlash mask_token_id: %s", self.mask_token_id) @@ -320,7 +322,7 @@ def _compute_loss( base_logits: Base model logits for KD loss [B, seq_len, vocab], or None for CE. Returns: - (loss, accuracy) tuple. + (loss, accuracy) tuple.yes. """ bsz, seq_len = input_ids.shape block_size = self.dflash_block_size @@ -381,13 +383,26 @@ def _compute_loss( with torch.no_grad(): preds = flat_logits.argmax(dim=-1) correct = (preds == flat_targets) & (binary_eval_mask > 0.5) - accuracy = correct.sum().float() / (binary_eval_mask.sum() + 1e-6) - accuracy = accuracy.item() + accuracy = (correct.sum().float() / (binary_eval_mask.sum() + 1e-6)).item() + + # Simulated accept length: per kept block, count consecutive correct + # predictions from pos 1 until the first miss (or end of block). + # Range: [0, block_size - 1]. Averaged over block_keep_mask blocks. + correct_3d = correct.view(bsz, n_blocks, block_size).float() + cumprod = torch.cumprod(correct_3d[..., 1:], dim=-1) + per_block_accept_len = cumprod.sum(dim=-1) + kept_f = block_keep_mask.float() + sim_accept_len = ( + (per_block_accept_len * kept_f).sum() / (kept_f.sum() + 1e-6) + ).item() + + stats = {"sim_accept_len": sim_accept_len} else: loss = flat_logits.sum() * 0.0 accuracy = 0.0 + stats = {"sim_accept_len": 0.0} - return loss, accuracy + return loss, accuracy, stats def forward( self, @@ -493,7 +508,13 @@ def forward( if n_blocks == 0 or not block_keep_mask.any(): # Zero loss that still flows through dflash_module for DDP gradient sync dummy = self.dflash_module.fc.weight.sum() * 0.0 - return ModelOutput(loss=dummy, logits=base_outputs.logits, train_acc=[[0.0]]) + empty_stats = {"sim_accept_len": 0.0} + return ModelOutput( + loss=dummy, + logits=base_outputs.logits, + train_acc=[[0.0]], + train_stats=empty_stats, + ) # 4. Build draft inputs noise_embedding = self._build_noise_embedding( @@ -514,7 +535,7 @@ def forward( # 6. Compute loss and accuracy logits = self._base_model_lm_head(hidden) - loss, accuracy = self._compute_loss( + loss, accuracy, train_stats = self._compute_loss( logits, input_ids, anchor_positions, @@ -527,6 +548,7 @@ def forward( loss=loss, logits=base_outputs.logits, train_acc=[[accuracy]], + train_stats=train_stats, ) @torch.no_grad() diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml index 3d43e0fe1d4..be477e3642b 100644 --- a/modelopt_recipes/general/speculative_decoding/dflash.yaml +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -55,6 +55,12 @@ dflash: dflash_self_logit_distillation: true dflash_loss_decay_factor: 4.0 dflash_architecture_config: - num_hidden_layers: 5 + num_hidden_layers: 6 + max_position_embeddings: 262144 + intermediate_size: 18432 + num_attention_heads: 64 + num_key_value_heads: 8 + rope_theta: 50000.0 + rms_norm_eps: 1e-05 # mask_token_id: auto-detected from model vocab (override for specific models) # sliding_window and layer_types are inherited from base model config automatically