Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 66 additions & 18 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.pt.train.utils import (
NonFiniteGradGuard,
clip_grad_norm_,
latest_checkpoint_path,
resolve_best_checkpoint_dir,
resolve_keep_ckpt_count,
scoped_env_defaults,
)
from deepmd.pt.train.validation import (
Expand Down Expand Up @@ -208,8 +211,13 @@ def __init__(
self.disp_freq = training_params.get("disp_freq", 1000)
self.disp_avg = training_params.get("disp_avg", False)
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
save_dir = training_params.get("save_dir")
self.save_dir = Path(save_dir) if save_dir else None
if self.save_dir is not None and self.rank == 0:
self.save_dir.mkdir(parents=True, exist_ok=True)
self.save_freq = training_params.get("save_freq", 1000)
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
self.ckpt_keep_ratio = training_params.get("ckpt_keep_ratio")
self.enable_ema = bool(training_params.get("enable_ema", False))
self.ema_decay = float(training_params.get("ema_decay", 0.999))
self.ema_ckpt_keep = int(training_params.get("ema_ckpt_keep", 3))
Expand Down Expand Up @@ -723,6 +731,24 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
rank=self.rank,
)

# === Derive checkpoint retention from ckpt_keep_ratio ===
# num_steps is final here (including when derived from num_epoch), so the
# ratio can be converted into an absolute keep count once.
keep_ckpt_count = resolve_keep_ckpt_count(
self.ckpt_keep_ratio, self.num_steps, self.save_freq
)
if keep_ckpt_count is not None:
self.max_ckpt_keep = keep_ckpt_count
self.ema_ckpt_keep = keep_ckpt_count
log.info(
"Resolved checkpoint retention to %d from ckpt_keep_ratio=%s "
"(num_steps=%d, save_freq=%d).",
keep_ckpt_count,
self.ckpt_keep_ratio,
self.num_steps,
self.save_freq,
)

# Learning rate
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
self.nonfinite_grad_guard = NonFiniteGradGuard()
Expand Down Expand Up @@ -1172,7 +1198,9 @@ def _create_full_validator(
rank=self.rank,
zero_stage=self.zero_stage,
restart_training=self.restart_training,
checkpoint_dir=Path(self.save_ckpt).parent,
checkpoint_dir=resolve_best_checkpoint_dir(
validating_params, self.save_ckpt
),
)

def _create_ema_full_validator(
Expand Down Expand Up @@ -1206,7 +1234,9 @@ def _create_ema_full_validator(
rank=self.rank,
zero_stage=self.zero_stage,
restart_training=self.restart_training,
checkpoint_dir=Path(self.save_ckpt).parent,
checkpoint_dir=resolve_best_checkpoint_dir(
validating_params, self.save_ckpt
),
full_val_file=get_ema_validation_log_path(
validating_params.get("full_val_file", "val.log")
),
Expand Down Expand Up @@ -1795,21 +1825,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0
):
# Handle the case if rank 0 aborted and re-assigned
self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt")
self.latest_model = latest_checkpoint_path(
self.save_ckpt, display_step_id, self.save_dir
)
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
if self.rank == 0 or dist.get_rank() == 0:
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
symlink_prefix_files(
str(self.latest_model.with_suffix("")), self.save_ckpt
)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
if self.model_ema is not None:
self.latest_ema_model = Path(
self.ema_save_ckpt + f"-{display_step_id}.pt"
self.latest_ema_model = latest_checkpoint_path(
self.ema_save_ckpt, display_step_id, self.save_dir
)
self.save_ema_model(self.latest_ema_model, lr=cur_lr, step=_step_id)
if self.rank == 0 or dist.get_rank() == 0:
symlink_prefix_files(
self.latest_ema_model.stem,
str(self.latest_ema_model.with_suffix("")),
self.ema_save_ckpt,
)

Expand Down Expand Up @@ -1888,30 +1922,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
self.get_sample_func[model_key],
_bias_adjust_mode="change-by-statistic",
)
self.latest_model = Path(self.save_ckpt + f"-{self.num_steps}.pt")
self.latest_model = latest_checkpoint_path(
self.save_ckpt, self.num_steps, self.save_dir
)
cur_lr = self.lr_schedule.value(self.num_steps - 1)
self.save_model(self.latest_model, lr=cur_lr, step=self.num_steps - 1)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
symlink_prefix_files(str(self.latest_model.with_suffix("")), self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
if self.model_ema is not None:
self.latest_ema_model = Path(
self.ema_save_ckpt + f"-{self.num_steps}.pt"
self.latest_ema_model = latest_checkpoint_path(
self.ema_save_ckpt, self.num_steps, self.save_dir
)
self.save_ema_model(
self.latest_ema_model,
lr=cur_lr,
step=self.num_steps - 1,
)
symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt)
symlink_prefix_files(
str(self.latest_ema_model.with_suffix("")), self.ema_save_ckpt
)

if self.num_steps == 0 and self.zero_stage > 0:
# ZeRO-1 / FSDP: all ranks participate in save_model (collective op)
self.latest_model = Path(self.save_ckpt + "-0.pt")
self.latest_model = latest_checkpoint_path(self.save_ckpt, 0, self.save_dir)
self.save_model(self.latest_model, lr=0, step=0)
if self.model_ema is not None:
self.latest_ema_model = Path(self.ema_save_ckpt + "-0.pt")
self.latest_ema_model = latest_checkpoint_path(
self.ema_save_ckpt, 0, self.save_dir
)
self.save_ema_model(self.latest_ema_model, lr=0, step=0)

if (
Expand All @@ -1920,17 +1960,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
if self.num_steps == 0:
if self.zero_stage == 0:
# When num_steps is 0, the checkpoint is never saved in the loop
self.latest_model = Path(self.save_ckpt + "-0.pt")
self.latest_model = latest_checkpoint_path(
self.save_ckpt, 0, self.save_dir
)
self.save_model(self.latest_model, lr=0, step=0)
if self.model_ema is not None:
self.latest_ema_model = Path(self.ema_save_ckpt + "-0.pt")
self.latest_ema_model = latest_checkpoint_path(
self.ema_save_ckpt, 0, self.save_dir
)
self.save_ema_model(self.latest_ema_model, lr=0, step=0)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
symlink_prefix_files(
str(self.latest_model.with_suffix("")), self.save_ckpt
)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
if self.model_ema is not None:
symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt)
symlink_prefix_files(
str(self.latest_ema_model.with_suffix("")), self.ema_save_ckpt
)

if self.timing_in_training and self.timed_steps:
msg = f"average training time: {self.total_train_time / self.timed_steps:.4f} s/batch"
Expand Down
92 changes: 92 additions & 0 deletions deepmd/pt/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
from contextlib import (
contextmanager,
)
from math import (
ceil,
)
from pathlib import (
Path,
)
from typing import (
TYPE_CHECKING,
Any,
)

import torch
Expand Down Expand Up @@ -191,3 +198,88 @@ def scoped_env_defaults(defaults: dict[str, str]) -> Generator[None, None, None]
os.environ.pop(key, None)
else:
os.environ[key] = value


def latest_checkpoint_path(prefix: str, step_label: int, save_dir: Path | None) -> Path:
"""
Resolve the on-disk path of a periodic checkpoint file.

Parameters
----------
prefix : str
The checkpoint prefix, e.g. ``model.ckpt`` or its EMA counterpart.
step_label : int
The training step encoded into the filename.
save_dir : Path or None
The configured checkpoint directory. When ``None`` the file follows
``prefix`` relative to the working directory.

Returns
-------
Path
``save_dir/<prefix name>-<step>.pt`` when ``save_dir`` is set, otherwise
``<prefix>-<step>.pt`` relative to the working directory.
"""
directory = save_dir if save_dir is not None else Path(prefix).parent
return directory / f"{Path(prefix).name}-{step_label}.pt"


def resolve_best_checkpoint_dir(
validating_params: dict[str, Any], save_ckpt: str
) -> Path:
"""
Resolve the directory for full-validation best checkpoints.

Parameters
----------
validating_params : dict
The ``validating`` section of the training configuration.
save_ckpt : str
The regular checkpoint prefix from ``training.save_ckpt``.

Returns
-------
Path
``validating.save_best_dir`` when set, otherwise the directory derived
from ``save_ckpt``.
"""
save_best_dir = validating_params.get("save_best_dir")
if save_best_dir:
return Path(save_best_dir)
return Path(save_ckpt).parent


def resolve_keep_ckpt_count(
ckpt_keep_ratio: float | None, num_steps: int, save_freq: int
) -> int | None:
"""
Convert a checkpoint-retention ratio into a sliding-window keep count.

A checkpoint is written every ``save_freq`` steps and once more at the final
step, so a run of ``num_steps`` produces ``ceil(num_steps / save_freq)`` of
them in total (the terminal checkpoint is off-cadence when ``num_steps`` is
not a multiple of ``save_freq``). Keeping the most recent
``ceil(ratio * total)`` is equivalent to retaining the final ``ratio``
fraction of the run by step, without the caller computing the count by hand.

Parameters
----------
ckpt_keep_ratio : float or None
The fraction of the training run, by step, whose periodic checkpoints
are retained. ``None`` leaves the keep count unchanged.
num_steps : int
The total number of training steps, already resolved (including when
derived from ``numb_epoch``).
save_freq : int
The checkpoint saving frequency in steps.

Returns
-------
int or None
The number of most recent checkpoints to keep (at least one), or
``None`` when ``ckpt_keep_ratio`` is not set.
"""
if ckpt_keep_ratio is None:
return None
total_ckpts = max(1, ceil(num_steps / save_freq))
return max(1, ceil(ckpt_keep_ratio * total_ckpts))
1 change: 1 addition & 0 deletions deepmd/pt/train/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def __init__(
self.topk_records = self._load_topk_records()
self._sync_state_store()
if self.rank == 0:
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self._initialize_best_checkpoints(restart_training=restart_training)

# Lazily-populated full test snapshot for LMDB validation. Mixed-nloc
Expand Down
46 changes: 46 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -4980,11 +4980,26 @@ def training_args(
doc_disp_freq = "The frequency of printing learning curve."
doc_save_freq = "The frequency of saving check point."
doc_save_ckpt = "The path prefix of saving check point files."
doc_save_dir = (
"The directory in which periodic checkpoint files are written, "
"including the regular checkpoints (the `save_ckpt` prefix) and, when "
"EMA is enabled, the EMA checkpoints. It is created recursively if it "
"does not exist. The latest-checkpoint symlinks (such as "
"`model.ckpt.pt`) and the `checkpoint` pointer file remain in the "
"working directory and reference the files in this directory. If not "
"set, checkpoints are written to the working directory."
)
doc_max_ckpt_keep = (
"The maximum number of checkpoints to keep. "
"The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. "
"Defaults to 5."
)
doc_ckpt_keep_ratio = (
"An alternative to `max_ckpt_keep` that sets the number of retained "
"checkpoints as a fraction in (0, 1) of the run: the most recent "
"`ceil(ckpt_keep_ratio * ceil(numb_steps / save_freq))` checkpoints are kept. "
"When set, it overrides `max_ckpt_keep` and `ema_ckpt_keep`."
)
doc_enable_ema = (
"Whether to maintain an exponential moving average (EMA) of model "
"parameters during training and save periodic EMA checkpoints with an "
Expand Down Expand Up @@ -5109,10 +5124,26 @@ def training_args(
),
Argument("disp_freq", int, optional=True, default=1000, doc=doc_disp_freq),
Argument("save_freq", int, optional=True, default=1000, doc=doc_save_freq),
Argument(
"save_dir",
[str, None],
optional=True,
default=None,
doc=doc_only_pt_supported + doc_save_dir,
),
Argument(
"save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt
),
Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep),
Argument(
"ckpt_keep_ratio",
[float, None],
optional=True,
default=None,
doc=doc_only_pt_supported + doc_ckpt_keep_ratio,
extra_check=lambda x: x is None or 0.0 < x < 1.0,
extra_check_errmsg="must be a fraction in the open interval (0, 1)",
),
Argument(
"enable_ema",
bool,
Expand Down Expand Up @@ -5335,6 +5366,14 @@ def validating_args() -> Argument:
"The frequency, in training steps, of running the full validation pass."
)
doc_save_best = "Whether to save an extra checkpoint when the selected full validation metric reaches a new best value."
doc_save_best_dir = (
"The directory in which the best checkpoints selected by full "
"validation are written (the `best.ckpt` prefix, and the "
"`best_ema.ckpt` prefix when EMA full validation is enabled). It is "
"created recursively if it does not exist. If not set, the best "
"checkpoints are written to the directory determined by "
"`training.save_ckpt`."
)
doc_ema_full_validation = (
"Whether to additionally run the same full validation flow on the "
"EMA-smoothed model when `validating.full_validation=true`. This reuses "
Expand Down Expand Up @@ -5412,6 +5451,13 @@ def validating_args() -> Argument:
default=True,
doc=doc_only_pt_supported + doc_save_best,
),
Argument(
"save_best_dir",
[str, None],
optional=True,
default=None,
doc=doc_only_pt_supported + doc_save_best_dir,
),
Argument(
"max_best_ckpt",
int,
Expand Down
2 changes: 2 additions & 0 deletions doc/train/training-advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ Other keys in the {ref}`training <training>` section are explained below:
- {ref}`disp_file <training/disp_file>` The file for printing learning curve.
- {ref}`disp_freq <training/disp_freq>` The frequency of printing learning curve. Set in the unit of training steps
- {ref}`save_freq <training/save_freq>` The frequency of saving checkpoint.
- {ref}`save_dir <training/save_dir>` The directory where periodic checkpoints are written (PyTorch backend). It is created recursively if missing, while the `model.ckpt.pt` symlinks and the `checkpoint` pointer file stay in the working directory. Defaults to the working directory.
- {ref}`ckpt_keep_ratio <training/ckpt_keep_ratio>` An alternative to `max_ckpt_keep` (PyTorch backend) that keeps a sliding window of `ceil(ckpt_keep_ratio * ceil(numb_steps / save_freq))` most recent checkpoints, i.e. the final `ckpt_keep_ratio` fraction of the run by step. It overrides `max_ckpt_keep` (and `ema_ckpt_keep`) when set, and works the same whether the run length is given by `numb_steps` or `numb_epoch`.

## Options and environment variables

Expand Down
4 changes: 3 additions & 1 deletion examples/water/dpa4/input.json
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
"numb_steps": 2000000,
"gradient_max_norm": 5.0,
"save_freq": 2000,
"save_dir": "ckpt",
"max_ckpt_keep": 3,
"enable_ema": true,
"ema_decay": 0.999,
Expand All @@ -119,6 +120,7 @@
},
"validating": {
"compiled_infer": false,
"tf32_infer": false
"tf32_infer": false,
"save_best_dir": "ckpt_best"
Comment thread
OutisLi marked this conversation as resolved.
Comment thread
OutisLi marked this conversation as resolved.
}
}
Loading
Loading