-
Notifications
You must be signed in to change notification settings - Fork 681
add high precision init weights to fully_shard example #2785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,8 +13,11 @@ | |
| local shards on each rank's GPU. | ||
| 2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization | ||
| (actual quantization happens in ``reset_parameters`` after sharding). | ||
| 3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. | ||
| 4. ``FusedAdam`` with FP32 master weights for full-precision training updates. | ||
| 3. ``preserve_high_precision_init_val`` -- Keeps the original BF16 weight | ||
| values on CPU so they can seed the optimizer's FP32 master weights, | ||
| avoiding the precision loss of round-tripping through FP8. | ||
| 4. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. | ||
| 5. ``FusedAdam`` with FP32 master weights for full-precision training updates. | ||
|
|
||
| .. note:: | ||
| ``fuse_wgrad_accumulation`` is **not** used here. That feature writes | ||
|
|
@@ -38,10 +41,9 @@ | |
| from torch.distributed.tensor import DTensor | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| from transformer_engine.pytorch import QuantizedTensor | ||
| from transformer_engine.pytorch.module.base import TransformerEngineBaseModule | ||
|
|
||
| # ── Configuration (matches main.py) ────────────────────────────────── | ||
| # ── Configuration ──────────────────────────────────────────────────── | ||
| HIDDEN_SIZE = 256 | ||
| FFN_HIDDEN_SIZE = 1024 | ||
| NUM_ATTENTION_HEADS = 8 | ||
|
|
@@ -60,10 +62,6 @@ def dist_print(msg): | |
|
|
||
| def main(): | ||
| # ── 1. Distributed setup ───────────────────────────────────────── | ||
| assert "TORCHELASTIC_RUN_ID" in os.environ, ( | ||
| "This script must be launched with torchrun, e.g.:\n" | ||
| " torchrun --nproc-per-node 2 fully_shard.py" | ||
| ) | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| local_rank = int(os.environ["LOCAL_RANK"]) | ||
|
|
||
|
|
@@ -74,10 +72,12 @@ def main(): | |
| torch.manual_seed(42) | ||
| torch.cuda.manual_seed(42) | ||
|
|
||
| # ── 2. Create model on meta device (zero memory) ──────────────── | ||
| # quantized_model_init sets the flag for FP8 weight initialization, | ||
| # but with device="meta" no actual memory is allocated yet. | ||
| with te.quantized_model_init(enabled=True): | ||
| # ── 2. Create model on meta device (zero memory) ───────────────── | ||
| # quantized_model_init flags parameters for FP8 quantization. | ||
| # preserve_high_precision_init_val=True saves the original BF16 | ||
| # values on CPU so they can seed optimizer master weights later, | ||
| # avoiding the precision loss of dequantizing from FP8. | ||
| with te.quantized_model_init(enabled=True, preserve_high_precision_init_val=True): | ||
| model = torch.nn.Sequential( | ||
| *[ | ||
| te.TransformerLayer( | ||
|
|
@@ -93,52 +93,51 @@ def main(): | |
| for _ in range(NUM_LAYERS) | ||
| ] | ||
| ) | ||
|
|
||
| # Verify all parameters are on meta device (no GPU memory used). | ||
| for name, param in model.named_parameters(): | ||
| assert param.device == torch.device("meta"), f"{name} is not on meta device" | ||
| dist_print("Model created on meta device (zero GPU memory).") | ||
|
|
||
| # ── 3. FSDP2 sharding ──────────────────────────────────────────── | ||
| # Apply sharding to the meta-device model. FSDP2 wraps parameters | ||
| # ── 3. FSDP2 sharding ─────────────────────────────────────────── | ||
| # Apply sharding to the meta-device model. FSDP2 wraps parameters | ||
| # as DTensors but no GPU memory is allocated yet. | ||
| mesh = DeviceMesh("cuda", list(range(world_size))) | ||
| for child in model.children(): | ||
| fully_shard(child, mesh=mesh) | ||
| fully_shard(model, mesh=mesh) | ||
| dist_print("FSDP2 sharding applied to meta-device model.") | ||
|
|
||
| # ── 4. Materialize parameters on GPU ────────────────────────────── | ||
| # ── 4. Materialize parameters on GPU ───────────────────────────── | ||
| # reset_parameters() on each TE module materializes the local shard | ||
| # on CUDA, applies weight initialization, and quantizes to FP8. | ||
| # Because preserve_high_precision_init_val=True, the pre-quantization | ||
| # BF16 values are saved on CPU for each local shard. | ||
| for module in model.modules(): | ||
| if isinstance(module, TransformerEngineBaseModule): | ||
| module.reset_parameters() | ||
| dist_print("Parameters materialized on GPU.") | ||
|
|
||
| # Post-materialization verification. | ||
| for name, param in model.named_parameters(): | ||
| assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding" | ||
| qt_count = sum( | ||
| 1 | ||
| for _, p in model.named_parameters() | ||
| if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) | ||
| ) | ||
| assert qt_count > 0, "No QuantizedTensor local tensors after materialization" | ||
| dist_print( | ||
| f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params " | ||
| "wrapped in DTensors." | ||
| ) | ||
|
|
||
| # ── 5. Optimizer ───────────────────────────────────────────────── | ||
| # ── 5. Optimizer with FP32 master weights ──────────────────────── | ||
| optimizer = te.optimizers.FusedAdam( | ||
| model.parameters(), | ||
| lr=1e-3, | ||
| master_weights=True, | ||
| master_weight_dtype=torch.float32, | ||
| ) | ||
| dist_print("Using FusedAdam with master_weights=True.") | ||
|
|
||
| # ── 6. Training loop ───────────────────────────────────────────── | ||
| # ── 6. Seed master weights from high-precision init values ─────── | ||
| # By default, FusedAdam initializes master weights by dequantizing | ||
| # the FP8 parameters, which introduces quantization noise. Instead, | ||
| # we seed them from the original BF16 init values preserved in step 2. | ||
| for param in model.parameters(): | ||
| optimizer.initialize_state(param, store_param_remainders=False) | ||
| local = param._local_tensor if isinstance(param, DTensor) else param | ||
| hp_val = getattr(local, "get_high_precision_init_val", lambda: None)() | ||
| if hp_val is not None: | ||
| optimizer.set_scaled_state( | ||
| param, "master_param", hp_val.to(device=device, dtype=torch.float32) | ||
| ) | ||
| local.clear_high_precision_init_val() | ||
| dist_print("Optimizer master weights seeded from high-precision init values.") | ||
|
|
||
| # ── 7. Training loop ───────────────────────────────────────────── | ||
| x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) | ||
| target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) | ||
|
|
||
|
|
@@ -153,56 +152,22 @@ def main(): | |
| optimizer.step() | ||
| dist_print(f" Step {step}: loss = {loss.item():.6f}") | ||
|
|
||
| # ── 7. Post-training assertions ────────────────────────────────── | ||
| dist_print("\nVerifying invariants ...") | ||
|
|
||
| qt_after = 0 | ||
| for name, param in model.named_parameters(): | ||
| assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" | ||
| if isinstance(param._local_tensor, QuantizedTensor): | ||
| qt_after += 1 | ||
| assert qt_after > 0, "No QuantizedTensor local tensors after training" | ||
| dist_print(f" {qt_after} params still have QuantizedTensor local tensors.") | ||
|
|
||
| # Optimizer states: master weights and moments should be float32. | ||
| for param in model.parameters(): | ||
| state = optimizer.state[param] | ||
| if "master_param" in state: | ||
| assert ( | ||
| state["master_param"].dtype == torch.float32 | ||
| ), f"Master weight dtype {state['master_param'].dtype}, expected float32" | ||
| assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" | ||
| assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" | ||
|
|
||
| dist_print("All assertions passed!") | ||
| dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor") | ||
| dist_print(" - Optimizer master weights: float32") | ||
| dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32") | ||
|
|
||
| # ── 8. Distributed checkpoint: save and load ───────────────────── | ||
| # torch.distributed.checkpoint (DCP) saves sharded state — each rank | ||
| # writes only its local shard. This preserves FP8 compute weights | ||
| # and the full optimizer state (master weights, moments, step count). | ||
| # writes only its local shard, preserving FP8 compute weights and | ||
| # the full optimizer state (master weights, moments, step count). | ||
| import torch.distributed.checkpoint as dcp | ||
| from torch.distributed.checkpoint.state_dict import ( | ||
| StateDictOptions, | ||
| get_model_state_dict, | ||
| get_optimizer_state_dict, | ||
| ) | ||
|
|
||
| # Use a fixed path so all ranks agree on the checkpoint location. | ||
| checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint" | ||
| dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...") | ||
|
|
||
| # Save sharded checkpoint. DCP handles DTensor shards natively — | ||
| # each rank writes only its local shard to the filesystem. | ||
| dcp.save( | ||
| {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, | ||
| checkpoint_id=checkpoint_dir, | ||
| ) | ||
| dist_print(" Checkpoint saved (FP8 weights + optimizer state).") | ||
|
|
||
| # Load checkpoint back. Provide empty state dict containers with the | ||
| # Load checkpoint back. Provide empty state dict containers with the | ||
| # same structure; DCP fills them from the saved files. | ||
| state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a requirement to store the model weights? If the optimizer weights hold our FP32 master weights, and we simply quantize those to build our FP4/FP8 weights, then why do we need to store the model weights (quantized in FP4/FP8)? |
||
| dcp.load(state_to_load, checkpoint_id=checkpoint_dir) | ||
|
|
@@ -225,6 +190,11 @@ def main(): | |
| # authoritative FP32 values (more precise than dequantizing FP8). | ||
| # All ranks must participate in gathering; only rank 0 saves. | ||
| from safetensors.torch import save_file | ||
| from torch.distributed.checkpoint.state_dict import ( | ||
| StateDictOptions, | ||
| get_model_state_dict, | ||
| get_optimizer_state_dict, | ||
| ) | ||
|
|
||
| full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True) | ||
|
|
||
|
|
@@ -238,10 +208,10 @@ def main(): | |
|
|
||
| for key, value in full_model_state.items(): | ||
| if key in opt_param_states and "master_param" in opt_param_states[key]: | ||
| # Prefer optimizer's FP32 master weight (maintained throughout training). | ||
| # Prefer optimizer's FP32 master weight. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we put this into a utility function? So we dont have to manually rip out the parts from the optimizer? |
||
| fp32_state[key] = opt_param_states[key]["master_param"].float() | ||
| elif isinstance(value, QuantizedTensor): | ||
| # Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off). | ||
| elif isinstance(value, te.QuantizedTensor): | ||
| # Fallback: dequantize FP8 → FP32. | ||
| fp32_state[key] = value.dequantize().float() | ||
| else: | ||
| # Non-FP8 params (e.g. LayerNorm weights): cast to FP32. | ||
|
|
@@ -251,14 +221,6 @@ def main(): | |
| save_file(fp32_state, save_path) | ||
| dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}") | ||
|
|
||
| # Quick verification: all saved tensors are float32. | ||
| from safetensors.torch import load_file | ||
|
|
||
| loaded = load_file(save_path) | ||
| for k, v in loaded.items(): | ||
| assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}" | ||
| dist_print(f" Verified: all {len(loaded)} tensors are float32.") | ||
|
|
||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know if params_dtype=torch.float32 would fix this issue and not require us to manually retrieve BF16 weights and manually cast them to FP32?