Skip to content

add high precision init weights to fully_shard example#2785

Draft
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:worktree-pstjohn/clean-up-example
Draft

add high precision init weights to fully_shard example#2785
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:worktree-pstjohn/clean-up-example

Conversation

@pstjohn
Copy link
Copy Markdown
Contributor

@pstjohn pstjohn commented Mar 20, 2026

WIP, would like to also add tests around preserve_high_precision_init_val

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
# By default, FusedAdam initializes master weights by dequantizing
# the FP8 parameters, which introduces quantization noise. Instead,
# we seed them from the original BF16 init values preserved in step 2.
for param in model.parameters():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if params_dtype=torch.float32 would fix this issue and not require us to manually retrieve BF16 weights and manually cast them to FP32?

# Load checkpoint back. Provide empty state dict containers with the
# Load checkpoint back. Provide empty state dict containers with the
# same structure; DCP fills them from the saved files.
state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a requirement to store the model weights? If the optimizer weights hold our FP32 master weights, and we simply quantize those to build our FP4/FP8 weights, then why do we need to store the model weights (quantized in FP4/FP8)?

for key, value in full_model_state.items():
if key in opt_param_states and "master_param" in opt_param_states[key]:
# Prefer optimizer's FP32 master weight (maintained throughout training).
# Prefer optimizer's FP32 master weight.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into a utility function? So we dont have to manually rip out the parts from the optimizer?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants