Skip to content

Commit 436bf95

Browse files
committed
chore: Merge remote-tracking branch 'origin/main' into hf_checkpoint_conversion_for_fsdp2
2 parents 1b2aca5 + d328e2e commit 436bf95

8 files changed

Lines changed: 440 additions & 9 deletions

File tree

docs/components/components.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
| scheduler | constant_lr | [ConstantLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ConstantLR.html#torch.optim.lr_scheduler.ConstantLR)| [ConstantLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Multiplies the learning rate of each parameter group by a small constant factor until the number of steps reaches a pre-defined milestone |
4141
| scheduler | onecycle_lr | [OneCycleLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR)| [OneCycleLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Sets the learning rate of each parameter group according to the 1cycle learning rate policy. |
4242
| scheduler | cosine_annealing_lr | [CosineAnnealingLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR)| [CosineAnnealingLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Set the learning rate of each parameter group using a cosine annealing schedule |
43+
| scheduler | linear_warmup_cosine_annealing_lr | [LinearWarmupCosineAnnealingLRScheduler](../../src/modalities/optimizers/lr_schedulers.py) | [LinearWarmupCosineAnnealingLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Linearly warms up to the base learning rate, then decays with cosine annealing for the remaining training steps |
4344

4445

4546
## Tokenization
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import json
5+
import os
6+
import re
7+
from pathlib import Path
8+
from typing import cast
9+
10+
import torch
11+
import torch.distributed as dist
12+
from pydantic import BaseModel
13+
from torch.distributed.device_mesh import DeviceMesh
14+
from torch.distributed.tensor import DTensor
15+
16+
from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading
17+
from modalities.checkpointing.stateful.app_state import AppState
18+
from modalities.config.config import ProcessGroupBackendType
19+
from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticDeviceMeshIFType
20+
from modalities.main import Main
21+
from modalities.running_env.cuda_env import CudaEnv
22+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method
23+
24+
25+
class ComponentsInstantiationModel(BaseModel):
26+
app_state: PydanticAppStateType
27+
device_mesh: PydanticDeviceMeshIFType | None = None
28+
29+
30+
def _parse_args() -> argparse.Namespace:
31+
parser = argparse.ArgumentParser(description="Load one or more Modalities DCP checkpoints into an app state.")
32+
parser.add_argument("--config-file-path", type=Path, required=True, help="Path to the YAML config file.")
33+
parser.add_argument(
34+
"--experiments-root-path",
35+
type=Path,
36+
required=True,
37+
help="Path passed to Main for resolver/context setup.",
38+
)
39+
parser.add_argument(
40+
"--checkpoint-dir-paths",
41+
type=Path,
42+
nargs="+",
43+
required=True,
44+
help="Paths to multiple checkpoint directories containing *.distcp files.",
45+
)
46+
parser.add_argument(
47+
"--json-output-path",
48+
type=Path,
49+
default=Path("layer_norms_across_checkpoints.json"),
50+
help="Output path for raw per-checkpoint norms as JSON.",
51+
)
52+
return parser.parse_args()
53+
54+
55+
def _resolve_checkpoint_dir_paths(args: argparse.Namespace) -> list[Path]:
56+
return list(args.checkpoint_dir_paths)
57+
58+
59+
def _normalize_parameter_name(parameter_name: str) -> str:
60+
name = parameter_name
61+
for prefix in ("module.", "_orig_mod.", "_fsdp_wrapped_module."):
62+
if name.startswith(prefix):
63+
name = name[len(prefix) :]
64+
return name
65+
66+
67+
def _get_dp_shard_group(device_mesh: DeviceMesh | None):
68+
if device_mesh is None:
69+
return None
70+
try:
71+
return get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.DP_SHARD).get_group()
72+
except Exception:
73+
# Fallback to the default process group if a dedicated DP-shard group is unavailable.
74+
return None
75+
76+
77+
def _compute_and_print_parameter_norms(app_state: AppState, dp_shard_group) -> dict[str, float]:
78+
parameter_sq_sums: dict[str, torch.Tensor] = {}
79+
80+
for model_part_idx, model_part in enumerate(app_state.model_parts):
81+
for name, parameter in model_part.named_parameters():
82+
if not parameter.requires_grad:
83+
continue
84+
raw_name = f"model_part_{model_part_idx}.{name}" if len(app_state.model_parts) > 1 else name
85+
parameter_name = _normalize_parameter_name(raw_name)
86+
87+
# FSDP2 parameters can be DTensors. Convert to local shard first so c10d all_reduce
88+
# operates on plain tensors instead of DTensors.
89+
local_param = parameter.to_local() if isinstance(parameter, DTensor) else parameter
90+
local_sq_sum = local_param.detach().float().pow(2).sum()
91+
parameter_sq_sums[parameter_name] = local_sq_sum
92+
93+
# Aggregate over the DP-shard group to reconstruct global norms for sharded parameters.
94+
for parameter_name, sq_sum in parameter_sq_sums.items():
95+
dist.all_reduce(sq_sum, op=dist.ReduceOp.SUM, group=dp_shard_group)
96+
parameter_sq_sums[parameter_name] = sq_sum
97+
98+
parameter_norms = {name: torch.sqrt(sq_sum).item() for name, sq_sum in parameter_sq_sums.items()}
99+
100+
if dist.get_rank() == 0:
101+
print("Per-parameter L2 norms (global across DP-shards):")
102+
for parameter_name in sorted(parameter_norms):
103+
print(f"{parameter_name}: {parameter_norms[parameter_name]:.6f}")
104+
105+
return parameter_norms
106+
107+
108+
def _extract_checkpoint_label(checkpoint_dir_path: Path) -> str:
109+
match = re.search(r"seen_steps_(\d+)", checkpoint_dir_path.name)
110+
if match:
111+
return f"steps_{match.group(1)}"
112+
return checkpoint_dir_path.name
113+
114+
115+
def _save_json_results(results: list[dict], output_path: Path) -> None:
116+
output_path.parent.mkdir(parents=True, exist_ok=True)
117+
with open(output_path, "w", encoding="utf-8") as f:
118+
json.dump(results, f, indent=2)
119+
120+
121+
def main() -> None:
122+
args = _parse_args()
123+
checkpoint_dir_paths = _resolve_checkpoint_dir_paths(args)
124+
125+
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
126+
rank = dist.get_rank()
127+
collected_results: list[dict] = []
128+
129+
for checkpoint_dir_path in checkpoint_dir_paths:
130+
# Rebuild components per checkpoint because AppState only supports one load call.
131+
main_obj = Main(
132+
config_path=args.config_file_path,
133+
experiments_root_path=args.experiments_root_path,
134+
)
135+
components = cast(
136+
ComponentsInstantiationModel,
137+
main_obj.build_components(components_model_type=ComponentsInstantiationModel),
138+
)
139+
140+
app_state = cast(AppState, getattr(components, "app_state"))
141+
device_mesh = cast(DeviceMesh | None, getattr(components, "device_mesh", None))
142+
143+
loader = DCPCheckpointLoading(global_rank=rank)
144+
loader.load_checkpoint_(app_state=app_state, checkpoint_dir_path=checkpoint_dir_path)
145+
146+
dp_shard_group = _get_dp_shard_group(device_mesh)
147+
if rank == 0:
148+
print(f"\n=== {checkpoint_dir_path} ===")
149+
parameter_norms = _compute_and_print_parameter_norms(app_state, dp_shard_group)
150+
151+
if rank == 0:
152+
collected_results.append(
153+
{
154+
"checkpoint_path": str(checkpoint_dir_path),
155+
"checkpoint_label": _extract_checkpoint_label(checkpoint_dir_path),
156+
"parameter_norms": parameter_norms,
157+
}
158+
)
159+
print(
160+
f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} "
161+
f"(pid={os.getpid()})."
162+
)
163+
164+
if rank == 0:
165+
_save_json_results(collected_results, args.json_output_path)
166+
print(f"Saved raw parameter norms JSON to {args.json_output_path}")
167+
168+
169+
if __name__ == "__main__":
170+
main()
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import json
5+
import re
6+
from pathlib import Path
7+
8+
import matplotlib.pyplot as plt
9+
from matplotlib.backends.backend_pdf import PdfPages
10+
11+
12+
def _parse_args() -> argparse.Namespace:
13+
parser = argparse.ArgumentParser(description="Plot parameter norms across checkpoints from a JSON log file.")
14+
parser.add_argument(
15+
"--layer-norms-json-path",
16+
type=Path,
17+
required=True,
18+
help="Path to JSON produced by scripts/compute_layer_norms.py.",
19+
)
20+
parser.add_argument(
21+
"--plot-output-path",
22+
type=Path,
23+
default=Path("parameter_norms_grouped_by_layer.pdf"),
24+
help="Output PDF path containing one plot page per layer.",
25+
)
26+
parser.add_argument(
27+
"--layer-filter-regex",
28+
type=str,
29+
default=r".*",
30+
help="Regex to select layer keys in the visualization.",
31+
)
32+
return parser.parse_args()
33+
34+
35+
def _load_results(path: Path) -> list[dict]:
36+
with open(path, "r", encoding="utf-8") as f:
37+
results = json.load(f)
38+
if not isinstance(results, list) or not results:
39+
raise ValueError("Expected a non-empty JSON list of checkpoint results.")
40+
return results
41+
42+
43+
def _extract_layer_key(parameter_name: str) -> str:
44+
tokens = parameter_name.split(".")
45+
for i in range(len(tokens) - 1):
46+
if tokens[i] in {"h", "layers", "blocks"} and tokens[i + 1].isdigit():
47+
if i > 0:
48+
return ".".join(tokens[i - 1 : i + 2])
49+
return ".".join(tokens[i : i + 2])
50+
return ".".join(tokens[:-1]) if len(tokens) > 1 else parameter_name
51+
52+
53+
def _layer_sort_key(layer_key: str) -> tuple:
54+
# Prefer numeric ordering for transformer block keys like h.0, layers.12, blocks.3.
55+
match = re.search(r"(?:^|\.)(h|layers|blocks)\.(\d+)(?:\.|$)", layer_key)
56+
if match:
57+
return (0, match.group(1), int(match.group(2)), layer_key)
58+
return (1, layer_key)
59+
60+
61+
def _plot_checkpoint_comparison(
62+
results: list[dict],
63+
plot_output_path: Path,
64+
layer_filter_regex: str,
65+
) -> None:
66+
metric_key = "parameter_norms" if "parameter_norms" in results[0] else "layer_norms"
67+
layer_pattern = re.compile(layer_filter_regex)
68+
filtered_parameters = sorted(
69+
{
70+
parameter_name
71+
for checkpoint_result in results
72+
for parameter_name in checkpoint_result[metric_key].keys()
73+
if layer_pattern.search(parameter_name)
74+
}
75+
)
76+
if not filtered_parameters:
77+
raise ValueError(f"No layer names matched --layer-filter-regex={layer_filter_regex!r}.")
78+
79+
checkpoint_labels = [checkpoint_result["checkpoint_label"] for checkpoint_result in results]
80+
81+
grouped_parameters: dict[str, list[str]] = {}
82+
for parameter_name in filtered_parameters:
83+
layer_key = _extract_layer_key(parameter_name)
84+
grouped_parameters.setdefault(layer_key, []).append(parameter_name)
85+
ordered_layer_keys = sorted(grouped_parameters, key=_layer_sort_key)
86+
87+
plot_output_path.parent.mkdir(parents=True, exist_ok=True)
88+
with PdfPages(plot_output_path) as pdf:
89+
# First page: quick summary of layers and parameter counts.
90+
summary_lines = [
91+
f"checkpoints: {len(checkpoint_labels)}",
92+
f"layers: {len(grouped_parameters)}",
93+
f"parameters plotted: {len(filtered_parameters)}",
94+
"",
95+
"Layer -> #parameters",
96+
]
97+
for layer_key in ordered_layer_keys:
98+
summary_lines.append(f"{layer_key}: {len(grouped_parameters[layer_key])}")
99+
100+
fig, ax = plt.subplots(figsize=(10, 12))
101+
ax.axis("off")
102+
ax.text(0.01, 0.99, "\n".join(summary_lines), va="top", ha="left", fontsize=10)
103+
fig.tight_layout()
104+
pdf.savefig(fig)
105+
plt.close(fig)
106+
107+
# One page per layer with all parameter curves for that layer.
108+
x = list(range(len(checkpoint_labels)))
109+
for layer_key in ordered_layer_keys:
110+
parameter_names = sorted(grouped_parameters[layer_key])
111+
fig, ax = plt.subplots(figsize=(12, 6))
112+
for parameter_name in parameter_names:
113+
y = [checkpoint_result[metric_key].get(parameter_name, float("nan")) for checkpoint_result in results]
114+
short_name = (
115+
parameter_name[len(layer_key) + 1 :]
116+
if parameter_name.startswith(layer_key + ".")
117+
else parameter_name
118+
)
119+
ax.plot(x, y, marker="o", linewidth=1.5, label=short_name)
120+
121+
ax.set_title(f"{layer_key} parameter norms across checkpoints")
122+
ax.set_xlabel("Checkpoint")
123+
ax.set_ylabel("L2 norm")
124+
ax.set_xticks(x)
125+
ax.set_xticklabels(checkpoint_labels, rotation=45, ha="right")
126+
ax.grid(True, alpha=0.25)
127+
ax.legend(loc="best", fontsize=8)
128+
fig.tight_layout()
129+
pdf.savefig(fig)
130+
plt.close(fig)
131+
132+
133+
def main() -> None:
134+
args = _parse_args()
135+
results = _load_results(args.layer_norms_json_path)
136+
_plot_checkpoint_comparison(
137+
results=results,
138+
plot_output_path=args.plot_output_path,
139+
layer_filter_regex=args.layer_filter_regex,
140+
)
141+
print(f"Saved grouped parameter-norm plots to {args.plot_output_path}")
142+
143+
144+
if __name__ == "__main__":
145+
main()

src/modalities/config/config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class OneCycleLRSchedulerConfig(BaseModel):
188188
steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
189189
pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)]
190190
anneal_strategy: str
191-
cycle_momentum: bool = True
191+
cycle_momentum: bool = False
192192
base_momentum: Annotated[float, Field(strict=True, gt=0)] | list[
193193
Annotated[float, Field(strict=True, gt=0.0)]
194194
] = 0.85
@@ -229,6 +229,22 @@ class CosineAnnealingLRSchedulerConfig(BaseModel):
229229
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
230230

231231

232+
class LinearWarmupCosineAnnealingLRSchedulerConfig(BaseModel):
233+
optimizer: PydanticOptimizerIFType
234+
warmup_steps: Annotated[int, Field(strict=True, gt=0)]
235+
total_steps: Annotated[int, Field(strict=True, gt=0)]
236+
initial_lr: Annotated[float, Field(strict=True, ge=0.0)]
237+
final_lr: Annotated[float, Field(strict=True, ge=0.0)]
238+
max_lr: Annotated[float, Field(strict=True, ge=0.0)]
239+
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
240+
241+
@model_validator(mode="after")
242+
def check_total_steps_greater_than_warmup_steps(self) -> "LinearWarmupCosineAnnealingLRSchedulerConfig":
243+
if self.total_steps <= self.warmup_steps:
244+
raise ValueError("total_steps must be greater than warmup_steps.")
245+
return self
246+
247+
232248
class FSDP1CheckpointedOptimizerConfig(BaseModel):
233249
checkpoint_loading: PydanticFSDP1CheckpointLoadingIFType
234250
checkpoint_path: Path

0 commit comments

Comments
 (0)