Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion examples/puzzletron/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ To use the Puzzle algorithm effectively, we need to specify the target number of

In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md).

> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).
> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For KV-head pruning see [`llama-3_1-8B_pruneattn_runtime`](./configs/llama-3_1-8B_pruneattn_runtime/) and the [Attention Pruning](#attention-pruning-kv-head-reduction) and [Runtime-Based Latency Optimization](#runtime-based-latency-optimization) sections below. For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).

## Environment

Expand Down Expand Up @@ -343,6 +343,33 @@ See [Megatron-Bridge distillation](../megatron_bridge/README.md#distillation) fo

For distillation results on Puzzletron-compressed models, see [examples/pruning/puzzletron/](../pruning/puzzletron/README.md).

## Runtime-Based Latency Optimization

By default, subblock statistics use the `trt_torch` backend with theoretical memory proxies. You can instead enable **runtime stats** to measure actual inference latency via vLLM, which unlocks latency-based MIP constraints:

```yaml
calc_subblock_stats:
runtime_stats:
enabled: true
synth_dataset_num_requests: 32
backend: vllm
num_warmup_iters: 2
num_iters: 10
batch_size: 1

mip:
human_constraints:
target_latency: 20 # seconds
```

Because vLLM startup adds substantial overhead during stats collection, extend the distributed process group timeout accordingly:

```yaml
dist_timeout_minutes: 60 # default is 10 if omitted
```

This field is supported in any Puzzletron YAML config and overrides the default 10-minute distributed timeout.

## Advanced Usage

Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios.
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
defaults:
- pruning: ffn_pruning
- scoring: ../validate_solutions_defaults
- realize_model: ../validate_solutions_defaults
- bypass:
- override hydra/hydra_logging: disabled
- _self_

puzzle_dir: ???
descriptor: llama
teacher_dir: ${puzzle_dir}/ckpts/teacher/
replacement_library_path: ${puzzle_dir}/replacement_library.json
dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2

skip_realize_model: false

build_replacement_library:
add_ffn_no_ops: true
add_attention_no_ops: true

calc_subblock_stats:
batch_sizes: [64, 96, 128]
prefill_seq_len: 4096
generation_seq_len: 4096
num_active_tokens_override: # Optional override for sequence lengths
prefill_queue_size: 0
allocate_prefill_query: false
benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
merge_with_existing_stats: true
subblock_stats_filename: "subblock_stats.json"
moe_stats_filename: "moe_stats.json"
runtime_stats:
backend: trt_torch

scoring:
descriptor: ${descriptor}
solutions_to_validate:
skip_existing_solutions: true

replacement_library_path: ${replacement_library_path}
solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
teacher_dir: ${to_path:${teacher_dir}}
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation

eval_samples: 8
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

mip:
single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
gathered_metrics_path:
puzzle_profile:

# puzzle_profile:
objective: metrics.cosine_embedding_loss_hidden_states
bigger_is_better: false

subblock_stats_args:
- batch_size: 96
weights_dtype: torch.bfloat16
activations_dtype: torch.bfloat16
kv_cache_dtype: torch.bfloat16

report_additional_costs:
- stats.memory_mib
- stats.num_params
- stats.num_kv_heads
- stats.has_attention
- stats.has_ffn
- stats.kv_cache_memory_mib
- stats.attention_memory_mib
- stats.ffn_memory_mib
- stats.ffn_num_params
- stats.attention_num_params

mip_constraints:
metric_overrides:
max_seconds_per_solution: 60

realize_model:
descriptor: ${descriptor}
teacher_dir: ${to_path:${teacher_dir}}
tokenizer_name: ${to_path:${teacher_dir}}
replacement_library_path: ${replacement_library_path}
save_models: true
solutions_path: # Filled dynamically

# Validate params
skip_validation: false # To enable validation of the model solution set `skip_validation` as False
eval_samples: 128
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

nccl_timeout_minutes: ${timedelta_minutes:10}

# This section redirects Hydra outputs
hydra:
run:
dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
defaults:
- Llama-3_1-8B
- _self_

# Input Hugging Face model to compress
input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct

# Dataset path for pruning and NAS scoring
dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2

# Working directory for puzzletron outputs
puzzle_dir: /workspace/puzzle_dir

dist_timeout_minutes: 60

calc_subblock_stats:
runtime_stats:
enabled: true
synth_dataset_num_requests: 32
backend: vllm
num_warmup_iters: 2
num_iters: 10
batch_size: 1

# MIP memory constraint (in MiB)
mip:
human_constraints:
target_latency: 21

# FFN intermediate sizes to search over (heterogeneous architecture)
pruning:
intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
defaults:
- pruning_defaults

hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IndependentKvHeadContributionHook}

activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}

