Conversation
There was a problem hiding this comment.
Pull request overview
Adds “full validation” support to PyTorch training, including config argcheck validation and checkpoint rotation on best metric.
Changes:
- Introduces
FullValidatorto run periodic full-dataset validation, logval.log, and optionally save/rotate best checkpoints. - Extends
deepmd.utils.argcheck.normalize()to validate full-validation configs and supported metrics/prefactors. - Adds unit/integration tests covering metric parsing/start-step resolution, argcheck failures, and best-checkpoint rotation.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/test_validation.py | Adds unit tests for helper functions and argcheck validation for full validation. |
| source/tests/pt/test_training.py | Adds trainer-level tests for full validation behavior and rejection paths (spin/multi-task). |
| deepmd/utils/argcheck.py | Adds validating config schema and cross-field validation for full validation. |
| deepmd/pt/train/validation.py | Implements FullValidator and full-validation metric/logging utilities. |
| deepmd/pt/train/training.py | Wires FullValidator into the training loop and enforces runtime constraints. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a runtime full-training validation subsystem: new validator module, CLI/config schema and validators, Trainer wiring to run periodic full validation, top‑K best‑checkpoint management and rotation, dataset accessor, metric/eval helpers, auto-batch silent option, and unit tests; validation writes Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as Trainer
participant Validator as FullValidator
participant Model as Model
participant Data as ValidationData
participant CKPT as CheckpointManager
participant Log as ValLog
Trainer->>Validator: initialize(validating_params, validation_data, model, train_infos, num_steps, rank, ...)
loop Training steps
Trainer->>Trainer: perform training step
alt display/logging step
Trainer->>Validator: run(step_id, display_step, lr, save_checkpoint)
Validator->>Validator: should_run(display_step)?
alt run validation
Validator->>Model: set eval mode
Validator->>Data: iterate validation systems
loop per system
Validator->>Model: predict(inputs)
Model-->>Validator: outputs (E, F, V)
Validator->>Validator: compute per-system metrics
end
Validator->>Validator: aggregate metrics, select validation metric
alt new best metric
Validator->>CKPT: save new best checkpoint(s)
Validator->>CKPT: prune/rotate retained best checkpoints
end
Validator->>Log: append metrics row to val.log
Validator-->>Trainer: return FullValidationResult
else skip validation
Validator-->>Trainer: return None
end
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
deepmd/pt/train/validation.py (1)
464-470: Disable autograd for validation forwards.
eval()changes module behavior, but gradients are still tracked here. On a full-dataset validation pass that is unnecessary memory and latency overhead.Proposed refactor
- batch_output = self.model( - coord_input, - type_input, - box=box_input, - fparam=fparam_input, - aparam=aparam_input, - ) + with torch.inference_mode(): + batch_output = self.model( + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 464 - 470, The validation forward is running with gradients enabled; wrap the model inference call that produces batch_output = self.model(...) in a no-grad context (preferably with torch.inference_mode() or with torch.no_grad()) inside the validation routine (the method where self.model is called for validation) so autograd is disabled during validation forwards and reduces memory/latency overhead.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 1421-1427: FullValidator.run currently receives self.save_model
directly, causing every stage-0 worker to build model_state and deepcopy
optimizer.state_dict() when a new best checkpoint is broadcast; change the call
so save_checkpoint is True only on the rank that should serialize (e.g. pass
save_checkpoint=(self.save_model and (self.global_rank == 0)) or use
torch.distributed.get_rank()==0 when self.stage == 0), i.e. gate the
save_checkpoint argument before calling self.full_validator.run to skip
serialization on nonzero ranks and avoid unnecessary deep copies.
- Around line 900-914: FullValidator.run() can deadlock other ranks if rank 0
raises during _evaluate()/save_checkpoint() because those ranks wait on
broadcast_object_list(); wrap the rank-0 evaluation/checkpoint block in a
try/except that captures any Exception, set a serializable error payload (e.g.,
tuple with True and the exception string), and immediately broadcast that
payload with broadcast_object_list() so all ranks are unblocked; on non-zero
ranks receive the payload, detect the error flag, and raise a matching exception
(or handle/clean up) so the failure is propagated instead of leaving ranks
blocked—modify deepmd/pt/train/validation.py FullValidator.run() to implement
this pattern around _evaluate() and save_checkpoint().
In `@deepmd/pt/train/validation.py`:
- Around line 255-259: The validator is disabled when start_step equals
num_steps due to a strict '<' check; update the initialization of self.enabled
(which uses self.full_validation, self.start_step, and num_steps) to allow
equality (use '<=' semantics) so full validation can run on the final training
step, and ensure the should_run() logic remains consistent with this change.
- Around line 307-328: The current code only calls self._evaluate() on rank 0
which deadlocks when self.zero_stage >= 2 because forward passes require all
ranks; change the control flow so that when self.zero_stage >= 2 you call
self._evaluate() on every rank (remove the rank==0-only guard for that case) and
still use save_path = [None] + dist.broadcast_object_list(save_path, src=0) to
propagate the chosen checkpoint; keep the existing rank-0-only actions (calling
self._prune_best_checkpoints and self._log_result) but ensure
save_checkpoint(Path(save_path[0]), ...) and the broadcast happen after every
rank has produced or received save_path; update the branches around
self._evaluate, save_path, dist.broadcast_object_list, save_checkpoint,
_prune_best_checkpoints and _log_result accordingly so distributed stage-2/3
training doesn't hang.
In `@deepmd/utils/argcheck.py`:
- Around line 4180-4194: The code currently returns early on multi_task or
non-'ener' losses which lets validating.full_validation silently pass; instead,
check validating.get("full_validation") first and if true reject unsupported
modes: if multi_task is True or loss_params.get("type","ener") != "ener" raise a
ValueError explaining that full_validation is unsupported with multi-task or
non-'ener' losses. Also only run the validation_metric check (using
validating["validation_metric"], is_valid_full_validation_metric and
FULL_VALIDATION_METRIC_PREFS) when full_validation is enabled so invalid metrics
are rejected rather than silently ignored.
In `@source/tests/pt/test_validation.py`:
- Around line 135-139: The test test_normalize_rejects_invalid_metric currently
catches the broad Exception; replace this with the concrete validation error
type that normalize() raises (e.g., ValidationError or the project-specific
ValidationError class) and update imports accordingly so the assertion uses
assertRaisesRegex(ValidationError, "validation_metric") against
normalize(config); keep the same regex and test flow but narrow the exception to
the specific validation error class.
---
Nitpick comments:
In `@deepmd/pt/train/validation.py`:
- Around line 464-470: The validation forward is running with gradients enabled;
wrap the model inference call that produces batch_output = self.model(...) in a
no-grad context (preferably with torch.inference_mode() or with torch.no_grad())
inside the validation routine (the method where self.model is called for
validation) so autograd is disabled during validation forwards and reduces
memory/latency overhead.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 98de0e8f-e132-4770-ae44-25a71d16e73e
📒 Files selected for processing (5)
deepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5336 +/- ##
==========================================
- Coverage 82.38% 80.99% -1.40%
==========================================
Files 812 814 +2
Lines 83611 84163 +552
Branches 4091 4090 -1
==========================================
- Hits 68882 68165 -717
- Misses 13508 14777 +1269
Partials 1221 1221 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
♻️ Duplicate comments (1)
deepmd/pt/train/validation.py (1)
318-338:⚠️ Potential issue | 🟠 MajorAdd runtime guard for
zero_stage >= 2.The documentation in
argcheck.pystates thatzero_stage >= 2is not supported with full validation, but there's no runtime enforcement. With FSDP2 (stage 2/3), model forward passes require collective participation from all ranks. Since only rank 0 enters_evaluate()(line 330), this will hang when other ranks block on the barrier or broadcast.Proposed fix
if not self.should_run(display_step): return None + if self.is_distributed and self.zero_stage >= 2: + raise ValueError( + "validating.full_validation does not support training.zero_stage >= 2. " + "FSDP2 requires all ranks to participate in forward passes." + ) + if self.is_distributed: dist.barrier()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 318 - 338, Add a runtime guard for zero_stage >= 2 before any distributed synchronization or calling _evaluate: check the model/training config value (e.g. self.zero_stage or self.config.zero_stage) right after should_run(...) and, if >= 2, ensure all ranks take the same action (either return None on all ranks or raise a RuntimeError on all ranks) and log a clear message; do not let only rank 0 call _evaluate() while others wait on dist.barrier() — perform the guard before is_distributed/dist.barrier() so no rank blocks.
🧹 Nitpick comments (1)
source/tests/pt/test_training.py (1)
818-822: Clarify or relax the val.log content assertions.The test checks
val_lines[0].split()[1] == "1000.0"andval_lines[1].split()[1] == "2000.0", which appear to expect the MAE values multiplied by 1000 (frommae_e_per_atomvalues of 1.0 and 2.0). This relies on implementation details of the log format that may change.Consider either:
- Adding a comment explaining the expected format (e.g.,
# val.log format: step metric_meV ...)- Using a more flexible assertion like checking the line count or that lines contain expected step numbers
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_training.py` around lines 818 - 822, The assertions on val.log are too brittle because they depend on exact formatted metric values; update the test in test_training.py to either (a) add a clarifying comment above the val.log checks describing the expected file format (e.g., "# val.log format: step metric_meV ...") or (b) relax the assertions by checking structural properties instead of exact strings — for example parse each non-comment line into tokens via val_lines[i].split(), assert the step token equals the expected steps (e.g., "1000" and "2000") and/or parse the metric token to float and compare with the expected value using a numeric tolerance or scaled comparison rather than exact string equality for val_lines[0].split()[1] and val_lines[1].split()[1].
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@deepmd/pt/train/validation.py`:
- Around line 318-338: Add a runtime guard for zero_stage >= 2 before any
distributed synchronization or calling _evaluate: check the model/training
config value (e.g. self.zero_stage or self.config.zero_stage) right after
should_run(...) and, if >= 2, ensure all ranks take the same action (either
return None on all ranks or raise a RuntimeError on all ranks) and log a clear
message; do not let only rank 0 call _evaluate() while others wait on
dist.barrier() — perform the guard before is_distributed/dist.barrier() so no
rank blocks.
---
Nitpick comments:
In `@source/tests/pt/test_training.py`:
- Around line 818-822: The assertions on val.log are too brittle because they
depend on exact formatted metric values; update the test in test_training.py to
either (a) add a clarifying comment above the val.log checks describing the
expected file format (e.g., "# val.log format: step metric_meV ...") or (b)
relax the assertions by checking structural properties instead of exact strings
— for example parse each non-comment line into tokens via val_lines[i].split(),
assert the step token equals the expected steps (e.g., "1000" and "2000") and/or
parse the metric token to float and compare with the expected value using a
numeric tolerance or scaled comparison rather than exact string equality for
val_lines[0].split()[1] and val_lines[1].split()[1].
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bcbdf8fa-2d30-4b2d-905b-49e7034d1081
📒 Files selected for processing (6)
deepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/pt/utils/dataset.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (1)
- source/tests/pt/test_validation.py
Code ReviewIssues1. RMSE aggregation is incorrect (Bug) In # Current (wrong): averages RMSE values
rmse_e_per_atom = float(np.sqrt(np.mean(diff_e * diff_e)) / natoms)
metrics["rmse_e_per_atom"] = (rmse_e_per_atom, float(diff_e.size))
# Then weighted_average averages RMSE values across systemsShould compute MSE per-system, aggregate MSE, then sqrt. 2. In 3.
4. Duplicate validation logic between
5.
6. Force metric uses vector L2 norm, not component-wise error diff_f_norm = np.linalg.vector_norm(diff_f, axis=1)
mae_f_vector = float(np.mean(diff_f_norm))This computes the L2 norm of the (fx, fy, fz) error vector per atom, which differs from 7. No
8.
Minor
|
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
deepmd/pt/train/validation.py (1)
328-337:⚠️ Potential issue | 🔴 CriticalRank-0-only evaluation may deadlock with FSDP (zero_stage >= 2).
When
zero_stage >= 2, FSDP shards model parameters across ranks. Callingself._evaluate()only on rank 0 while other ranks wait at the barrier can deadlock because forward passes require collective participation from all ranks.Add a guard to reject unsupported zero_stage values
if not self.should_run(display_step): return None + if self.is_distributed and self.zero_stage >= 2: + raise ValueError( + "validating.full_validation does not support training.zero_stage >= 2. " + "Please use zero_stage=0 or zero_stage=1." + ) + if self.is_distributed: dist.barrier()Alternatively, add this check to
validate_full_validation_configinargcheck.pyfor early rejection at config time.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 328 - 337, The current rank-0-only evaluation call to self._evaluate can deadlock when using FSDP with zero_stage >= 2 because parameters are sharded across ranks; fix by adding a guard that rejects or prevents running full validation under unsupported FSDP sharding: either (A) in this function check the configuration value zero_stage (or self.config.zero_stage / zero_stage) and raise a clear error if zero_stage >= 2 before attempting the rank==0-only evaluation, or (B) add the same check to validate_full_validation_config in argcheck.py to fail fast at config time; reference symbols to change: self._evaluate (where the call currently occurs), the rank check (rank == 0), and validate_full_validation_config in argcheck.py for the alternative early-rejection location.
🧹 Nitpick comments (1)
deepmd/pt/train/validation.py (1)
620-628: Checkpoint globbing uses current working directory.
Path(".")at line 624 relies on the training process's CWD being the output directory. This should work for standarddp trainworkflows but could behave unexpectedly if the process is launched from a different directory.Consider accepting an optional
output_dirparameter or deriving it fromself.full_val_file.parentto make the path handling more explicit.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 620 - 628, The _list_best_checkpoints method currently globs BEST_CKPT_GLOB against Path(".") which assumes CWD is the output directory; change it to accept an optional output_dir argument (or derive the directory from self.full_val_file.parent) and use Path(output_dir) (or self.full_val_file.parent) instead of Path(".") so checkpoint discovery is deterministic; update the method signature _list_best_checkpoints(self, output_dir: Path | None = None) and replace Path(".").glob(...) with (Path(output_dir) or self.full_val_file.parent).glob(BEST_CKPT_GLOB), keeping the existing filtering and sorting logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/validation.py`:
- Around line 707-709: Replace the assertion in _log_result with an explicit
None-check: inside the _log_result method, test "if result is None:" and handle
it deterministically (e.g., raise a ValueError or log an error and return)
instead of using "assert result is not None"; ensure subsequent access to
result.display_step and other attributes only happens after this guard so
behavior is correct under Python -O optimizations.
- Around line 521-539: The validation forward is creating autograd graphs and
assumes virial always exists; wrap the model forward (the call to
self.model(...) that produces batch_output) inside a torch.no_grad() context to
avoid building gradients, and then handle virial safely by checking for its
presence (e.g., if "virial" in batch_output or using batch_output.get("virial"))
before calling detach()/cpu()/numpy()/reshape so you only include virial in the
returned dict when provided by the model; keep the existing
detach/cpu/numpy/reshape logic for energy and force and apply the same
safe-access pattern to virial.
- Around line 158-164: The code uses np.linalg.vector_norm which exists only in
NumPy ≥2.0; replace that call with np.linalg.norm to maintain compatibility with
NumPy ≥1.26.0: in the block guarded by find_force (where diff_f is computed from
prediction["force"] and test_data["force"]), compute diff_f_norm using
np.linalg.norm(diff_f, axis=1) and leave the subsequent mae_f_vector,
rmse_f_vector, and metrics assignments unchanged so behavior and shapes remain
the same.
---
Duplicate comments:
In `@deepmd/pt/train/validation.py`:
- Around line 328-337: The current rank-0-only evaluation call to self._evaluate
can deadlock when using FSDP with zero_stage >= 2 because parameters are sharded
across ranks; fix by adding a guard that rejects or prevents running full
validation under unsupported FSDP sharding: either (A) in this function check
the configuration value zero_stage (or self.config.zero_stage / zero_stage) and
raise a clear error if zero_stage >= 2 before attempting the rank==0-only
evaluation, or (B) add the same check to validate_full_validation_config in
argcheck.py to fail fast at config time; reference symbols to change:
self._evaluate (where the call currently occurs), the rank check (rank == 0),
and validate_full_validation_config in argcheck.py for the alternative
early-rejection location.
---
Nitpick comments:
In `@deepmd/pt/train/validation.py`:
- Around line 620-628: The _list_best_checkpoints method currently globs
BEST_CKPT_GLOB against Path(".") which assumes CWD is the output directory;
change it to accept an optional output_dir argument (or derive the directory
from self.full_val_file.parent) and use Path(output_dir) (or
self.full_val_file.parent) instead of Path(".") so checkpoint discovery is
deterministic; update the method signature _list_best_checkpoints(self,
output_dir: Path | None = None) and replace Path(".").glob(...) with
(Path(output_dir) or self.full_val_file.parent).glob(BEST_CKPT_GLOB), keeping
the existing filtering and sorting logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 0002b1c1-6f36-4e9c-8fb5-ebcf45c3da7d
📒 Files selected for processing (6)
deepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/pt/utils/dataset.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (2)
- deepmd/pt/utils/dataset.py
- source/tests/pt/test_validation.py
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (4)
deepmd/pt/train/validation.py (4)
707-709:⚠️ Potential issue | 🟡 MinorReplace
assertwith explicit check for production robustness.The
assert result is not Nonewill be stripped when Python runs with optimizations (-Oflag), potentially causingAttributeErroronresult.display_step.Proposed fix
def _log_result(self, result: FullValidationResult | None) -> None: """Log and persist full validation results on rank 0.""" - assert result is not None + if result is None: + raise RuntimeError("_log_result called with None result on rank 0") self._write_log_file(result)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 707 - 709, In _log_result replace the bare assert with an explicit runtime check: if result is None, log an error (or warning) and return (or raise a clear ValueError) instead of proceeding; ensure you reference the FullValidationResult parameter and avoid calling result.display_step when result is None so production runs under -O don't hit AttributeError. Use the function name _log_result and the result variable in your change.
158-164:⚠️ Potential issue | 🟠 Major
np.linalg.vector_normrequires NumPy ≥ 2.0; use a compatible alternative.This function was introduced in NumPy 2.0.0, but the project supports
numpy>=1.26.0. This will raiseAttributeErroron older NumPy versions.Proposed fix for broader compatibility
if find_force: diff_f = prediction["force"].reshape(-1, 3) - test_data["force"].reshape(-1, 3) - diff_f_norm = np.linalg.vector_norm(diff_f, axis=1) + diff_f_norm = np.sqrt(np.sum(diff_f * diff_f, axis=1)) mae_f_vector = float(np.mean(diff_f_norm)) rmse_f_vector = float(np.sqrt(np.mean(diff_f_norm * diff_f_norm)))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 158 - 164, The code uses np.linalg.vector_norm (in the block guarded by find_force, calculating diff_f and diff_f_norm) which exists only in NumPy ≥2.0; replace it with the compatible np.linalg.norm call to compute row-wise norms (e.g., use np.linalg.norm(diff_f, axis=1)) so diff_f_norm is computed correctly across supported NumPy versions, then proceed to compute mae_f_vector, rmse_f_vector and store them in metrics["mae_f_vector"] / metrics["rmse_f_vector"] as before.
321-363:⚠️ Potential issue | 🟠 MajorConsider guarding against
zero_stage >= 2to prevent distributed deadlock.When
zero_stage >= 2, the model is sharded across ranks and forward passes require collective participation from all ranks. Currently, only rank 0 calls_evaluate()while other ranks wait atbroadcast_object_list, which can cause a hang.The fix at lines 346-350 addresses checkpoint saving but does not address the evaluation path itself.
Proposed fix to reject unsupported zero_stage at runtime
if not self.should_run(display_step): return None + if self.is_distributed and self.zero_stage >= 2: + raise ValueError( + "validating.full_validation does not support training.zero_stage >= 2." + ) + if self.is_distributed: dist.barrier()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 321 - 363, The validation flow can deadlock when zero_stage >= 2 because model forward requires all ranks but only rank 0 runs _evaluate; add a runtime guard in the validation routine to detect if self.is_distributed and self.zero_stage >= 2 and raise a clear exception (via self._raise_if_distributed_error or by setting caught_exception/error_message so _raise_if_distributed_error is invoked) before any rank waits at dist.broadcast_object_list; reference self.zero_stage, self.is_distributed, _evaluate and _raise_if_distributed_error to locate where to insert the check and ensure the error is raised on all ranks to avoid distributed hangs.
521-539:⚠️ Potential issue | 🟠 MajorMissing
torch.no_grad()and unconditional virial access.Two issues in
predict_batch:
The model forward pass runs without
torch.no_grad(), unnecessarily building the autograd graph and consuming extra memory during validation.
batch_output["virial"]is accessed unconditionally, but models for non-PBC systems may not provide virial, causing aKeyError.Proposed fix
+ `@torch.no_grad`() def predict_batch( coord_batch: np.ndarray, atom_types_batch: np.ndarray, box_batch: np.ndarray | None, fparam_batch: np.ndarray | None, aparam_batch: np.ndarray | None, ) -> dict[str, np.ndarray]: # ... tensor creation code unchanged ... batch_output = self.model( coord_input, type_input, box=box_input, fparam=fparam_input, aparam=aparam_input, ) if isinstance(batch_output, tuple): batch_output = batch_output[0] - return { - "energy": batch_output["energy"].detach().cpu().numpy().reshape(-1, 1), - "force": batch_output["force"] - .detach() - .cpu() - .numpy() - .reshape(-1, natoms * 3), - "virial": batch_output["virial"].detach().cpu().numpy().reshape(-1, 9), - } + result = { + "energy": batch_output["energy"].detach().cpu().numpy().reshape(-1, 1), + "force": batch_output["force"] + .detach() + .cpu() + .numpy() + .reshape(-1, natoms * 3), + } + if "virial" in batch_output: + result["virial"] = batch_output["virial"].detach().cpu().numpy().reshape(-1, 9) + return result🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 521 - 539, predict_batch is missing a torch.no_grad() context and unconditionally accesses batch_output["virial"], which can raise KeyError for non-PBC models; wrap the forward call in torch.no_grad() around the self.model(...) invocation in predict_batch to avoid building the autograd graph, and guard the virial extraction by checking "virial" in batch_output (or using batch_output.get("virial") and handling None) before calling .detach().cpu().numpy().reshape(...) so prediction works for models that do not return virial.
🧹 Nitpick comments (2)
deepmd/pt/train/validation.py (2)
620-628:Path(".")glob depends on current working directory.Using
Path(".")assumes the current working directory is the training output directory. If the working directory changes during training or if the code is invoked from a different directory, checkpoint management will fail silently.Consider accepting an explicit output directory path in the constructor and using it consistently throughout checkpoint management methods.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 620 - 628, The _list_best_checkpoints method currently uses Path(".") which depends on the process CWD and can fail if training is run from a different directory; update the class to accept an explicit output directory (e.g., add an output_dir Path parameter on the constructor) and replace Path(".") usages in _list_best_checkpoints (and other checkpoint-related methods that reference BEST_CKPT_GLOB) to use that stored output_dir.glob(BEST_CKPT_GLOB) so checkpoint listing/manipulation is always relative to the configured training output directory.
239-240: Suppress logs more robustly by calling parent's_adjust_batch_size.Currently, this override reimplements the multiplication logic without calling
super(). While the behavior is identical today, calling the parent method ensures this subclass evolves with future changes. The class already uses this pattern in__init__to suppress logs—apply the same approach here:Proposed fix
def _adjust_batch_size(self, factor: float) -> None: - self.current_batch_size = int(self.current_batch_size * factor) + batch_size_log = logging.getLogger(BATCH_SIZE_LOGGER_NAME) + old_disabled = batch_size_log.disabled + batch_size_log.disabled = True + try: + super()._adjust_batch_size(factor) + finally: + batch_size_log.disabled = old_disabled🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 239 - 240, Replace the manual multiplication in _adjust_batch_size with a call to the parent implementation: call super()._adjust_batch_size(factor) (so the subclass inherits any future behavior and log-suppression logic used elsewhere, e.g. in __init__) instead of setting self.current_batch_size = int(self.current_batch_size * factor) directly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/validation.py`:
- Around line 551-555: The returned dict always includes "virial" but
predict_batch may omit it, causing a KeyError; update the return construction in
the function that builds the output from prediction (the block returning
{"energy": prediction["energy"], "force": prediction["force"], "virial":
prediction["virial"]}) to include "virial" only when present (e.g., check
"virial" in prediction or use prediction.get and conditionally add the key) so
the function safely returns energy and force and only adds virial if prediction
provides it.
In `@source/tests/pt/test_training.py`:
- Around line 824-828: The test test_full_validation_rejects_spin_loss is
asserting a regex "spin-energy" but validate_full_validation_config (called via
get_trainer) actually emits the literal loss type 'ener_spin' in its error
message; update the assertion to match the real message (e.g.,
assertRaisesRegex(ValueError, "ener_spin") or use a regex that accepts either
"ener_spin" or "spin-energy") so the test matches the error produced by
validate_full_validation_config when config["loss"]["type"] = "ener_spin".
---
Duplicate comments:
In `@deepmd/pt/train/validation.py`:
- Around line 707-709: In _log_result replace the bare assert with an explicit
runtime check: if result is None, log an error (or warning) and return (or raise
a clear ValueError) instead of proceeding; ensure you reference the
FullValidationResult parameter and avoid calling result.display_step when result
is None so production runs under -O don't hit AttributeError. Use the function
name _log_result and the result variable in your change.
- Around line 158-164: The code uses np.linalg.vector_norm (in the block guarded
by find_force, calculating diff_f and diff_f_norm) which exists only in NumPy
≥2.0; replace it with the compatible np.linalg.norm call to compute row-wise
norms (e.g., use np.linalg.norm(diff_f, axis=1)) so diff_f_norm is computed
correctly across supported NumPy versions, then proceed to compute mae_f_vector,
rmse_f_vector and store them in metrics["mae_f_vector"] /
metrics["rmse_f_vector"] as before.
- Around line 321-363: The validation flow can deadlock when zero_stage >= 2
because model forward requires all ranks but only rank 0 runs _evaluate; add a
runtime guard in the validation routine to detect if self.is_distributed and
self.zero_stage >= 2 and raise a clear exception (via
self._raise_if_distributed_error or by setting caught_exception/error_message so
_raise_if_distributed_error is invoked) before any rank waits at
dist.broadcast_object_list; reference self.zero_stage, self.is_distributed,
_evaluate and _raise_if_distributed_error to locate where to insert the check
and ensure the error is raised on all ranks to avoid distributed hangs.
- Around line 521-539: predict_batch is missing a torch.no_grad() context and
unconditionally accesses batch_output["virial"], which can raise KeyError for
non-PBC models; wrap the forward call in torch.no_grad() around the
self.model(...) invocation in predict_batch to avoid building the autograd
graph, and guard the virial extraction by checking "virial" in batch_output (or
using batch_output.get("virial") and handling None) before calling
.detach().cpu().numpy().reshape(...) so prediction works for models that do not
return virial.
---
Nitpick comments:
In `@deepmd/pt/train/validation.py`:
- Around line 620-628: The _list_best_checkpoints method currently uses
Path(".") which depends on the process CWD and can fail if training is run from
a different directory; update the class to accept an explicit output directory
(e.g., add an output_dir Path parameter on the constructor) and replace
Path(".") usages in _list_best_checkpoints (and other checkpoint-related methods
that reference BEST_CKPT_GLOB) to use that stored
output_dir.glob(BEST_CKPT_GLOB) so checkpoint listing/manipulation is always
relative to the configured training output directory.
- Around line 239-240: Replace the manual multiplication in _adjust_batch_size
with a call to the parent implementation: call
super()._adjust_batch_size(factor) (so the subclass inherits any future behavior
and log-suppression logic used elsewhere, e.g. in __init__) instead of setting
self.current_batch_size = int(self.current_batch_size * factor) directly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: accd30cc-c238-4318-9057-cb68289feae0
📒 Files selected for processing (6)
deepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/pt/utils/dataset.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (1)
- source/tests/pt/test_validation.py
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/pt/train/training.py
- deepmd/utils/argcheck.py
- deepmd/pt/utils/dataset.py
There was a problem hiding this comment.
Actionable comments posted: 7
♻️ Duplicate comments (1)
source/tests/pt/test_training.py (1)
840-844:⚠️ Potential issue | 🟡 MinorFix the regex in this negative test.
Line 843 is matching
"spin-energy", butvalidate_full_validation_config()now includes the literal loss type (ener_spin) in theValueError, so this assertion will fail even when the implementation is correct.Suggested fix
- with self.assertRaisesRegex(ValueError, "spin-energy"): + with self.assertRaisesRegex(ValueError, "ener_spin"): get_trainer(config)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_training.py` around lines 840 - 844, The test test_full_validation_rejects_spin_loss currently expects the ValueError message to contain "spin-energy" but validate_full_validation_config() now includes the literal loss type "ener_spin" in the error; update the assertion in that test (the get_trainer call) to assert a regex matching "ener_spin" (or a pattern that accepts either "spin-energy" or "ener_spin", e.g. "spin-energy|ener_spin") so the negative test correctly matches the new error text from validate_full_validation_config().
🧹 Nitpick comments (2)
deepmd/utils/batch_size.py (1)
44-46: Clarifysilentdocstring to match actual behavior.Current wording says informational logs only, but the implementation also suppresses warnings.
Suggested doc tweak
- silent : bool, default: False - whether to suppress auto batch size informational logs + silent : bool, default: False + whether to suppress auto batch size logs (including info/warning messages)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/batch_size.py` around lines 44 - 46, The docstring for the parameter silent in deepmd/utils/batch_size.py is misleading: it says it only suppresses informational logs but the implementation also suppresses warnings; update the docstring for the silent parameter (where it is defined—e.g., the function/method that accepts silent in this module) to state that silent suppresses both informational logs and warnings, and adjust any nearby examples or summary text in that function's docstring to match the actual behavior.deepmd/pt/train/validation.py (1)
110-119: Consider documenting the overloaded semantics offull_val_start.The function handles three distinct cases (0.0-1.0 as ratio, 1.0 as disabled, >1 as step number), but this overloading isn't self-documenting. A docstring explaining these semantics would help maintainability.
Proposed docstring enhancement
def resolve_full_validation_start_step( full_val_start: float, num_steps: int ) -> int | None: - """Resolve the first step at which full validation becomes active.""" + """Resolve the first step at which full validation becomes active. + + Parameters + ---------- + full_val_start : float + - 0.0 to <1.0: treated as a ratio of num_steps (e.g., 0.5 = start at 50%) + - 1.0: full validation is disabled (returns None) + - >1.0: treated as an absolute step number + num_steps : int + Total number of training steps. + + Returns + ------- + int | None + The starting step, or None if disabled. + """ start_value = float(full_val_start)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 110 - 119, Update the docstring of resolve_full_validation_start_step to explicitly describe the overloaded semantics of the full_val_start parameter: when full_val_start == 1.0 the function returns None to disable full validation, when 0.0 <= full_val_start < 1.0 it is interpreted as a fraction of num_steps and the function returns int(num_steps * full_val_start), and when full_val_start > 1.0 it is treated as an absolute step number and the function returns int(full_val_start); also mention return type int | None and any edge-case behavior (e.g., truncation via int()).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/entrypoints/test.py`:
- Around line 417-440: The virial detail files are unconditionally written using
pv and pv_atom even when virial data is absent; update the logic around the
pv/pv_atom construction and the save_txt_file calls so they only run when virial
data is actually available (e.g., check the same condition used by
test_ener/find_virial or verify reference_virial and prediction_virial are not
None/empty). In practice, wrap the pv = np.concatenate(...), pv_atom = pv /
natoms, and both save_txt_file(...) calls in a single guard that mirrors the
existing virial-availability check, and skip writing .v.out and .v_peratom.out
when virial arrays are missing.
- Around line 295-307: The current code computes nloc_real by summing
np.count_nonzero(atype == ii) across the whole batch, which is wrong when
mixed_type=True because atype is 2D; instead compute a per-frame real-atom
count/mask (e.g., sum over the frame axis or build a boolean mask per frame from
atype and ntpyes_real) and use that per-frame mask/indices to split or select
columns from prediction_force and reference_force into real vs magnetic parts;
update the logic around nloc_real, force_real_prediction,
force_magnetic_prediction, force_real_reference, and force_magnetic_reference to
use the per-frame mask (or per-frame counts multiplied by 3 for column offsets)
rather than a single global nloc_real.
- Around line 315-323: The current code raises a ValueError if any of
prediction_force_mag, reference_force_mag, or mask_mag are None, but the PyTorch
spin path should allow missing magnetic-force arrays; remove the exception and
make the magnetic arrays optional in get-spin-return logic (the function around
the shown block in deepmd/entrypoints/test.py used by test_ener()). Concretely,
if any of prediction_force_mag, reference_force_mag, or mask_mag is None, return
prediction_force and reference_force and produce empty magnetic outputs (e.g.,
arrays shaped (0,3)) and an empty magnetic_mask so downstream code that checks
find_force_mag continues to work; otherwise keep the existing
reshape(-1,3)[magnetic_mask] behavior for prediction_force_mag,
reference_force_mag and mask_mag. Ensure you reference the symbols
prediction_force, reference_force, prediction_force_mag, reference_force_mag,
mask_mag, magnetic_mask and keep compatibility with test_ener()'s find_force_mag
guard.
In `@deepmd/pt/train/validation.py`:
- Around line 157-170: The function format_metric_number_for_log can call
np.log10 on values that are zero or extremely tiny, causing
infinities/overflows; fix by guarding tiny magnitudes: compute abs_value =
abs(metric_value) then if abs_value == 0.0 or abs_value <= np.finfo(float).tiny
return "0" (or "nan" if metric_value is nan already), otherwise compute
safe_log_input = max(abs_value, np.finfo(float).tiny) and use decimals =
VAL_LOG_SIGNIFICANT_DIGITS - int(np.floor(np.log10(safe_log_input))) - 1;
finally clamp decimals to a reasonable range (e.g., min_decimals=0,
max_decimals=12) before rounding/formatting so format_metric_number_for_log and
the constant VAL_LOG_SIGNIFICANT_DIGITS are used safely.
- Around line 568-576: The _list_best_checkpoints function currently uses
Path(".") which depends on the process CWD and can miss checkpoints; update
FullValidator to accept and store an explicit output directory (e.g. output_dir:
Path) passed from the training config/trainer, then change
_list_best_checkpoints to use self.output_dir.glob(BEST_CKPT_GLOB) and any other
checkpoint helpers (e.g. _best_checkpoint_name, checkpoint save/resolve logic)
to resolve paths relative to self.output_dir instead of Path("."), and modify
the train() caller to pass the configured output directory into FullValidator so
all checkpoint operations are consistent and independent of CWD.
In `@deepmd/utils/argcheck.py`:
- Around line 4318-4322: The early return using
_is_full_validation_active(validating, num_steps) prevents static
full-validation compatibility checks from running during normalize() when
training schedule isn't finalized; instead, change the logic so that when
validating.get("full_validation") is truthy you always run the static
compatibility checks (the existing full-validation validation routines)
regardless of num_steps, and use _is_full_validation_active only to decide
runtime start-step scheduling (not to skip validation). Concretely, replace the
current "if not _is_full_validation_active(...): return" flow with a branch that
(a) runs static full-validation compatibility checks whenever
validating.get("full_validation") is enabled, and (b) separately uses
_is_full_validation_active(validating, num_steps) only to set or defer
start-step scheduling logic used at runtime.
In `@deepmd/utils/eval_metrics.py`:
- Around line 122-134: The RMSE aggregation is wrong because compute_error_stat
currently stores rmse and weight=diff.size which leads weighted_average to
compute a weighted mean of RMSEs; instead change compute_error_stat (and the
ErrorStat structure) to return the sum of squared errors and count so aggregate
RMSE is computed as sqrt(sum(sse)/sum(n)). Specifically, have compute_error_stat
compute sse = np.sum((prediction - reference)**2) * (scale**2) and return that
as a field (e.g., sse) and weight as the integer count (diff.size), keep mae
as-is (or store sum_abs if you prefer consistent aggregation), then update any
downstream weighted_average/aggregation logic to compute final rmse =
sqrt(sum(sse_i)/sum(n_i)) rather than averaging rmse_i. Ensure references to
compute_error_stat, ErrorStat, mae, rmse (rename to sse or add sse alongside
rmse), and weighted_average are updated accordingly.
---
Duplicate comments:
In `@source/tests/pt/test_training.py`:
- Around line 840-844: The test test_full_validation_rejects_spin_loss currently
expects the ValueError message to contain "spin-energy" but
validate_full_validation_config() now includes the literal loss type "ener_spin"
in the error; update the assertion in that test (the get_trainer call) to assert
a regex matching "ener_spin" (or a pattern that accepts either "spin-energy" or
"ener_spin", e.g. "spin-energy|ener_spin") so the negative test correctly
matches the new error text from validate_full_validation_config().
---
Nitpick comments:
In `@deepmd/pt/train/validation.py`:
- Around line 110-119: Update the docstring of
resolve_full_validation_start_step to explicitly describe the overloaded
semantics of the full_val_start parameter: when full_val_start == 1.0 the
function returns None to disable full validation, when 0.0 <= full_val_start <
1.0 it is interpreted as a fraction of num_steps and the function returns
int(num_steps * full_val_start), and when full_val_start > 1.0 it is treated as
an absolute step number and the function returns int(full_val_start); also
mention return type int | None and any edge-case behavior (e.g., truncation via
int()).
In `@deepmd/utils/batch_size.py`:
- Around line 44-46: The docstring for the parameter silent in
deepmd/utils/batch_size.py is misleading: it says it only suppresses
informational logs but the implementation also suppresses warnings; update the
docstring for the silent parameter (where it is defined—e.g., the
function/method that accepts silent in this module) to state that silent
suppresses both informational logs and warnings, and adjust any nearby examples
or summary text in that function's docstring to match the actual behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2635eef0-a34a-4911-8710-4d5167f9c3dd
📒 Files selected for processing (13)
deepmd/entrypoints/test.pydeepmd/jax/utils/auto_batch_size.pydeepmd/pd/utils/auto_batch_size.pydeepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/pt/utils/auto_batch_size.pydeepmd/pt/utils/dataset.pydeepmd/tf/utils/batch_size.pydeepmd/utils/argcheck.pydeepmd/utils/batch_size.pydeepmd/utils/eval_metrics.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (1)
- deepmd/pt/utils/dataset.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/train/training.py
Summary by CodeRabbit
New Features
silentoption to suppress informational logs.Changes
Tests