diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8a278dea7c..3e8f0e6830 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -75,6 +75,9 @@ from deepmd.pt.train.utils import ( NonFiniteGradGuard, clip_grad_norm_, + latest_checkpoint_path, + resolve_best_checkpoint_dir, + resolve_keep_ckpt_count, scoped_env_defaults, ) from deepmd.pt.train.validation import ( @@ -208,8 +211,13 @@ def __init__( self.disp_freq = training_params.get("disp_freq", 1000) self.disp_avg = training_params.get("disp_avg", False) self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") + save_dir = training_params.get("save_dir") + self.save_dir = Path(save_dir) if save_dir else None + if self.save_dir is not None and self.rank == 0: + self.save_dir.mkdir(parents=True, exist_ok=True) self.save_freq = training_params.get("save_freq", 1000) self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) + self.ckpt_keep_ratio = training_params.get("ckpt_keep_ratio") self.enable_ema = bool(training_params.get("enable_ema", False)) self.ema_decay = float(training_params.get("ema_decay", 0.999)) self.ema_ckpt_keep = int(training_params.get("ema_ckpt_keep", 3)) @@ -723,6 +731,24 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: rank=self.rank, ) + # === Derive checkpoint retention from ckpt_keep_ratio === + # num_steps is final here (including when derived from num_epoch), so the + # ratio can be converted into an absolute keep count once. + keep_ckpt_count = resolve_keep_ckpt_count( + self.ckpt_keep_ratio, self.num_steps, self.save_freq + ) + if keep_ckpt_count is not None: + self.max_ckpt_keep = keep_ckpt_count + self.ema_ckpt_keep = keep_ckpt_count + log.info( + "Resolved checkpoint retention to %d from ckpt_keep_ratio=%s " + "(num_steps=%d, save_freq=%d).", + keep_ckpt_count, + self.ckpt_keep_ratio, + self.num_steps, + self.save_freq, + ) + # Learning rate self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) self.nonfinite_grad_guard = NonFiniteGradGuard() @@ -1172,7 +1198,9 @@ def _create_full_validator( rank=self.rank, zero_stage=self.zero_stage, restart_training=self.restart_training, - checkpoint_dir=Path(self.save_ckpt).parent, + checkpoint_dir=resolve_best_checkpoint_dir( + validating_params, self.save_ckpt + ), ) def _create_ema_full_validator( @@ -1206,7 +1234,9 @@ def _create_ema_full_validator( rank=self.rank, zero_stage=self.zero_stage, restart_training=self.restart_training, - checkpoint_dir=Path(self.save_ckpt).parent, + checkpoint_dir=resolve_best_checkpoint_dir( + validating_params, self.save_ckpt + ), full_val_file=get_ema_validation_log_path( validating_params.get("full_val_file", "val.log") ), @@ -1795,21 +1825,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict: self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0 ): # Handle the case if rank 0 aborted and re-assigned - self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt") + self.latest_model = latest_checkpoint_path( + self.save_ckpt, display_step_id, self.save_dir + ) self.save_model(self.latest_model, lr=cur_lr, step=_step_id) if self.rank == 0 or dist.get_rank() == 0: log.info(f"Saved model to {self.latest_model}") - symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + symlink_prefix_files( + str(self.latest_model.with_suffix("")), self.save_ckpt + ) with open("checkpoint", "w") as f: f.write(str(self.latest_model)) if self.model_ema is not None: - self.latest_ema_model = Path( - self.ema_save_ckpt + f"-{display_step_id}.pt" + self.latest_ema_model = latest_checkpoint_path( + self.ema_save_ckpt, display_step_id, self.save_dir ) self.save_ema_model(self.latest_ema_model, lr=cur_lr, step=_step_id) if self.rank == 0 or dist.get_rank() == 0: symlink_prefix_files( - self.latest_ema_model.stem, + str(self.latest_ema_model.with_suffix("")), self.ema_save_ckpt, ) @@ -1888,30 +1922,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict: self.get_sample_func[model_key], _bias_adjust_mode="change-by-statistic", ) - self.latest_model = Path(self.save_ckpt + f"-{self.num_steps}.pt") + self.latest_model = latest_checkpoint_path( + self.save_ckpt, self.num_steps, self.save_dir + ) cur_lr = self.lr_schedule.value(self.num_steps - 1) self.save_model(self.latest_model, lr=cur_lr, step=self.num_steps - 1) log.info(f"Saved model to {self.latest_model}") - symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + symlink_prefix_files(str(self.latest_model.with_suffix("")), self.save_ckpt) with open("checkpoint", "w") as f: f.write(str(self.latest_model)) if self.model_ema is not None: - self.latest_ema_model = Path( - self.ema_save_ckpt + f"-{self.num_steps}.pt" + self.latest_ema_model = latest_checkpoint_path( + self.ema_save_ckpt, self.num_steps, self.save_dir ) self.save_ema_model( self.latest_ema_model, lr=cur_lr, step=self.num_steps - 1, ) - symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt) + symlink_prefix_files( + str(self.latest_ema_model.with_suffix("")), self.ema_save_ckpt + ) if self.num_steps == 0 and self.zero_stage > 0: # ZeRO-1 / FSDP: all ranks participate in save_model (collective op) - self.latest_model = Path(self.save_ckpt + "-0.pt") + self.latest_model = latest_checkpoint_path(self.save_ckpt, 0, self.save_dir) self.save_model(self.latest_model, lr=0, step=0) if self.model_ema is not None: - self.latest_ema_model = Path(self.ema_save_ckpt + "-0.pt") + self.latest_ema_model = latest_checkpoint_path( + self.ema_save_ckpt, 0, self.save_dir + ) self.save_ema_model(self.latest_ema_model, lr=0, step=0) if ( @@ -1920,17 +1960,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict: if self.num_steps == 0: if self.zero_stage == 0: # When num_steps is 0, the checkpoint is never saved in the loop - self.latest_model = Path(self.save_ckpt + "-0.pt") + self.latest_model = latest_checkpoint_path( + self.save_ckpt, 0, self.save_dir + ) self.save_model(self.latest_model, lr=0, step=0) if self.model_ema is not None: - self.latest_ema_model = Path(self.ema_save_ckpt + "-0.pt") + self.latest_ema_model = latest_checkpoint_path( + self.ema_save_ckpt, 0, self.save_dir + ) self.save_ema_model(self.latest_ema_model, lr=0, step=0) log.info(f"Saved model to {self.latest_model}") - symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + symlink_prefix_files( + str(self.latest_model.with_suffix("")), self.save_ckpt + ) with open("checkpoint", "w") as f: f.write(str(self.latest_model)) if self.model_ema is not None: - symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt) + symlink_prefix_files( + str(self.latest_ema_model.with_suffix("")), self.ema_save_ckpt + ) if self.timing_in_training and self.timed_steps: msg = f"average training time: {self.total_train_time / self.timed_steps:.4f} s/batch" diff --git a/deepmd/pt/train/utils.py b/deepmd/pt/train/utils.py index 489927e9ee..fe6d29b24e 100644 --- a/deepmd/pt/train/utils.py +++ b/deepmd/pt/train/utils.py @@ -9,8 +9,15 @@ from contextlib import ( contextmanager, ) +from math import ( + ceil, +) +from pathlib import ( + Path, +) from typing import ( TYPE_CHECKING, + Any, ) import torch @@ -191,3 +198,88 @@ def scoped_env_defaults(defaults: dict[str, str]) -> Generator[None, None, None] os.environ.pop(key, None) else: os.environ[key] = value + + +def latest_checkpoint_path(prefix: str, step_label: int, save_dir: Path | None) -> Path: + """ + Resolve the on-disk path of a periodic checkpoint file. + + Parameters + ---------- + prefix : str + The checkpoint prefix, e.g. ``model.ckpt`` or its EMA counterpart. + step_label : int + The training step encoded into the filename. + save_dir : Path or None + The configured checkpoint directory. When ``None`` the file follows + ``prefix`` relative to the working directory. + + Returns + ------- + Path + ``save_dir/-.pt`` when ``save_dir`` is set, otherwise + ``-.pt`` relative to the working directory. + """ + directory = save_dir if save_dir is not None else Path(prefix).parent + return directory / f"{Path(prefix).name}-{step_label}.pt" + + +def resolve_best_checkpoint_dir( + validating_params: dict[str, Any], save_ckpt: str +) -> Path: + """ + Resolve the directory for full-validation best checkpoints. + + Parameters + ---------- + validating_params : dict + The ``validating`` section of the training configuration. + save_ckpt : str + The regular checkpoint prefix from ``training.save_ckpt``. + + Returns + ------- + Path + ``validating.save_best_dir`` when set, otherwise the directory derived + from ``save_ckpt``. + """ + save_best_dir = validating_params.get("save_best_dir") + if save_best_dir: + return Path(save_best_dir) + return Path(save_ckpt).parent + + +def resolve_keep_ckpt_count( + ckpt_keep_ratio: float | None, num_steps: int, save_freq: int +) -> int | None: + """ + Convert a checkpoint-retention ratio into a sliding-window keep count. + + A checkpoint is written every ``save_freq`` steps and once more at the final + step, so a run of ``num_steps`` produces ``ceil(num_steps / save_freq)`` of + them in total (the terminal checkpoint is off-cadence when ``num_steps`` is + not a multiple of ``save_freq``). Keeping the most recent + ``ceil(ratio * total)`` is equivalent to retaining the final ``ratio`` + fraction of the run by step, without the caller computing the count by hand. + + Parameters + ---------- + ckpt_keep_ratio : float or None + The fraction of the training run, by step, whose periodic checkpoints + are retained. ``None`` leaves the keep count unchanged. + num_steps : int + The total number of training steps, already resolved (including when + derived from ``numb_epoch``). + save_freq : int + The checkpoint saving frequency in steps. + + Returns + ------- + int or None + The number of most recent checkpoints to keep (at least one), or + ``None`` when ``ckpt_keep_ratio`` is not set. + """ + if ckpt_keep_ratio is None: + return None + total_ckpts = max(1, ceil(num_steps / save_freq)) + return max(1, ceil(ckpt_keep_ratio * total_ckpts)) diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py index f8874d8b19..a40e7113c7 100644 --- a/deepmd/pt/train/validation.py +++ b/deepmd/pt/train/validation.py @@ -273,6 +273,7 @@ def __init__( self.topk_records = self._load_topk_records() self._sync_state_store() if self.rank == 0: + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self._initialize_best_checkpoints(restart_training=restart_training) # Lazily-populated full test snapshot for LMDB validation. Mixed-nloc diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index ba1a2cb347..d942252e4a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -4980,11 +4980,26 @@ def training_args( doc_disp_freq = "The frequency of printing learning curve." doc_save_freq = "The frequency of saving check point." doc_save_ckpt = "The path prefix of saving check point files." + doc_save_dir = ( + "The directory in which periodic checkpoint files are written, " + "including the regular checkpoints (the `save_ckpt` prefix) and, when " + "EMA is enabled, the EMA checkpoints. It is created recursively if it " + "does not exist. The latest-checkpoint symlinks (such as " + "`model.ckpt.pt`) and the `checkpoint` pointer file remain in the " + "working directory and reference the files in this directory. If not " + "set, checkpoints are written to the working directory." + ) doc_max_ckpt_keep = ( "The maximum number of checkpoints to keep. " "The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. " "Defaults to 5." ) + doc_ckpt_keep_ratio = ( + "An alternative to `max_ckpt_keep` that sets the number of retained " + "checkpoints as a fraction in (0, 1) of the run: the most recent " + "`ceil(ckpt_keep_ratio * ceil(numb_steps / save_freq))` checkpoints are kept. " + "When set, it overrides `max_ckpt_keep` and `ema_ckpt_keep`." + ) doc_enable_ema = ( "Whether to maintain an exponential moving average (EMA) of model " "parameters during training and save periodic EMA checkpoints with an " @@ -5109,10 +5124,26 @@ def training_args( ), Argument("disp_freq", int, optional=True, default=1000, doc=doc_disp_freq), Argument("save_freq", int, optional=True, default=1000, doc=doc_save_freq), + Argument( + "save_dir", + [str, None], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_save_dir, + ), Argument( "save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt ), Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep), + Argument( + "ckpt_keep_ratio", + [float, None], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_ckpt_keep_ratio, + extra_check=lambda x: x is None or 0.0 < x < 1.0, + extra_check_errmsg="must be a fraction in the open interval (0, 1)", + ), Argument( "enable_ema", bool, @@ -5335,6 +5366,14 @@ def validating_args() -> Argument: "The frequency, in training steps, of running the full validation pass." ) doc_save_best = "Whether to save an extra checkpoint when the selected full validation metric reaches a new best value." + doc_save_best_dir = ( + "The directory in which the best checkpoints selected by full " + "validation are written (the `best.ckpt` prefix, and the " + "`best_ema.ckpt` prefix when EMA full validation is enabled). It is " + "created recursively if it does not exist. If not set, the best " + "checkpoints are written to the directory determined by " + "`training.save_ckpt`." + ) doc_ema_full_validation = ( "Whether to additionally run the same full validation flow on the " "EMA-smoothed model when `validating.full_validation=true`. This reuses " @@ -5412,6 +5451,13 @@ def validating_args() -> Argument: default=True, doc=doc_only_pt_supported + doc_save_best, ), + Argument( + "save_best_dir", + [str, None], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_save_best_dir, + ), Argument( "max_best_ckpt", int, diff --git a/doc/train/training-advanced.md b/doc/train/training-advanced.md index 179041f6d9..0112462581 100644 --- a/doc/train/training-advanced.md +++ b/doc/train/training-advanced.md @@ -103,6 +103,8 @@ Other keys in the {ref}`training ` section are explained below: - {ref}`disp_file ` The file for printing learning curve. - {ref}`disp_freq ` The frequency of printing learning curve. Set in the unit of training steps - {ref}`save_freq ` The frequency of saving checkpoint. +- {ref}`save_dir ` The directory where periodic checkpoints are written (PyTorch backend). It is created recursively if missing, while the `model.ckpt.pt` symlinks and the `checkpoint` pointer file stay in the working directory. Defaults to the working directory. +- {ref}`ckpt_keep_ratio ` An alternative to `max_ckpt_keep` (PyTorch backend) that keeps a sliding window of `ceil(ckpt_keep_ratio * ceil(numb_steps / save_freq))` most recent checkpoints, i.e. the final `ckpt_keep_ratio` fraction of the run by step. It overrides `max_ckpt_keep` (and `ema_ckpt_keep`) when set, and works the same whether the run length is given by `numb_steps` or `numb_epoch`. ## Options and environment variables diff --git a/examples/water/dpa4/input.json b/examples/water/dpa4/input.json index 9eb48e201c..9a2260d0d6 100644 --- a/examples/water/dpa4/input.json +++ b/examples/water/dpa4/input.json @@ -99,6 +99,7 @@ "numb_steps": 2000000, "gradient_max_norm": 5.0, "save_freq": 2000, + "save_dir": "ckpt", "max_ckpt_keep": 3, "enable_ema": true, "ema_decay": 0.999, @@ -119,6 +120,7 @@ }, "validating": { "compiled_infer": false, - "tf32_infer": false + "tf32_infer": false, + "save_best_dir": "ckpt_best" } } diff --git a/source/tests/pt/test_train_utils.py b/source/tests/pt/test_train_utils.py index 49b8b74173..c3944f0972 100644 --- a/source/tests/pt/test_train_utils.py +++ b/source/tests/pt/test_train_utils.py @@ -6,6 +6,7 @@ from deepmd.pt.train.utils import ( NonFiniteGradGuard, clip_grad_norm_, + resolve_keep_ckpt_count, ) @@ -111,5 +112,22 @@ def test_resets_after_check(self) -> None: guard.raise_if_nonfinite(self._named(1.0)) +class TestResolveKeepCkptCount(unittest.TestCase): + def test_none_ratio_leaves_count_unchanged(self) -> None: + self.assertIsNone(resolve_keep_ckpt_count(None, 1000, 10)) + + def test_ratio_maps_to_recent_window_count(self) -> None: + # 1000 / 10 = 100 periodic checkpoints; 40% keeps the most recent 40. + self.assertEqual(resolve_keep_ckpt_count(0.4, 1000, 10), 40) + + def test_ratio_rounds_up(self) -> None: + # 4 periodic checkpoints; ceil(0.4 * 4) = ceil(1.6) = 2. + self.assertEqual(resolve_keep_ckpt_count(0.4, 4, 1), 2) + + def test_keeps_at_least_one(self) -> None: + # save_freq larger than num_steps yields a single (final) checkpoint. + self.assertEqual(resolve_keep_ckpt_count(0.4, 5, 100), 1) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 7851850ba0..973fb5c864 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -2,6 +2,7 @@ import functools import json import os +import platform import shutil import signal import tempfile @@ -933,6 +934,7 @@ def tearDown(self) -> None: os.chdir(self._cwd) self._tmpdir.cleanup() + @TRAINING_TEST_TIMEOUT @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None: mocked_eval.side_effect = [ @@ -964,6 +966,29 @@ def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None: self.assertEqual(val_lines[0].split()[1], "1000.0") self.assertEqual(val_lines[1].split()[1], "2000.0") + @TRAINING_TEST_TIMEOUT + @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") + def test_full_validation_save_best_dir(self, mocked_eval) -> None: + mocked_eval.side_effect = [ + {"mae_e_per_atom": 1.0}, + {"mae_e_per_atom": 2.0}, + {"mae_e_per_atom": 0.5}, + {"mae_e_per_atom": 1.5}, + ] + config = deepcopy(self.config) + config["validating"]["save_best_dir"] = "nested/best" + config["validating"]["max_best_ckpt"] = 1 + trainer = get_trainer(config) + trainer.run() + + best_dir = Path("nested/best") + self.assertTrue(best_dir.is_dir()) + # The single best checkpoint (lowest E:MAE, at step 3) lands under + # save_best_dir, and none is left in the working directory. + self.assertTrue((best_dir / "best.ckpt-3.t-1.pt").is_file()) + self.assertEqual(list(Path(".").glob("best.ckpt-*.pt")), []) + + @TRAINING_TEST_TIMEOUT @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") def test_full_validation_runs_when_start_step_is_final_step( self, mocked_eval @@ -1159,6 +1184,71 @@ def tearDown(self) -> None: os.chdir(self._cwd) self._tmpdir.cleanup() + @TRAINING_TEST_TIMEOUT + def test_ckpt_keep_ratio_overrides_keep_counts(self) -> None: + config = deepcopy(self.config) + config["training"]["ckpt_keep_ratio"] = 0.5 + trainer = get_trainer(config) + # 4 periodic checkpoints; ceil(0.5 * 4) = 2 overrides both the regular + # and EMA keep counts. + self.assertEqual(trainer.max_ckpt_keep, 2) + self.assertEqual(trainer.ema_ckpt_keep, 2) + save_ckpt = trainer.save_ckpt + ema_save_ckpt = trainer.ema_save_ckpt + trainer.run() + + self.assertEqual( + sorted(path.name for path in Path(".").glob(f"{save_ckpt}-*.pt")), + [f"{save_ckpt}-3.pt", f"{save_ckpt}-4.pt"], + ) + self.assertEqual( + sorted(path.name for path in Path(".").glob(f"{ema_save_ckpt}-*.pt")), + [f"{ema_save_ckpt}-3.pt", f"{ema_save_ckpt}-4.pt"], + ) + + @TRAINING_TEST_TIMEOUT + def test_save_dir_redirects_checkpoints_with_local_symlinks(self) -> None: + config = deepcopy(self.config) + config["training"]["save_dir"] = "ckpt" + trainer = get_trainer(config) + save_ckpt = trainer.save_ckpt + ema_save_ckpt = trainer.ema_save_ckpt + trainer.run() + + save_dir = Path("ckpt") + # Periodic regular/EMA checkpoints honor their keep limits inside save_dir. + self.assertEqual( + sorted(path.name for path in save_dir.glob(f"{save_ckpt}-*.pt")), + [f"{save_ckpt}-2.pt", f"{save_ckpt}-3.pt", f"{save_ckpt}-4.pt"], + ) + self.assertEqual( + sorted(path.name for path in save_dir.glob(f"{ema_save_ckpt}-*.pt")), + [f"{ema_save_ckpt}-3.pt", f"{ema_save_ckpt}-4.pt"], + ) + self.assertEqual(trainer.latest_model, save_dir / f"{save_ckpt}-4.pt") + + # The latest-checkpoint alias stays in the working directory and points + # at the newest checkpoint under save_dir; none is created inside + # save_dir. symlink_prefix_files() copies on Windows, so the alias is a + # content copy there rather than a symlink. + for prefix in (save_ckpt, ema_save_ckpt): + link = Path(f"{prefix}.pt") + target = save_dir / f"{prefix}-4.pt" + self.assertTrue(link.exists()) + if platform.system() == "Windows": + self.assertEqual(link.read_bytes(), target.read_bytes()) + else: + self.assertTrue(link.is_symlink()) + self.assertEqual(link.resolve(), target.resolve()) + self.assertFalse((save_dir / f"{save_ckpt}.pt").exists()) + + # The checkpoint pointer file stays in the working directory and points + # at the real file under save_dir. + self.assertEqual( + Path(Path("checkpoint").read_text().strip()), + save_dir / f"{save_ckpt}-4.pt", + ) + @TRAINING_TEST_TIMEOUT def test_ema_checkpoint_rotation(self) -> None: trainer = get_trainer(deepcopy(self.config)) diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py index 157d45daaa..c3fee681e8 100644 --- a/source/tests/pt/test_validation.py +++ b/source/tests/pt/test_validation.py @@ -280,6 +280,43 @@ def test_full_validator_restores_top_k_checkpoints(self) -> None: ["best.ckpt-10.t-2.pt", "best.ckpt-20.t-1.pt"], ) + def test_full_validator_writes_best_into_custom_checkpoint_dir(self) -> None: + train_infos = {} + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + best_dir = Path("nested/best") + validator = FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=_DummyModel(), + state_store=train_infos, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=False, + checkpoint_dir=best_dir, + ) + # The directory is created recursively at construction time. + self.assertTrue(best_dir.is_dir()) + new_best_path = validator._update_best_state( + display_step=1, + selected_metric_value=2.0, + ) + finally: + os.chdir(old_cwd) + + self.assertEqual(new_best_path, str(best_dir / "best.ckpt-1.t-1.pt")) + def test_full_validator_lmdb_full_validation_iterates_nloc_groups(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: lmdb_path = _create_mixed_nloc_lmdb(f"{tmpdir}/mixed.lmdb")