pruning_mixin:
_target_: modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin.KVHeadsPruningMixIn
layer_descriptor:
_target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaKVHeadsLayerDescriptor

activation_hooks_kwargs:
method: independent_kv_head_contribution
optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
target_layer: "self_attn.o_proj"
layer_input_descriptors_path:

# n_heads_in_group: 4
# num_attention_heads: 32 # num query heads
# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group
n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
gqa_init_mode: "PruneKVHeads"
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
defaults:
- pruning_defaults

pruning_mixin:
_target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
layer_descriptor:
_target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor

hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}

activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}

activation_hooks_kwargs:
method: iterative
target_layer: "mlp.down_proj"
layer_input_descriptors_path:

intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
mlp_init_mode: "PruneByActivationsLog"
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults:
- pruning_defaults

activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}

activation_hooks_kwargs:
method: layer_norm_contribution
target_layer: "layernorm"

# Hidden dimension pruning specific settings
hidden_size_list: [3072, 2048] # Target hidden sizes to prune to
hidden_size_init_mode: "PruneByChannelRanking"
mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
linear_init_mode: "FromTeacher"
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
defaults:
- /validate_model_defaults

descriptor: ${descriptor}
model_name_or_path: ${teacher_dir}
experiment_id: ${pruning.eval_samples}samples_diverse_mini
activations_log_dir: ???
activation_hooks_kwargs: ???

# Data:
eval_samples: 1000 # default is 10000
micro_batch_size: 4
dataset_path: ${dataset_path}
val_dataset_name: train

# Prune ckpts
pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}

## FFN pruning
ffn_list:
mlp_init_mode: "Truncate" # PruneByActivationsLog

## KV-heads pruning
n_heads_in_group_list:
gqa_init_mode: "AverageKV"

## Hidden dimension pruning
hidden_size_list:
hidden_size_init_mode: "PruneByChannelRanking"
linear_init_mode: "FromTeacher"

mlp_init_config_yaml:
activations_log_dir: ${pruning.activations_log_dir}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
block_size: 8192
bos_rate: 0.5
data_column: messages
val_dataset_name: validation
shuffle_seed: 81436
seed: 42
fim_rate: 0
fim_spm_rate: 0
source_datasets_to_discard:
varlen: false
write_results: false
calc_losses_on_cpu: false
activations_log_dir:
model_name_or_path:
load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults:
- /validate_model_defaults
- _self_

solutions_to_validate:
skip_validation: false
save_models: false
bigger_is_better: false
sort_solutions_by:
calculate_full_score_ablations: false
15 changes: 14 additions & 1 deletion examples/puzzletron/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,20 @@ def run_full_puzzletron(hydra_config_path: str):
config_path: Path to the YAML configuration file
"""
mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
dist.setup(timeout=timedelta(minutes=10))
# Read the Hydra config to determine runtime_stats:enabled, and set the timeout accordingly
from omegaconf import OmegaConf

# Resolve absolute path for Hydra config
hydra_config_path = Path(hydra_config_path).resolve()
hydra_config = OmegaConf.load(str(hydra_config_path))

# Default timeout: 10 minutes, or extended to dist_timeout_minutes if set in config
if hasattr(hydra_config, "dist_timeout_minutes"):
timeout_minutes = timedelta(minutes=hydra_config.dist_timeout_minutes)
else:
timeout_minutes = timedelta(minutes=10)
mtpz.tools.mprint(f"Puzzletron Progress 1/8: Timeout minutes: {timeout_minutes}")
dist.setup(timeout=timeout_minutes)

# Register Hydra custom resolvers (needed for config resolution)
mtpz.tools.register_hydra_resolvers()
Expand Down
23 changes: 23 additions & 0 deletions modelopt/torch/nas/subblock_stats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Subblock runtime statistics API for ModelOpt NAS.

This module provides utilities for measuring and calculating runtime statistics
of subblocks (e.g., Attention, FFN) within transformer architectures.

Primary API:
- calc_runtime_for_subblocks: Empirically measures runtime for candidate subblock configurations
"""
from .calc_runtime_stats import calc_runtime_for_subblocks
Loading
Loading