diff --git a/deepmd/dpmodel/train/__init__.py b/deepmd/dpmodel/train/__init__.py new file mode 100644 index 0000000000..3a8965284e --- /dev/null +++ b/deepmd/dpmodel/train/__init__.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Backend-independent training abstractions.""" + +from .data import ( + TrainingTaskConfig, + iter_training_task_configs, + make_task_maps, + print_data_summaries, +) +from .entrypoint import ( + AbstractTrainEntrypoint, + TrainEntrypointOptions, +) +from .trainer import ( + DEFAULT_TASK_KEY, + AbstractTrainer, + LearningCurveWriter, + RankContext, + TrainerConfig, + TrainingTask, + TrainingTaskCollection, + TrainStepResult, +) + +__all__ = [ + "DEFAULT_TASK_KEY", + "AbstractTrainEntrypoint", + "AbstractTrainer", + "LearningCurveWriter", + "RankContext", + "TrainEntrypointOptions", + "TrainStepResult", + "TrainerConfig", + "TrainingTask", + "TrainingTaskCollection", + "TrainingTaskConfig", + "iter_training_task_configs", + "make_task_maps", + "print_data_summaries", +] diff --git a/deepmd/dpmodel/train/data.py b/deepmd/dpmodel/train/data.py new file mode 100644 index 0000000000..e26ed01d48 --- /dev/null +++ b/deepmd/dpmodel/train/data.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Shared training-data helpers for backend entrypoints.""" + +from __future__ import ( + annotations, +) + +import inspect +from dataclasses import ( + dataclass, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +from .trainer import ( + DEFAULT_TASK_KEY, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Iterator, + Mapping, + ) + + +@dataclass(frozen=True) +class TrainingTaskConfig: + """Normalized config view for one training task.""" + + key: str + model_params: Mapping[str, Any] + training_data_params: Mapping[str, Any] + validation_data_params: Mapping[str, Any] | None + stat_file: str | None + valid_numb_batch: int + + +def iter_training_task_configs( + config: Mapping[str, Any], +) -> Iterator[TrainingTaskConfig]: + """Yield task configs, treating single-task input as one ``Default`` task.""" + model_params = config["model"] + training_params = config["training"] + if "model_dict" not in model_params: + validation_data_params = training_params.get("validation_data") + yield TrainingTaskConfig( + key=DEFAULT_TASK_KEY, + model_params=model_params, + training_data_params=training_params["training_data"], + validation_data_params=validation_data_params, + stat_file=training_params.get("stat_file"), + valid_numb_batch=_valid_numb_batch(validation_data_params), + ) + return + + data_dict = training_params["data_dict"] + for task_key, task_model_params in model_params["model_dict"].items(): + task_data_params = data_dict[task_key] + validation_data_params = task_data_params.get("validation_data") + yield TrainingTaskConfig( + key=task_key, + model_params=task_model_params, + training_data_params=task_data_params["training_data"], + validation_data_params=validation_data_params, + stat_file=task_data_params.get("stat_file"), + valid_numb_batch=_valid_numb_batch(validation_data_params), + ) + + +def make_task_maps( + config: Mapping[str, Any], + factory: Callable[[TrainingTaskConfig], tuple[Any, Any | None, Any | None]], +) -> tuple[dict[str, Any], dict[str, Any | None], dict[str, Any | None]]: + """Build training, validation, and stat maps from normalized task configs.""" + training_data: dict[str, Any] = {} + validation_data: dict[str, Any | None] = {} + stat_data: dict[str, Any | None] = {} + for task_config in iter_training_task_configs(config): + train_item, valid_item, stat_item = factory(task_config) + training_data[task_config.key] = train_item + validation_data[task_config.key] = valid_item + stat_data[task_config.key] = stat_item + return training_data, validation_data, stat_data + + +def print_data_summaries( + training_data: Mapping[str, Any], + validation_data: Mapping[str, Any | None], + *, + probabilities: Mapping[str, float] | None = None, +) -> None: + """Print train/validation data summaries for one or more tasks.""" + multi_task = len(training_data) > 1 + for task_key, data in training_data.items(): + name = f"training data({task_key})" if multi_task else "training" + _print_summary(data, name, _task_probability(probabilities, task_key)) + valid_data = validation_data.get(task_key) + if valid_data is not None: + name = f"validation data({task_key})" if multi_task else "validation" + _print_summary(valid_data, name, None) + + +def _valid_numb_batch(validation_data_params: Mapping[str, Any] | None) -> int: + if validation_data_params is None: + return 1 + return max(int(validation_data_params.get("numb_btch", 1)), 1) + + +def _task_probability( + probabilities: Mapping[str, float] | None, + task_key: str, +) -> list[float] | None: + if probabilities is None or task_key not in probabilities: + return None + return [float(probabilities[task_key])] + + +def _print_summary(data: Any, name: str, prob: list[float] | None) -> None: + printer = data.print_summary + try: + signature = inspect.signature(printer) + except (TypeError, ValueError): + printer(name, prob) + return + try: + signature.bind(name, prob) + except TypeError as exc: + try: + signature.bind(name) + except TypeError: + raise exc from None + printer(name) + else: + printer(name, prob) diff --git a/deepmd/dpmodel/train/entrypoint.py b/deepmd/dpmodel/train/entrypoint.py new file mode 100644 index 0000000000..2b79150fc4 --- /dev/null +++ b/deepmd/dpmodel/train/entrypoint.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Backend-independent training entrypoint pipeline.""" + +from __future__ import ( + annotations, +) + +import json +import logging +from abc import ( + ABC, + abstractmethod, +) +from dataclasses import ( + dataclass, +) +from typing import ( + Any, +) + +from deepmd.common import ( + j_loader, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) + +log = logging.getLogger(__name__) + + +@dataclass +class TrainEntrypointOptions: + """Common command options for backend train entrypoints.""" + + input_file: str + output: str = "out.json" + init_model: str | None = None + restart: str | None = None + init_frz_model: str | None = None + finetune: str | None = None + model_branch: str = "" + use_pretrain_script: bool = False + skip_neighbor_stat: bool = False + + +class AbstractTrainEntrypoint(ABC): + """Shared pipeline for backend train entrypoints. + + Backend subclasses keep ownership of backend-specific feature handling, + neighbor-stat updates, distributed setup, data construction, and trainer + construction. This pipeline only coordinates the common command flow. + """ + + def run(self, options: TrainEntrypointOptions) -> None: + """Run the training entrypoint.""" + log.info("Configuration path: %s", options.input_file) + options = self.prepare_options(options) + config = self.load_config(options.input_file) + self.validate_options(config, options) + + config = self.preprocess_config(config, options) + multi_task = self.is_multi_task(config) + config = self.update_input(config) + config = self.normalize_config(config, multi_task=multi_task) + + neighbor_stat = None + if not options.skip_neighbor_stat: + config, neighbor_stat = self.update_neighbor_stat( + config, + options, + multi_task=multi_task, + ) + + self.dump_config(config, options.output) + self.print_summary() + + try: + self.setup_run(options, config) + self.run_training(config, options, neighbor_stat) + finally: + self.teardown_run(options, config) + + def prepare_options( + self, + options: TrainEntrypointOptions, + ) -> TrainEntrypointOptions: + """Normalize command options before reading or preprocessing config.""" + return options + + def load_config(self, input_file: str) -> dict[str, Any]: + """Load the JSON/YAML training config.""" + return j_loader(input_file) + + def validate_options( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> None: + """Validate backend feature support before mutating the config.""" + return None + + def preprocess_config( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> dict[str, Any]: + """Apply backend-specific config preprocessing before argcheck.""" + return config + + def is_multi_task(self, config: dict[str, Any]) -> bool: + """Return whether the config is in multi-task layout.""" + return "model_dict" in config.get("model", {}) + + def update_input(self, config: dict[str, Any]) -> dict[str, Any]: + """Apply DeePMD input-version compatibility conversion.""" + return update_deepmd_input(config, warning=True, dump="input_v2_compat.json") + + def normalize_config( + self, + config: dict[str, Any], + *, + multi_task: bool, + ) -> dict[str, Any]: + """Run DeePMD argcheck normalization.""" + return normalize(config, multi_task=multi_task) + + def update_neighbor_stat( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + *, + multi_task: bool, + ) -> tuple[dict[str, Any], Any]: + """Update descriptor selections from neighbor statistics.""" + return config, None + + def dump_config(self, config: dict[str, Any], output: str) -> None: + """Dump the normalized config used for training.""" + with open(output, "w") as fp: + json.dump(config, fp, indent=4) + + def print_summary(self) -> None: + """Print backend summary information.""" + return None + + def setup_run( + self, + options: TrainEntrypointOptions, + config: dict[str, Any], + ) -> None: + """Set up backend runtime state before trainer execution.""" + return None + + def teardown_run( + self, + options: TrainEntrypointOptions, + config: dict[str, Any], + ) -> None: + """Tear down backend runtime state after trainer execution.""" + return None + + @abstractmethod + def run_training( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + neighbor_stat: Any, + ) -> None: + """Build backend data/trainer objects and run training.""" + raise NotImplementedError diff --git a/deepmd/dpmodel/train/trainer.py b/deepmd/dpmodel/train/trainer.py new file mode 100644 index 0000000000..83d3900e3f --- /dev/null +++ b/deepmd/dpmodel/train/trainer.py @@ -0,0 +1,695 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Backend-independent training driver abstractions. + +The classes in this module intentionally know nothing about TensorFlow, +PyTorch, JAX, or Paddle tensor semantics. Backend trainers provide the +numerical hooks; this layer owns task/rank normalization, display scheduling, +learning-curve output, and checkpoint cadence. +""" + +from __future__ import ( + annotations, +) + +import datetime +import logging +import time +from abc import ( + ABC, + abstractmethod, +) +from collections.abc import ( + Callable, + Iterator, + Mapping, + Sequence, +) +from dataclasses import ( + dataclass, + field, +) +from pathlib import ( + Path, +) +from typing import ( + Any, + TextIO, +) + +import numpy as np + +from deepmd.loggers.training import ( + format_training_message, + format_training_message_per_task, +) + +DEFAULT_TASK_KEY = "Default" + +log = logging.getLogger(__name__) + +LossResults = dict[str, float] +TaskResults = dict[str, LossResults | None] +DisplayResults = LossResults | TaskResults + + +@dataclass(frozen=True) +class RankContext: + """Rank metadata used by a trainer. + + A single-process run is represented as rank 0 in a world of size 1. This + makes single-rank training a special case of multi-rank training. + """ + + rank: int = 0 + world_size: int = 1 + + @property + def is_chief(self) -> bool: + """Whether this rank is responsible for user-visible side effects.""" + return self.rank == 0 + + +@dataclass(frozen=True) +class TrainerConfig: + """Common trainer configuration shared by backend implementations.""" + + num_steps: int + start_step: int = 0 + disp_file: str = "lcurve.out" + disp_freq: int = 1000 + save_ckpt: str = "model.ckpt" + save_freq: int = 1000 + max_ckpt_keep: int = 5 + display_in_training: bool = True + timing_in_training: bool = True + restart_training: bool = False + + @classmethod + def from_training_params( + cls, + training_params: Mapping[str, Any], + *, + num_steps: int | None = None, + start_step: int = 0, + restart_training: bool = False, + ) -> TrainerConfig: + """Create common trainer config from a normalized training section.""" + return cls( + num_steps=( + int(num_steps) + if num_steps is not None + else int(training_params["numb_steps"]) + ), + start_step=int(start_step), + disp_file=str(training_params.get("disp_file", "lcurve.out")), + disp_freq=int(training_params.get("disp_freq", 1000)), + save_ckpt=str(training_params.get("save_ckpt", "model.ckpt")), + save_freq=int(training_params.get("save_freq", 1000)), + max_ckpt_keep=int(training_params.get("max_ckpt_keep", 5)), + display_in_training=bool(training_params.get("disp_training", True)), + timing_in_training=bool(training_params.get("time_training", True)), + restart_training=restart_training, + ) + + +@dataclass +class TrainingTask: + """One training task. + + Single-task training is represented by a collection containing one task. + """ + + key: str + training_data: Any + validation_data: Any | None = None + valid_numb_batch: int = 1 + data_requirements: list[Any] = field(default_factory=list) + + def __post_init__(self) -> None: + self.valid_numb_batch = max(int(self.valid_numb_batch), 1) + + def add_data_requirements(self) -> None: + """Attach data requirements to train and validation data if possible.""" + if not self.data_requirements: + return + for data in (self.training_data, self.validation_data): + if data is not None and hasattr(data, "add_data_requirements"): + data.add_data_requirements(self.data_requirements) + + +class TrainingTaskCollection: + """Ordered collection of training tasks with optional sampling weights.""" + + def __init__( + self, + tasks: Mapping[str, TrainingTask] | Sequence[TrainingTask], + probabilities: Mapping[str, float] | Sequence[float] | None = None, + ) -> None: + if isinstance(tasks, Mapping): + task_dict = dict(tasks) + else: + task_list = list(tasks) + task_dict = {task.key: task for task in task_list} + if len(task_dict) != len(task_list): + raise ValueError("Training task keys must be unique.") + if not task_dict: + raise ValueError("At least one training task is required.") + for key, task in task_dict.items(): + if key != task.key: + raise ValueError( + f"Task mapping key {key!r} does not match task key {task.key!r}." + ) + self._tasks = task_dict + self._keys = list(task_dict) + self._probabilities = self._normalize_probabilities(probabilities) + + @classmethod + def single( + cls, + training_data: Any, + validation_data: Any | None = None, + *, + key: str = DEFAULT_TASK_KEY, + valid_numb_batch: int = 1, + data_requirements: list[Any] | None = None, + ) -> TrainingTaskCollection: + """Build a task collection for single-task training.""" + task = TrainingTask( + key=key, + training_data=training_data, + validation_data=validation_data, + valid_numb_batch=valid_numb_batch, + data_requirements=list(data_requirements or []), + ) + return cls([task]) + + @property + def keys(self) -> list[str]: + """Task keys in iteration order.""" + return list(self._keys) + + @property + def probabilities(self) -> np.ndarray: + """Normalized task sampling probabilities.""" + return self._probabilities.copy() + + @property + def is_multitask(self) -> bool: + """Whether more than one task is present.""" + return len(self._tasks) > 1 + + def __len__(self) -> int: + return len(self._tasks) + + def __iter__(self) -> Iterator[TrainingTask]: + for key in self._keys: + yield self._tasks[key] + + def __getitem__(self, key: str) -> TrainingTask: + return self._tasks[key] + + def select( + self, + choice: Callable[..., Any] | None = None, + ) -> TrainingTask: + """Select a task according to the configured probabilities.""" + if len(self._keys) == 1: + return self._tasks[self._keys[0]] + chooser = choice or np.random.choice + index = int( + chooser(np.arange(len(self._keys), dtype=np.int_), p=self._probabilities) + ) + return self._tasks[self._keys[index]] + + def _normalize_probabilities( + self, + probabilities: Mapping[str, float] | Sequence[float] | None, + ) -> np.ndarray: + if probabilities is None: + prob = np.ones(len(self._keys), dtype=np.float64) + elif isinstance(probabilities, Mapping): + missing = [key for key in self._keys if key not in probabilities] + if missing: + raise ValueError(f"Missing task probabilities for {missing}.") + unknown = [key for key in probabilities if key not in self._tasks] + if unknown: + raise ValueError(f"Unknown task probabilities for {unknown}.") + prob = np.asarray( + [probabilities[key] for key in self._keys], dtype=np.float64 + ) + else: + prob = np.asarray(probabilities, dtype=np.float64) + if prob.ndim != 1 or prob.shape[0] != len(self._keys): + raise ValueError("Task probabilities must match the number of tasks.") + if not np.all(np.isfinite(prob)): + raise ValueError("Task probabilities must be finite.") + if np.any(prob < 0.0): + raise ValueError("Task probabilities must be non-negative.") + prob_sum = float(np.sum(prob)) + if prob_sum <= 0.0: + raise ValueError("Task probabilities must sum to a positive value.") + return prob / prob_sum + + +@dataclass +class TrainStepResult: + """Backend payload returned from one optimizer step.""" + + task_key: str + step: int + payload: Any = None + train_results: LossResults | None = None + + +class LearningCurveWriter: + """Formatter for learning-curve files and per-task training logs.""" + + def __init__(self, task_keys: Sequence[str] | None = None) -> None: + self.task_keys = list(task_keys or []) + + def write_header( + self, + fp: TextIO, + train_results: DisplayResults, + valid_results: DisplayResults | None, + ) -> None: + """Write a learning-curve header.""" + fp.write(self.format_header(train_results, valid_results)) + fp.flush() + + def write_row( + self, + fp: TextIO, + *, + step: int, + learning_rate: float, + train_results: DisplayResults, + valid_results: DisplayResults | None, + ) -> None: + """Write one learning-curve row.""" + fp.write( + self.format_row( + step=step, + learning_rate=learning_rate, + train_results=train_results, + valid_results=valid_results, + ) + ) + fp.flush() + + def log_results( + self, + *, + step: int, + learning_rate: float, + train_results: DisplayResults, + valid_results: DisplayResults | None, + ) -> None: + """Log per-task loss results.""" + if self._is_multitask(train_results): + task_results = train_results + valid_task_results = ( + valid_results if isinstance(valid_results, Mapping) else {} + ) + assert isinstance(task_results, Mapping) + for task_key in self._ordered_task_keys(task_results): + task_train = task_results.get(task_key) + if task_train is None: + continue + log.info( + format_training_message_per_task( + batch=step, + task_name=f"{task_key}_trn", + rmse=task_train, + learning_rate=learning_rate, + ) + ) + task_valid = valid_task_results.get(task_key) + if task_valid: + log.info( + format_training_message_per_task( + batch=step, + task_name=f"{task_key}_val", + rmse=task_valid, + learning_rate=None, + ) + ) + else: + assert not self._is_multitask(train_results) + log.info( + format_training_message_per_task( + batch=step, + task_name="trn", + rmse=train_results, + learning_rate=learning_rate, + ) + ) + if valid_results: + assert not self._is_multitask(valid_results) + log.info( + format_training_message_per_task( + batch=step, + task_name="val", + rmse=valid_results, + learning_rate=None, + ) + ) + + def format_header( + self, + train_results: DisplayResults, + valid_results: DisplayResults | None, + ) -> str: + """Format a learning-curve header.""" + header = "# {:5s}".format("step") + if self._is_multitask(train_results): + assert isinstance(train_results, Mapping) + valid_task_results = ( + valid_results if isinstance(valid_results, Mapping) else {} + ) + for task_key in self._ordered_task_keys(train_results): + task_train = train_results.get(task_key) + if not task_train: + continue + task_valid = valid_task_results.get(task_key) + if task_valid: + for key in task_train: + header += ( + f" {key + '_val_' + task_key:>11s}" + f" {key + '_trn_' + task_key:>11s}" + ) + else: + for key in task_train: + header += f" {key + '_trn_' + task_key:>11s}" + else: + assert not self._is_multitask(train_results) + if valid_results is not None: + assert not self._is_multitask(valid_results) + for key in train_results: + header += f" {key + '_val':>11s} {key + '_trn':>11s}" + else: + for key in train_results: + header += f" {key + '_trn':>11s}" + header += " {:8s}\n".format("lr") + header += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n" + return header + + def format_row( + self, + *, + step: int, + learning_rate: float, + train_results: DisplayResults, + valid_results: DisplayResults | None, + ) -> str: + """Format one learning-curve row.""" + row = f"{step:7d}" + if self._is_multitask(train_results): + assert isinstance(train_results, Mapping) + valid_task_results = ( + valid_results if isinstance(valid_results, Mapping) else {} + ) + for task_key in self._ordered_task_keys(train_results): + task_train = train_results.get(task_key) + if not task_train: + continue + task_valid = valid_task_results.get(task_key) + if task_valid: + for key in task_train: + row += ( + f" {float(task_valid.get(key, float('nan'))):11.2e}" + f" {float(task_train[key]):11.2e}" + ) + else: + for key in task_train: + row += f" {float(task_train[key]):11.2e}" + else: + assert not self._is_multitask(train_results) + if valid_results is not None: + assert not self._is_multitask(valid_results) + for key in train_results: + row += ( + f" {float(valid_results.get(key, float('nan'))):11.2e}" + f" {float(train_results[key]):11.2e}" + ) + else: + for key in train_results: + row += f" {float(train_results[key]):11.2e}" + row += f" {learning_rate:8.1e}\n" + return row + + def _ordered_task_keys(self, results: Mapping[str, Any]) -> list[str]: + keys = self.task_keys or list(results) + return [key for key in keys if key in results] + + @staticmethod + def _is_multitask(results: Any) -> bool: + if not isinstance(results, Mapping) or not results: + return False + return all( + isinstance(value, Mapping) or value is None for value in results.values() + ) + + +class AbstractTrainer(ABC): + """Backend-independent trainer driver. + + Backend trainers implement one optimizer step, metric evaluation, learning + rate lookup, and checkpoint persistence. This base class handles the + common training loop around those hooks. + """ + + def __init__( + self, + trainer_config: TrainerConfig, + *, + rank_context: RankContext | None = None, + ) -> None: + self.trainer_config = trainer_config + self.rank_context = rank_context or RankContext() + self.lcurve_writer = LearningCurveWriter() + + def run(self, tasks: TrainingTaskCollection) -> None: + """Run the common training loop.""" + self.lcurve_writer = LearningCurveWriter(tasks.keys) + start_step = self.trainer_config.start_step + num_steps = self.trainer_config.num_steps + fout: TextIO | None = None + try: + self.on_train_begin(tasks) + fout = self._open_learning_curve() + wall_start = time.time() + last_log_time = wall_start + last_log_step = start_step + for step in range(start_step, num_steps): + task = self.select_task(tasks) + step_result = self.train_step(task, step) + display_step = step + 1 + + if self._should_display(display_step): + if self.rank_context.is_chief: + train_results, valid_results = self.collect_display_results( + tasks, + active_task=task, + step=step, + step_result=step_result, + ) + current_time = time.time() + interval_wall_time = current_time - last_log_time + interval_steps = max(1, display_step - last_log_step) + self._log_interval( + display_step=display_step, + interval_wall_time=interval_wall_time, + interval_steps=interval_steps, + wall_elapsed=current_time - wall_start, + ) + current_lr = self.learning_rate(step) + self.lcurve_writer.log_results( + step=display_step, + learning_rate=current_lr, + train_results=train_results, + valid_results=valid_results, + ) + if fout is not None: + if fout.tell() == 0: + self.lcurve_writer.write_header( + fout, + train_results=train_results, + valid_results=valid_results, + ) + self.lcurve_writer.write_row( + fout, + step=display_step, + learning_rate=current_lr, + train_results=train_results, + valid_results=valid_results, + ) + last_log_time = current_time + last_log_step = display_step + + self.run_full_validation( + step=step, + display_step=display_step, + learning_rate=self.learning_rate(step), + ) + + if ( + self.rank_context.is_chief + and self.trainer_config.save_freq > 0 + and display_step % self.trainer_config.save_freq == 0 + ): + self.save_checkpoint(display_step) + + if self._should_save_final_checkpoint(): + self.save_checkpoint(num_steps) + finally: + if fout is not None: + fout.close() + self.on_train_end(tasks) + + def select_task(self, tasks: TrainingTaskCollection) -> TrainingTask: + """Select the task for the next optimizer step.""" + return tasks.select() + + def collect_display_results( + self, + tasks: TrainingTaskCollection, + *, + active_task: TrainingTask, + step: int, + step_result: TrainStepResult, + ) -> tuple[DisplayResults, DisplayResults | None]: + """Collect training and validation results for display.""" + if not tasks.is_multitask: + return ( + self.evaluate_training(active_task, step, step_result), + self.evaluate_validation(active_task, step, step_result), + ) + + train_results: TaskResults = {} + valid_results: TaskResults = {} + for task in tasks: + task_step_result = step_result if task.key == active_task.key else None + train_results[task.key] = self.evaluate_training( + task, + step, + task_step_result, + ) + valid_results[task.key] = self.evaluate_validation( + task, + step, + task_step_result, + ) + return train_results, valid_results + + def on_train_begin(self, tasks: TrainingTaskCollection) -> None: + """Hook called before the first optimizer step.""" + return None + + def on_train_end(self, tasks: TrainingTaskCollection) -> None: + """Hook called after training resources have been closed.""" + return None + + def run_full_validation( + self, + *, + step: int, + display_step: int, + learning_rate: float, + ) -> None: + """Run optional backend-specific full validation for one step.""" + return None + + @abstractmethod + def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: + """Run one backend-specific optimizer step.""" + + @abstractmethod + def evaluate_training( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> LossResults: + """Evaluate training metrics for one task.""" + + def evaluate_validation( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> LossResults | None: + """Evaluate validation metrics for one task.""" + return None + + @abstractmethod + def learning_rate(self, step: int) -> float: + """Return the learning rate associated with a zero-based step.""" + + @abstractmethod + def save_checkpoint(self, step: int) -> None: + """Persist a checkpoint for a one-based step.""" + + def _open_learning_curve(self) -> TextIO | None: + if ( + not self.rank_context.is_chief + or not self.trainer_config.display_in_training + ): + return None + disp_path = Path(self.trainer_config.disp_file) + append = ( + self.trainer_config.restart_training or self.trainer_config.start_step > 0 + ) and disp_path.exists() + return open(disp_path, "a" if append else "w") + + def _should_display(self, display_step: int) -> bool: + if not self.trainer_config.display_in_training: + return False + return display_step == 1 or ( + self.trainer_config.disp_freq > 0 + and display_step % self.trainer_config.disp_freq == 0 + ) + + def _should_save_final_checkpoint(self) -> bool: + if not self.rank_context.is_chief: + return False + if self.trainer_config.num_steps <= self.trainer_config.start_step: + return False + if self.trainer_config.save_freq <= 0: + return True + return self.trainer_config.num_steps % self.trainer_config.save_freq != 0 + + def _log_interval( + self, + *, + display_step: int, + interval_wall_time: float, + interval_steps: int, + wall_elapsed: float, + ) -> None: + if self.trainer_config.timing_in_training: + completed = max(1, display_step - self.trainer_config.start_step) + eta = int( + (self.trainer_config.num_steps - display_step) + / completed + * wall_elapsed + ) + log.info( + format_training_message( + batch=display_step, + wall_time=interval_wall_time, + eta=eta, + current_time=datetime.datetime.fromtimestamp( + time.time(), + tz=datetime.timezone.utc, + ).astimezone(), + step_time=interval_wall_time / interval_steps, + ) + ) + else: + log.info( + format_training_message( + batch=display_step, + wall_time=interval_wall_time, + ) + ) diff --git a/deepmd/dpmodel/train/validation.py b/deepmd/dpmodel/train/validation.py new file mode 100644 index 0000000000..1f567362f0 --- /dev/null +++ b/deepmd/dpmodel/train/validation.py @@ -0,0 +1,573 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Backend-independent helpers for full validation during training.""" + +from __future__ import ( + annotations, +) + +import logging +import re +import shutil +import traceback +from abc import ( + ABC, + abstractmethod, +) +from contextlib import ( + nullcontext, +) +from dataclasses import ( + dataclass, +) +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import numpy as np + +from deepmd.utils.argcheck import ( + normalize_full_validation_metric, + resolve_full_validation_start_step, +) +from deepmd.utils.eval_metrics import ( + FULL_VALIDATION_METRIC_FAMILY_BY_KEY, + FULL_VALIDATION_METRIC_KEY_MAP, +) + +log = logging.getLogger(__name__) + +LOG_COLUMN_ORDER = [ + ("E_MAE", "mae_e_per_atom"), + ("E_RMSE", "rmse_e_per_atom"), + ("F_MAE", "mae_f"), + ("F_RMSE", "rmse_f"), + ("V_MAE", "mae_v_per_atom"), + ("V_RMSE", "rmse_v_per_atom"), +] + +TOPK_RECORDS_INFO_KEY = "full_validation_topk_records" +BEST_METRIC_NAME_INFO_KEY = "full_validation_metric" +STALE_FULL_VALIDATION_INFO_KEYS = ( + "full_validation_best_metric", + "full_validation_best_step", + "full_validation_best_path", + "full_validation_best_records", +) +BEST_CKPT_PREFIX = "best.ckpt" +VAL_LOG_SIGNIFICANT_DIGITS = 5 +VAL_LOG_COLUMN_GAP = " " +VAL_LOG_HEADER_PREFIX = "# " +VAL_LOG_DATA_PREFIX = " " +METRIC_LOG_UNIT_MAP = { + "e": ("meV/atom", 1000.0), + "f": ("meV/Å", 1000.0), + "v": ("meV/atom", 1000.0), +} + + +@dataclass(frozen=True) +class FullValidationResult: + """Result of one full validation run.""" + + display_step: int + metrics: dict[str, float] + selected_metric_key: str + selected_metric_value: float + saved_best_path: str | None + + +@dataclass(order=True, frozen=True) +class BestCheckpointRecord: + """One best-checkpoint record ordered by metric then step.""" + + metric: float + step: int + + +def build_best_checkpoint_glob( + best_checkpoint_prefix: str, + best_checkpoint_suffix: str, +) -> str: + """Build the glob pattern for managed best checkpoints.""" + return f"{best_checkpoint_prefix}-*.t-*{best_checkpoint_suffix}" + + +def build_best_checkpoint_pattern( + best_checkpoint_prefix: str, + best_checkpoint_suffix: str, +) -> re.Pattern[str]: + """Build the regex pattern for managed best checkpoints.""" + return re.compile( + rf"^{re.escape(best_checkpoint_prefix)}-(\d+)\.t-(\d+)" + rf"{re.escape(best_checkpoint_suffix)}$" + ) + + +def resolve_best_checkpoint_dir( + validating_params: dict[str, Any], + save_ckpt: str, +) -> Path: + """Resolve the directory for full-validation best checkpoints.""" + save_best_dir = validating_params.get("save_best_dir") + if save_best_dir: + return Path(save_best_dir) + return Path(save_ckpt).parent + + +def parse_validation_metric(metric: str) -> tuple[str, str]: + """Parse the configured full validation metric.""" + normalized_metric = normalize_full_validation_metric(metric) + if normalized_metric not in FULL_VALIDATION_METRIC_KEY_MAP: + supported_metrics = ", ".join( + item.upper() for item in FULL_VALIDATION_METRIC_KEY_MAP + ) + raise ValueError( + "validating.validation_metric must be one of " + f"{supported_metrics}, got {metric!r}." + ) + return normalized_metric, FULL_VALIDATION_METRIC_KEY_MAP[normalized_metric] + + +def format_metric_for_log( + metric_name: str, metric_value: float +) -> tuple[str, float, str]: + """Format a full validation metric for user-facing logging.""" + metric_family, metric_kind = metric_name.split(":") + metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + metric_label = f"{metric_family.upper()}:{metric_kind.upper()}" + return metric_label, metric_value * metric_scale, metric_unit + + +def format_metric_value_for_table( + metric_key: str, metric_value: float +) -> tuple[float, str]: + """Format one table metric value and its unit for `val.log`.""" + metric_family = FULL_VALIDATION_METRIC_FAMILY_BY_KEY.get(metric_key) + if metric_family is None: + raise ValueError(f"Unknown full validation metric key: {metric_key}") + metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + return metric_value * metric_scale, metric_unit + + +def format_metric_number_for_log(metric_value: float) -> str: + """Format one metric value for `val.log` and best-save messages.""" + if np.isnan(metric_value): + return "nan" + if np.isposinf(metric_value): + return "inf" + if np.isneginf(metric_value): + return "-inf" + if metric_value == 0.0: + return "0" + abs_value = abs(metric_value) + if abs_value < np.finfo(float).tiny: + return "0" + decimals = VAL_LOG_SIGNIFICANT_DIGITS - int(np.floor(np.log10(abs_value))) - 1 + if decimals > 16: + return f"{metric_value:.{VAL_LOG_SIGNIFICANT_DIGITS - 1}e}" + rounded_value = round(metric_value, decimals) + if rounded_value == 0.0: + rounded_value = 0.0 + if decimals > 0: + return f"{rounded_value:.{decimals}f}" + return f"{rounded_value:.0f}" + + +class FullValidatorBase(ABC): + """Run independent full validation during backend-specific training.""" + + def __init__( + self, + *, + validating_params: dict[str, Any], + state_store: dict[str, Any], + num_steps: int, + rank: int, + restart_training: bool, + checkpoint_dir: Path | None = None, + full_val_file: str | Path | None = None, + best_checkpoint_prefix: str = BEST_CKPT_PREFIX, + best_checkpoint_suffix: str, + metric_name_info_key: str = BEST_METRIC_NAME_INFO_KEY, + topk_records_info_key: str = TOPK_RECORDS_INFO_KEY, + stale_state_keys: tuple[str, ...] = STALE_FULL_VALIDATION_INFO_KEYS, + emit_best_save_log: bool = True, + ) -> None: + self.state_store = state_store + self.rank = rank + self.checkpoint_dir = ( + Path(checkpoint_dir) if checkpoint_dir is not None else Path(".") + ) + self.metric_name_info_key = metric_name_info_key + self.topk_records_info_key = topk_records_info_key + self.stale_state_keys = stale_state_keys + self.best_checkpoint_prefix = best_checkpoint_prefix + self.best_checkpoint_suffix = best_checkpoint_suffix + self.best_checkpoint_glob = build_best_checkpoint_glob( + best_checkpoint_prefix, best_checkpoint_suffix + ) + self.best_checkpoint_pattern = build_best_checkpoint_pattern( + best_checkpoint_prefix, best_checkpoint_suffix + ) + self.emit_best_save_log = emit_best_save_log + + self.full_validation = bool(validating_params.get("full_validation", False)) + self.validation_freq = int(validating_params.get("validation_freq", 5000)) + self.save_best = bool(validating_params.get("save_best", True)) + self.max_best_ckpt = int(validating_params.get("max_best_ckpt", 1)) + self.metric_name, self.metric_key = parse_validation_metric( + str(validating_params.get("validation_metric", "E:MAE")) + ) + resolved_log_file = ( + full_val_file + if full_val_file is not None + else validating_params.get("full_val_file", "val.log") + ) + self.full_val_file = Path(resolved_log_file) + self.start_step = resolve_full_validation_start_step( + validating_params.get("full_val_start", 0.5), + num_steps, + ) + self.enabled = ( + self.full_validation + and self.start_step is not None + and self.start_step <= num_steps + ) + self.step_column_width = max(len("step"), len(str(num_steps))) + self._write_mode = "a" if restart_training else "w" + self._should_write_header = not ( + restart_training and self.full_val_file.exists() + ) + self.table_column_specs = [] + for column_name, metric_key in LOG_COLUMN_ORDER: + _, metric_unit = format_metric_value_for_table(metric_key, 1.0) + header_label = f"{column_name}({metric_unit})" + self.table_column_specs.append( + (metric_key, header_label, max(len(header_label), 18)) + ) + + self.topk_records = self._load_topk_records() + self._sync_state_store() + if self.rank == 0: + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + self._initialize_best_checkpoints(restart_training=restart_training) + + def should_run(self, display_step: int) -> bool: + """Check whether the current step should trigger full validation.""" + if not self.enabled or self.start_step is None: + return False + if display_step < self.start_step: + return False + return (display_step - self.start_step) % self.validation_freq == 0 + + def run( + self, + *, + step_id: int, + display_step: int, + lr: float, + save_checkpoint: Any, + ) -> FullValidationResult | None: + """Run full validation if the current step is due.""" + if not self.should_run(display_step): + return None + + result: FullValidationResult | None = None + caught_exception: Exception | None = None + error_message = None + save_path: str | None = None + if self.rank == 0: + try: + result = self._evaluate(display_step) + save_path = result.saved_best_path + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed during evaluation:\n" + f"{traceback.format_exc()}" + ) + + self._raise_if_error(error_message, caught_exception) + + if save_path is not None and self.rank == 0: + try: + save_checkpoint(Path(save_path), lr=lr, step=step_id) + self._reconcile_best_checkpoints() + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed while saving the best checkpoint:\n" + f"{traceback.format_exc()}" + ) + else: + error_message = None + caught_exception = None + + self._raise_if_error(error_message, caught_exception) + + if self.rank == 0: + try: + self._log_result(result) + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed while writing logs:\n" + f"{traceback.format_exc()}" + ) + else: + error_message = None + caught_exception = None + + self._raise_if_error(error_message, caught_exception) + + return result if self.rank == 0 else None + + def _evaluate(self, display_step: int) -> FullValidationResult: + """Evaluate all validation systems and update best state.""" + with self.evaluation_context(): + metrics = self.evaluate_all_systems() + + if self.metric_key not in metrics or np.isnan(metrics[self.metric_key]): + raise RuntimeError( + "The selected full validation metric is unavailable on the " + f"validation dataset: {self.metric_name.upper()}." + ) + + selected_metric_value = float(metrics[self.metric_key]) + saved_best_path = self._update_best_state( + display_step=display_step, + selected_metric_value=selected_metric_value, + ) + return FullValidationResult( + display_step=display_step, + metrics=metrics, + selected_metric_key=self.metric_key, + selected_metric_value=selected_metric_value, + saved_best_path=saved_best_path, + ) + + def evaluation_context(self) -> Any: + """Return a context manager active during model evaluation.""" + return nullcontext() + + @abstractmethod + def evaluate_all_systems(self) -> dict[str, float]: + """Evaluate every validation system and aggregate metrics.""" + + def _update_best_state( + self, + *, + display_step: int, + selected_metric_value: float, + ) -> str | None: + """Update the top-K records and return the checkpoint path to save.""" + candidate = BestCheckpointRecord( + metric=selected_metric_value, + step=display_step, + ) + updated_records = [ + record for record in self.topk_records if record.step != display_step + ] + updated_records.append(candidate) + updated_records.sort() + updated_records = updated_records[: self.max_best_ckpt] + if candidate not in updated_records: + return None + + self.topk_records = updated_records + self._sync_state_store() + if not self.save_best: + return None + candidate_rank = self.topk_records.index(candidate) + 1 + return str(self._best_checkpoint_path(display_step, candidate_rank)) + + def _sync_state_store(self) -> None: + """Synchronize top-K validation state into the configured state store.""" + for key in self.stale_state_keys: + self.state_store.pop(key, None) + self.state_store[self.metric_name_info_key] = self.metric_name + self.state_store[self.topk_records_info_key] = [ + {"metric": record.metric, "step": record.step} + for record in self.topk_records + ] + + def _load_topk_records(self) -> list[BestCheckpointRecord]: + """Load top-K records from the configured state store.""" + if self.state_store.get(self.metric_name_info_key) != self.metric_name: + return [] + raw_records = self.state_store.get(self.topk_records_info_key, []) + if not isinstance(raw_records, list): + return [] + records = [] + for raw_record in raw_records: + if not isinstance(raw_record, dict): + continue + if "metric" not in raw_record or "step" not in raw_record: + continue + records.append( + BestCheckpointRecord( + metric=float(raw_record["metric"]), + step=int(raw_record["step"]), + ) + ) + records.sort() + return records[: self.max_best_ckpt] + + def _best_checkpoint_name(self, step: int, rank: int) -> str: + """Build the best-checkpoint filename for one step.""" + return f"{self.best_checkpoint_prefix}-{step}.t-{rank}{self.best_checkpoint_suffix}" + + def _best_checkpoint_path(self, step: int, rank: int) -> Path: + """Build the best-checkpoint path for one step.""" + return self.checkpoint_dir / self._best_checkpoint_name(step, rank) + + def _list_best_checkpoints(self) -> list[Path]: + """List all managed best checkpoints in the checkpoint directory.""" + best_checkpoints = [ + path + for path in self.checkpoint_dir.glob(self.best_checkpoint_glob) + if path.exists() and not path.is_symlink() + ] + best_checkpoints.sort(key=lambda path: path.stat().st_mtime) + return best_checkpoints + + @staticmethod + def _remove_checkpoint_path(path: Path) -> None: + """Remove one managed checkpoint path, file or directory.""" + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path) + else: + path.unlink(missing_ok=True) + + def _expected_topk_checkpoint_names(self) -> dict[int, str]: + """Return the expected checkpoint filename for each retained step.""" + return { + record.step: self._best_checkpoint_name(record.step, rank) + for rank, record in enumerate(self.topk_records, start=1) + } + + def _reconcile_best_checkpoints(self) -> None: + """Rename retained best checkpoints to ranked names and delete stale ones.""" + expected_names = self._expected_topk_checkpoint_names() + current_files = self._list_best_checkpoints() + files_by_step: dict[int, list[Path]] = {} + stale_files: list[Path] = [] + for checkpoint_path in current_files: + match = self.best_checkpoint_pattern.match(checkpoint_path.name) + if match is None: + stale_files.append(checkpoint_path) + continue + step = int(match.group(1)) + files_by_step.setdefault(step, []).append(checkpoint_path) + + temp_moves: list[tuple[Path, Path]] = [] + for step, checkpoint_paths in files_by_step.items(): + expected_name = expected_names.get(step) + if expected_name is None: + stale_files.extend(checkpoint_paths) + continue + + keep_path = next( + ( + checkpoint_path + for checkpoint_path in checkpoint_paths + if checkpoint_path.name == expected_name + ), + checkpoint_paths[0], + ) + for checkpoint_path in checkpoint_paths: + if checkpoint_path != keep_path: + stale_files.append(checkpoint_path) + if keep_path.name != expected_name: + temp_path = keep_path.with_name(f"{keep_path.name}.tmp") + keep_path.rename(temp_path) + temp_moves.append((temp_path, keep_path.with_name(expected_name))) + + for checkpoint_path in stale_files: + self._remove_checkpoint_path(checkpoint_path) + for temp_path, final_path in temp_moves: + self._remove_checkpoint_path(final_path) + temp_path.rename(final_path) + + def _initialize_best_checkpoints(self, restart_training: bool) -> None: + """Align on-disk best checkpoints with the current training mode.""" + if restart_training and self.save_best and self.topk_records: + self._reconcile_best_checkpoints() + return + for checkpoint_path in self._list_best_checkpoints(): + self._remove_checkpoint_path(checkpoint_path) + + def _raise_if_error( + self, + error_message: str | None, + local_exception: Exception | None = None, + ) -> None: + """Raise a full-validation error if one occurred.""" + propagated_error = self.propagate_error(error_message) + if propagated_error is None: + return + if local_exception is not None: + raise RuntimeError(propagated_error) from local_exception + raise RuntimeError(propagated_error) + + def propagate_error(self, error_message: str | None) -> str | None: + """Propagate a rank-0 full-validation error to backend peers if needed.""" + return error_message + + def _log_result(self, result: FullValidationResult | None) -> None: + """Log and persist full validation results on rank 0.""" + if result is None: + raise ValueError("Full validation logging requires a result on rank 0.") + self._write_log_file(result) + if self.emit_best_save_log and result.saved_best_path is not None: + metric_label, metric_value, metric_unit = format_metric_for_log( + self.metric_name, result.selected_metric_value + ) + log.info( + f"Saved best model to {result.saved_best_path} " + f"with {metric_label} = {format_metric_number_for_log(metric_value)} " + f"{metric_unit}" + ) + + def _write_log_file(self, result: FullValidationResult) -> None: + """Append one full validation entry to the dedicated log file.""" + with self.full_val_file.open(self._write_mode, buffering=1) as fout: + if self._should_write_header: + header = VAL_LOG_HEADER_PREFIX + f"{'step':^{self.step_column_width}s}" + for _, header_label, column_width in self.table_column_specs: + header += VAL_LOG_COLUMN_GAP + f"{header_label:^{column_width}s}" + header += "\n" + header += ( + "# E uses per-atom energy, F uses component-wise force errors, " + "and V uses virial normalized by natoms.\n" + ) + fout.write(header) + self._should_write_header = False + self._write_mode = "a" + + line = ( + VAL_LOG_DATA_PREFIX + + f"{result.display_step:^{self.step_column_width}d}" + ) + for metric_key, _, column_width in self.table_column_specs: + metric_value = result.metrics.get(metric_key, float("nan")) + if not np.isnan(metric_value): + metric_value, _ = format_metric_value_for_table( + metric_key, metric_value + ) + metric_text = format_metric_number_for_log(metric_value) + line += VAL_LOG_COLUMN_GAP + f"{metric_text:^{column_width}s}" + line += "\n" + fout.write(line) + if result.saved_best_path is not None: + metric_label, metric_value, metric_unit = format_metric_for_log( + self.metric_name, result.selected_metric_value + ) + fout.write( + "# saved best checkpoint: " + f"{result.saved_best_path} ({metric_label} = " + f"{format_metric_number_for_log(metric_value)} {metric_unit})\n" + ) diff --git a/deepmd/dpmodel/utils/training_utils.py b/deepmd/dpmodel/utils/training_utils.py index 72dfda930a..0e99293ab6 100644 --- a/deepmd/dpmodel/utils/training_utils.py +++ b/deepmd/dpmodel/utils/training_utils.py @@ -2,6 +2,7 @@ import logging from collections.abc import ( Iterable, + Sized, ) import numpy as np @@ -110,7 +111,7 @@ def resolve_model_prob( "number of systems per task." ) for ii, model_key in enumerate(model_keys): - model_prob[ii] = float(len(model_training_data[model_key])) + model_prob[ii] = float(_training_data_size(model_training_data[model_key])) if not np.all(np.isfinite(model_prob)): raise ValueError("Model prob must be finite.") if np.any(model_prob < 0.0): @@ -121,6 +122,14 @@ def resolve_model_prob( return model_prob / sum_prob +def _training_data_size(training_data: object) -> int: + if hasattr(training_data, "get_nsystems"): + return int(training_data.get_nsystems()) + if not isinstance(training_data, Sized): + return 1 + return len(training_data) + + def resolve_model_prob_from_epochs( model_keys: list[str], num_epoch_dict_config: dict[str, float], diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py index afcda0df90..fdc695548d 100644 --- a/deepmd/jax/entrypoints/train.py +++ b/deepmd/jax/entrypoints/train.py @@ -4,15 +4,19 @@ Can handle local training. """ -import json import logging import time from typing import ( Any, ) -from deepmd.common import ( - j_loader, +from deepmd.dpmodel.train import ( + AbstractTrainEntrypoint, + TrainEntrypointOptions, + TrainingTaskConfig, + iter_training_task_configs, + make_task_maps, + print_data_summaries, ) from deepmd.jax.env import ( jax, @@ -24,16 +28,13 @@ from deepmd.jax.train.trainer import ( DPTrainer, ) +from deepmd.jax.utils.serialization import ( + serialize_from_file, +) from deepmd.jax.utils.update_sel import ( use_jax_update_sel, ) from deepmd.utils import random as dp_random -from deepmd.utils.argcheck import ( - normalize, -) -from deepmd.utils.compat import ( - update_deepmd_input, -) from deepmd.utils.data_system import ( get_data, ) @@ -79,6 +80,123 @@ def get_device_name(self) -> str: return "Unknown" +class JAXTrainEntrypoint(AbstractTrainEntrypoint): + """JAX implementation of the common training entrypoint pipeline.""" + + def __init__(self) -> None: + self.finetune_links: dict[str, Any] | None = None + + def validate_options( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> None: + """Validate currently unsupported JAX train features.""" + if options.init_frz_model: + raise NotImplementedError( + "JAX training does not support init_frz_model yet" + ) + if self.is_multi_task(config) and config["model"].get("shared_dict"): + raise NotImplementedError( + "JAX multi-task training does not support shared_dict yet" + ) + + def preprocess_config( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> dict[str, Any]: + """Apply JAX fine-tuning and pretrained-script preprocessing.""" + self.finetune_links = None + if options.finetune is not None: + from deepmd.jax.utils.finetune import ( + get_finetune_rules, + ) + + config["model"], self.finetune_links = get_finetune_rules( + options.finetune, + config["model"], + model_branch=options.model_branch, + change_model_params=options.use_pretrain_script, + ) + elif options.init_model is not None and options.use_pretrain_script: + model_data = serialize_from_file(options.init_model) + config["model"] = model_data["model_def_script"] + return config + + def update_neighbor_stat( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + *, + multi_task: bool, + ) -> tuple[dict[str, Any], float | dict[str, float | None] | None]: + """Update JAX descriptor selections from neighbor statistics.""" + return update_sel(config, multi_task=multi_task) + + def print_summary(self) -> None: + """Print JAX backend summary.""" + SummaryPrinter()() + + def run_training( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + neighbor_stat: float | dict[str, float | None] | None, + ) -> None: + """Build JAX data/trainer objects and run training.""" + # make necessary checks + assert "training" in config + + model = DPTrainer( + config, + init_model=options.init_model, + restart=options.restart, + finetune_model=options.finetune, + finetune_links=self.finetune_links, + ) + if neighbor_stat is not None: + model.set_min_nbor_dist(neighbor_stat) + + # init random seed of data systems + seed = config["training"].get("seed", None) + if seed is not None: + seed += jax.process_index() + seed = seed % (2**32) + dp_random.seed(seed) + + def factory( + task_config: TrainingTaskConfig, + ) -> tuple[Any, Any | None, None]: + task_model = model.models[task_config.key] + type_map = task_model.get_type_map() + ipt_type_map = type_map if len(type_map) > 0 else None + train_data = get_data( + dict(task_config.training_data_params), + task_model.get_rcut(), + ipt_type_map, + None, + ) + valid_data = None + if task_config.validation_data_params is not None: + valid_data = get_data( + dict(task_config.validation_data_params), + task_model.get_rcut(), + train_data.type_map, + None, + ) + return train_data, valid_data, None + + train_data_map, valid_data_map, _ = make_task_maps(config, factory) + print_data_summaries(train_data_map, valid_data_map) + + start_time = time.time() + model.train(train_data_map, valid_data_map) + end_time = time.time() + log.info("finished training") + log.info(f"wall time: {(end_time - start_time):.3f} s") + + def train( *, INPUT: str, @@ -91,6 +209,7 @@ def train( log_path: str | None, skip_neighbor_stat: bool = False, finetune: str | None = None, + model_branch: str = "", use_pretrain_script: bool = False, **kwargs: Any, ) -> None: @@ -129,94 +248,50 @@ def train( RuntimeError if the training command fails. """ - # load json database - jdata = j_loader(INPUT) - - if init_frz_model: - raise NotImplementedError("JAX training does not support init_frz_model yet") - if finetune: - raise NotImplementedError("JAX training does not support finetune yet") - if use_pretrain_script: - raise NotImplementedError( - "JAX training does not support use_pretrain_script yet" + JAXTrainEntrypoint().run( + TrainEntrypointOptions( + input_file=INPUT, + output=output, + init_model=init_model, + restart=restart, + init_frz_model=init_frz_model, + finetune=finetune, + model_branch=model_branch, + use_pretrain_script=use_pretrain_script, + skip_neighbor_stat=skip_neighbor_stat, ) - - jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") - - jdata = normalize(jdata) - min_nbor_dist = None - if not skip_neighbor_stat: - jdata, min_nbor_dist = update_sel(jdata) - - with open(output, "w") as fp: - json.dump(jdata, fp, indent=4) - SummaryPrinter()() - - # make necessary checks - assert "training" in jdata - - # init the model - - model = DPTrainer( - jdata, - init_model=init_model, - restart=restart, ) - if min_nbor_dist is not None: - model.model.min_nbor_dist = min_nbor_dist - rcut = model.model.get_rcut() - type_map = model.model.get_type_map() - if len(type_map) == 0: - ipt_type_map = None - else: - ipt_type_map = type_map - - # init random seed of data systems - seed = jdata["training"].get("seed", None) - if seed is not None: - seed += jax.process_index() - seed = seed % (2**32) - dp_random.seed(seed) - - # init data - train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None) - train_data.add_data_requirements(model.data_requirements) - train_data.print_summary("training") - if jdata["training"].get("validation_data", None) is not None: - valid_data = get_data( - jdata["training"]["validation_data"], - rcut, - train_data.type_map, - None, - ) - valid_data.add_data_requirements(model.data_requirements) - valid_data.print_summary("validation") - else: - valid_data = None - - # train the model with the provided systems in a cyclic way - start_time = time.time() - model.train(train_data, valid_data) - end_time = time.time() - log.info("finished training") - log.info(f"wall time: {(end_time - start_time):.3f} s") -def update_sel(jdata: dict) -> tuple[dict, float | None]: +def update_sel( + jdata: dict, + *, + multi_task: bool | None = None, +) -> tuple[dict, float | dict[str, float | None] | None]: """Update descriptor selections from neighbor statistics when available.""" log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) jdata_cpy = jdata.copy() - type_map = jdata["model"].get("type_map") - train_data = get_data( - jdata["training"]["training_data"], - 0, # not used - type_map, - None, # not used - ) + if multi_task is None: + multi_task = "model_dict" in jdata["model"] + min_nbor_dist: dict[str, float | None] = {} with use_jax_update_sel(): - jdata_cpy["model"], min_nbor_dist = BaseModel.update_sel( - train_data, type_map, jdata["model"] - ) + for task_config in iter_training_task_configs(jdata): + type_map = task_config.model_params.get("type_map") + train_data = get_data( + dict(task_config.training_data_params), + 0, # not used + type_map, + None, # not used + ) + updated_model, task_min_nbor_dist = BaseModel.update_sel( + train_data, type_map, dict(task_config.model_params) + ) + if multi_task: + jdata_cpy["model"]["model_dict"][task_config.key] = updated_model + min_nbor_dist[task_config.key] = task_min_nbor_dist + else: + jdata_cpy["model"] = updated_model + return jdata_cpy, task_min_nbor_dist return jdata_cpy, min_nbor_dist diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 180249eaef..c77bc944b5 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -2,16 +2,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Local training utilities for the JAX backend.""" +import functools import logging import os import platform import shutil -import time +from collections.abc import ( + Mapping, +) +from copy import ( + deepcopy, +) from pathlib import ( Path, ) from typing import ( - TextIO, + Any, ) import numpy as np @@ -27,6 +33,18 @@ from deepmd.dpmodel.model.transform_output import ( communicate_extended_output, ) +from deepmd.dpmodel.train import ( + DEFAULT_TASK_KEY, + AbstractTrainer, + RankContext, + TrainerConfig, + TrainingTask, + TrainingTaskCollection, + TrainStepResult, +) +from deepmd.dpmodel.train.validation import ( + resolve_best_checkpoint_dir, +) from deepmd.dpmodel.utils.learning_rate import ( LearningRateExp, ) @@ -37,8 +55,12 @@ from deepmd.dpmodel.utils.region import ( normalize_coord, ) +from deepmd.dpmodel.utils.training_utils import ( + resolve_model_prob, +) from deepmd.jax.env import ( flax_version, + jax, jnp, nnx, ) @@ -51,9 +73,8 @@ from deepmd.jax.utils.serialization import ( serialize_from_file, ) -from deepmd.loggers.training import ( - format_training_message, - format_training_message_per_task, +from deepmd.utils.argcheck import ( + resolve_full_validation_start_step, ) from deepmd.utils.data import ( DataRequirementItem, @@ -61,6 +82,9 @@ from deepmd.utils.data_system import ( DeepmdDataSystem, ) +from deepmd.utils.finetune import ( + warn_configuration_mismatch_during_finetune, +) from deepmd.utils.model_stat import ( make_stat_input, ) @@ -68,7 +92,7 @@ log = logging.getLogger(__name__) -class DPTrainer: +class DPTrainer(AbstractTrainer): """Train JAX DeePMD models on local devices.""" def __init__( @@ -76,52 +100,89 @@ def __init__( jdata: dict, init_model: str | None = None, restart: str | None = None, + finetune_model: str | None = None, + finetune_links: dict[str, Any] | None = None, ) -> None: """Initialize the trainer from input data and optional checkpoints.""" + if finetune_model is not None and ( + init_model is not None or restart is not None + ): + raise ValueError( + "finetune_model cannot be combined with init_model or restart." + ) self.init_model = init_model self.restart = restart - self.model_def_script = jdata["model"] - self.start_step = 0 - if self.init_model is not None: - model_dict = serialize_from_file(self.init_model) - self.model = BaseModel.deserialize(model_dict["model"]) - elif self.restart is not None: - model_dict = serialize_from_file(self.restart) - self.model = BaseModel.deserialize(model_dict["model"]) - self.start_step = model_dict.get("model_def_script", {}).get( - "current_step", - model_dict.get("@variables", {}).get("current_step", 0), - ) - else: - # from scratch - self.model = get_model(jdata["model"]) + self.finetune_model = finetune_model + self.finetune_links = finetune_links + self.restart_training = restart is not None self.training_param = jdata["training"] + self.validating_param = jdata.get("validating", {}) or {} self.num_steps = self.training_param["numb_steps"] + self.start_step = 0 - def get_lr_and_coef(lr_param: dict) -> LearningRateExp: - lr_type = lr_param.get("type", "exp") - if lr_type == "exp": - lr = LearningRateExp( - **lr_param, - num_steps=self.num_steps, + self.model_def_script = jdata["model"] + self.multi_task = "model_dict" in self.model_def_script + self.model_keys = ( + list(self.model_def_script["model_dict"]) + if self.multi_task + else [DEFAULT_TASK_KEY] + ) + self.model_params_by_task = self._model_params_by_task(self.model_def_script) + + if init_model is not None or restart is not None: + checkpoint_path = init_model if init_model is not None else restart + assert checkpoint_path is not None + checkpoint_data = serialize_from_file(checkpoint_path) + checkpoint_multi_task = "model_dict" in checkpoint_data["model_def_script"] + if checkpoint_multi_task != self.multi_task: + raise ValueError( + "JAX init/restart checkpoint task layout does not match input config." ) - else: - raise RuntimeError("unknown learning_rate type " + lr_type) - return lr + checkpoint_keys = list( + checkpoint_data["model_def_script"].get( + "model_dict", {DEFAULT_TASK_KEY: None} + ) + ) + if checkpoint_keys != self.model_keys: + raise ValueError( + "JAX init/restart checkpoint task keys do not match input config." + ) + self.models = self._deserialize_models(checkpoint_data) + if restart is not None: + self.model_def_script = checkpoint_data["model_def_script"] + self.model_params_by_task = self._model_params_by_task( + self.model_def_script + ) + self.start_step = int( + checkpoint_data.get("model_def_script", {}).get( + "current_step", + checkpoint_data.get("@variables", {}).get("current_step", 0), + ) + ) + else: + self.models = { + model_key: get_model(deepcopy(self.model_params_by_task[model_key])) + for model_key in self.model_keys + } + self.model = self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] learning_rate_param = jdata["learning_rate"] - self.lr = get_lr_and_coef(learning_rate_param) - loss_param = jdata.get("loss", {}) - loss_param["starter_learning_rate"] = learning_rate_param["start_lr"] + self.lr = self._get_lr_and_coef(learning_rate_param) + self.losses = self._build_losses(jdata, learning_rate_param) + self.loss = self.losses if self.multi_task else self.losses[DEFAULT_TASK_KEY] + self.data_requirements_by_task = { + model_key: list(self.losses[model_key].label_requirement) + for model_key in self.model_keys + } - loss_type = loss_param.get("type", "ener") - if loss_type == "ener": - self.loss = EnergyLoss.get_loss(loss_param) - else: - raise RuntimeError("unknown loss type " + loss_type) + self.valid_numb_batch_by_task = self._valid_numb_batch_by_task() + self.valid_numb_batch = ( + self.valid_numb_batch_by_task + if self.multi_task + else self.valid_numb_batch_by_task[DEFAULT_TASK_KEY] + ) - # training - tr_data = jdata["training"] + tr_data = self.training_param self.disp_file = tr_data.get("disp_file", "lcurve.out") self.disp_freq = tr_data.get("disp_freq", 1000) self.save_freq = tr_data.get("save_freq", 1000) @@ -139,49 +200,390 @@ def get_lr_and_coef(lr_param: dict) -> LearningRateExp: self.change_bias_after_training = tr_data.get( "change_bias_after_training", False ) - self.numb_fparam = self.model.get_dim_fparam() - - if tr_data.get("validation_data", None) is not None: - self.valid_numb_batch = max( - tr_data["validation_data"].get("numb_btch", 1), - 1, - ) - else: - self.valid_numb_batch = 1 + self.numb_fparam = ( + {key: model.get_dim_fparam() for key, model in self.models.items()} + if self.multi_task + else self.models[DEFAULT_TASK_KEY].get_dim_fparam() + ) - # if init the graph with the frozen model self.frz_model = None self.ckpt_meta = None self.model_type = None + self.optimizers: dict[str, nnx.Optimizer] = {} + self.optimizer: nnx.Optimizer | None = None + self._train_step_impls: dict[str, Any] = {} + self._loss_fn_more_loss: dict[str, Any] = {} + self._sample_funcs: dict[str, Any] = {} + self.model_prob: np.ndarray | None = None + self.full_validator: Any | None = None + + super().__init__( + TrainerConfig.from_training_params( + tr_data, + num_steps=self.num_steps, + start_step=self.start_step, + restart_training=self.restart is not None, + ), + rank_context=RankContext( + rank=int(jax.process_index()), + world_size=int(jax.process_count()), + ), + ) + + @staticmethod + def _model_params_by_task( + model_params: dict[str, Any], + ) -> dict[str, dict[str, Any]]: + if "model_dict" in model_params: + return { + model_key: model_params["model_dict"][model_key] + for model_key in model_params["model_dict"] + } + return {DEFAULT_TASK_KEY: model_params} + + @staticmethod + def _deserialize_models(model_data: dict[str, Any]) -> dict[str, BaseModel]: + if "model_dict" in model_data["model_def_script"]: + return { + model_key: BaseModel.deserialize( + model_data["model"]["model_dict"][model_key] + ) + for model_key in model_data["model_def_script"]["model_dict"] + } + return {DEFAULT_TASK_KEY: BaseModel.deserialize(model_data["model"])} + + def _get_lr_and_coef(self, lr_param: dict[str, Any]) -> LearningRateExp: + lr_type = lr_param.get("type", "exp") + if lr_type == "exp": + return LearningRateExp(**lr_param, num_steps=self.num_steps) + raise RuntimeError("unknown learning_rate type " + lr_type) + + def _build_losses( + self, + jdata: dict[str, Any], + learning_rate_param: dict[str, Any], + ) -> dict[str, EnergyLoss]: + losses: dict[str, EnergyLoss] = {} + for model_key in self.model_keys: + loss_param = deepcopy( + jdata["loss_dict"][model_key] + if self.multi_task + else jdata.get("loss", {}) + ) + loss_param["starter_learning_rate"] = learning_rate_param["start_lr"] + loss_type = loss_param.get("type", "ener") + if loss_type != "ener": + raise RuntimeError("unknown loss type " + loss_type) + losses[model_key] = EnergyLoss.get_loss(loss_param) + return losses + + def _valid_numb_batch_by_task(self) -> dict[str, int]: + result: dict[str, int] = {} + for model_key in self.model_keys: + if self.multi_task: + valid_params = ( + self.training_param["data_dict"][model_key].get( + "validation_data", {} + ) + or {} + ) + else: + valid_params = self.training_param.get("validation_data", {}) or {} + result[model_key] = max(int(valid_params.get("numb_btch", 1)), 1) + return result + + def _create_full_validator( + self, + *, + validating_params: dict[str, Any], + validation_data: DeepmdDataSystem | None, + ) -> Any | None: + """Create the runtime full validator when it is active.""" + if not self._is_validation_requested(validating_params, "full_validation"): + return None + self._raise_if_full_validation_unsupported(validation_data) + if validation_data is None: + raise RuntimeError( + "validation_data must be available after full validation checks." + ) + from deepmd.jax.train.validation import ( + JAXFullValidator, + ) + + return JAXFullValidator( + validating_params=validating_params, + validation_data=validation_data, + model=self.models[DEFAULT_TASK_KEY], + state_store=self.model_def_script, + num_steps=self.num_steps, + rank=int(jax.process_index()), + restart_training=self.restart_training, + checkpoint_dir=resolve_best_checkpoint_dir( + validating_params, self.save_ckpt + ), + ) + + def _is_validation_requested( + self, + validating_params: dict[str, Any], + flag_name: str, + ) -> bool: + """Check whether a full validation flow can trigger during this run.""" + if not validating_params.get(flag_name, False): + return False + start_step = resolve_full_validation_start_step( + validating_params.get("full_val_start", 0.5), + self.num_steps, + ) + return start_step is not None and start_step <= self.num_steps + + def _raise_if_full_validation_unsupported( + self, + validation_data: DeepmdDataSystem | None, + ) -> None: + """Validate runtime full validation constraints.""" + if self.multi_task: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; multi-task training is not supported." + ) + + if not isinstance(self.loss, EnergyLoss): + raise ValueError( + "validating.full_validation only supports single-task energy training." + ) + + if validation_data is None: + raise ValueError( + "validating.full_validation requires `training.validation_data` " + "to be configured." + ) @property def data_requirements(self) -> list[DataRequirementItem]: - """Labels required by the configured loss.""" - return self.loss.label_requirement + """Labels required by the configured loss for single-task callers.""" + return self.data_requirements_by_task[DEFAULT_TASK_KEY] + + def set_min_nbor_dist( + self, + min_nbor_dist: float | Mapping[str, float | None] | None, + ) -> None: + """Attach neighbor-stat minimum distances to task models.""" + if min_nbor_dist is None: + return + if isinstance(min_nbor_dist, Mapping): + for model_key, value in min_nbor_dist.items(): + if value is not None and model_key in self.models: + self.models[model_key].min_nbor_dist = value + return + self.models[DEFAULT_TASK_KEY].min_nbor_dist = min_nbor_dist def train( - self, train_data: DeepmdDataSystem, valid_data: DeepmdDataSystem | None = None + self, + train_data: DeepmdDataSystem | Mapping[str, DeepmdDataSystem], + valid_data: DeepmdDataSystem + | Mapping[str, DeepmdDataSystem | None] + | None = None, ) -> None: """Run the training loop with optional validation data.""" - model = self.model - tx = optax.adam( - learning_rate=lambda step: self.lr.value(self.start_step + step), + train_data_by_task = self._normalize_data_map(train_data) + valid_data_by_task = self._normalize_data_map(valid_data, optional=True) + self._setup_training(train_data_by_task, valid_data_by_task) + tasks = TrainingTaskCollection( + [ + TrainingTask( + key=model_key, + training_data=train_data_by_task[model_key], + validation_data=valid_data_by_task[model_key], + valid_numb_batch=self.valid_numb_batch_by_task[model_key], + data_requirements=self.data_requirements_by_task[model_key], + ) + for model_key in self.model_keys + ], + probabilities=self.model_prob, ) - optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + self.run(tasks) + + def _normalize_data_map( + self, + data: Any, + *, + optional: bool = False, + ) -> dict[str, Any]: + if isinstance(data, Mapping): + return {model_key: data.get(model_key) for model_key in self.model_keys} + if optional and data is None: + return dict.fromkeys(self.model_keys) + return {DEFAULT_TASK_KEY: data} + + def _setup_training( + self, + train_data_by_task: Mapping[str, DeepmdDataSystem], + valid_data_by_task: Mapping[str, DeepmdDataSystem | None], + ) -> None: + """Initialize statistics, fine-tuning, optimizers, and JIT functions.""" + for model_key in self.model_keys: + train_data_by_task[model_key].add_data_requirements( + self.data_requirements_by_task[model_key] + ) + if valid_data_by_task[model_key] is not None: + valid_data_by_task[model_key].add_data_requirements( + self.data_requirements_by_task[model_key] + ) + + if self.multi_task: + self.model_prob = resolve_model_prob( + self.model_keys, + self.training_param.get("model_prob"), + dict(train_data_by_task), + ) + + for model_key in self.model_keys: + self._sample_funcs[model_key] = self._make_sample_func( + train_data_by_task[model_key], + self.model_params_by_task[model_key].get("data_stat_nbatch", 10), + ) - # data stat if self.init_model is None and self.restart is None: - data_stat_nbatch = self.model_def_script.get("data_stat_nbatch", 10) + for model_key in self.model_keys: + finetune_has_new_type = ( + self.finetune_model is not None + and self.finetune_links is not None + and model_key in self.finetune_links + and self.finetune_links[model_key].get_has_new_type() + ) + if self.finetune_model is None or finetune_has_new_type: + self.models[model_key].atomic_model.compute_or_load_stat( + self._sample_funcs[model_key] + ) + + if self.finetune_model is not None: + self._apply_finetune() + + for model_key in self.model_keys: + tx = optax.chain( + optax.scale_by_adam(), + _scale_by_global_learning_rate(), + ) + self.optimizers[model_key] = nnx.Optimizer( + self.models[model_key], tx, wrt=nnx.Param + ) + ( + self._train_step_impls[model_key], + self._loss_fn_more_loss[model_key], + ) = self._make_step_functions(self.losses[model_key]) + self.optimizer = ( + self.optimizers[DEFAULT_TASK_KEY] if not self.multi_task else None + ) + self.model = self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] + self.full_validator = self._create_full_validator( + validating_params=self.validating_param, + validation_data=valid_data_by_task[DEFAULT_TASK_KEY] + if not self.multi_task + else None, + ) + + @staticmethod + def _make_sample_func( + train_data: DeepmdDataSystem, + data_stat_nbatch: int, + ) -> Any: + @functools.lru_cache + def sample() -> list[dict[str, Any]]: stat_data = make_stat_input(train_data, data_stat_nbatch) - stat_data_jax = [ + return [ { - kk: jnp.asarray(vv) if isinstance(vv, np.ndarray) else vv - for kk, vv in single_data.items() + key: jnp.asarray(value) if isinstance(value, np.ndarray) else value + for key, value in single_data.items() } for single_data in stat_data ] - model.atomic_model.compute_or_load_stat(lambda: stat_data_jax) + return sample + + def _apply_finetune(self) -> None: + if self.finetune_model is None or self.finetune_links is None: + return + pretrained_data = serialize_from_file(self.finetune_model) + pretrained_params = pretrained_data["model_def_script"] + pretrained_models = self._deserialize_models(pretrained_data) + for model_key in self.model_keys: + finetune_rule = self.finetune_links[model_key] + source_key = finetune_rule.get_model_branch() + if source_key not in pretrained_models: + raise ValueError( + f"Pretrained model branch {source_key!r} is not available." + ) + source_model = pretrained_models[source_key] + if finetune_rule.get_finetune_tmap() != source_model.get_type_map(): + model_with_new_type_stat = ( + self.models[model_key] if finetune_rule.get_has_new_type() else None + ) + source_model.change_type_map( + finetune_rule.get_finetune_tmap(), + model_with_new_type_stat=model_with_new_type_stat, + ) + self._warn_finetune_config_mismatch( + model_key, source_key, pretrained_params + ) + self.models[model_key] = self._copy_finetune_state( + self.models[model_key], + source_model, + random_fitting=finetune_rule.get_random_fitting(), + ) + if finetune_rule.get_resuming(): + log.info("Model branch %s will resume training.", model_key) + continue + bias_mode = ( + "change-by-statistic" + if not finetune_rule.get_random_fitting() + else "set-by-statistic" + ) + self.models[model_key].change_out_bias( + self._sample_funcs[model_key], + bias_adjust_mode=bias_mode, + ) + + def _warn_finetune_config_mismatch( + self, + model_key: str, + source_key: str, + pretrained_params: dict[str, Any], + ) -> None: + input_model_params = self.model_params_by_task[model_key] + branch_pretrained_params = ( + pretrained_params["model_dict"][source_key] + if "model_dict" in pretrained_params + else pretrained_params + ) + if ( + "descriptor" in input_model_params + and "descriptor" in branch_pretrained_params + ): + warn_configuration_mismatch_during_finetune( + input_model_params["descriptor"], + branch_pretrained_params["descriptor"], + source_key, + ) + + @staticmethod + def _copy_finetune_state( + target_model: BaseModel, + source_model: BaseModel, + *, + random_fitting: bool, + ) -> BaseModel: + graphdef, target_state = nnx.split(target_model) + _, source_state = nnx.split(source_model) + copied = _copy_matching_state_tree( + target_state.to_pure_dict(), + source_state.to_pure_dict(), + random_fitting=random_fitting, + ) + target_state.replace_by_pure_dict(copied) + return nnx.merge(graphdef, target_state) + + def _make_step_functions(self, loss_obj: EnergyLoss) -> tuple[Any, Any]: def loss_fn( model: BaseModel, lr: float, @@ -193,25 +595,10 @@ def loss_fn( fp: jnp.ndarray | None, ap: jnp.ndarray | None, ) -> jnp.ndarray: - model_dict_lower = model.call_common_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ) - model_dict = communicate_extended_output( - model_dict_lower, - model.model_output_def(), - mapping, - do_atomic_virial=False, + model_dict = _evaluate_model_dict( + model, extended_coord, extended_atype, nlist, mapping, fp, ap ) - model_dict["atom_energy"] = model_dict["energy"] - model_dict["energy"] = model_dict["energy_redu"] - model_dict["force"] = model_dict["energy_derv_r"].squeeze(-2) - model_dict["virial"] = model_dict["energy_derv_c_redu"].squeeze(-2) - loss, more_loss = self.loss( + loss, _ = loss_obj( learning_rate=lr, natoms=label_dict["type"].shape[1], model_dict=model_dict, @@ -231,25 +618,10 @@ def loss_fn_more_loss( fp: jnp.ndarray | None, ap: jnp.ndarray | None, ) -> dict[str, jnp.ndarray]: - model_dict_lower = model.call_common_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, + model_dict = _evaluate_model_dict( + model, extended_coord, extended_atype, nlist, mapping, fp, ap ) - model_dict = communicate_extended_output( - model_dict_lower, - model.model_output_def(), - mapping, - do_atomic_virial=False, - ) - model_dict["atom_energy"] = model_dict["energy"] - model_dict["energy"] = model_dict["energy_redu"] - model_dict["force"] = model_dict["energy_derv_r"].squeeze(-2) - model_dict["virial"] = model_dict["energy_derv_c_redu"].squeeze(-2) - loss, more_loss = self.loss( + _, more_loss = loss_obj( learning_rate=lr, natoms=label_dict["type"].shape[1], model_dict=model_dict, @@ -282,127 +654,186 @@ def train_step( ap, ) if Version(flax_version) >= Version("0.11.0"): - optimizer.update(model, grads) + optimizer.update(model, grads, learning_rate=lr) else: - optimizer.update(grads) - - start_time = time.time() - disp_path = Path(self.disp_file) - disp_mode = "a" if self.start_step > 0 and disp_path.exists() else "w" - with open(disp_path, disp_mode) as disp_file_fp: - for step in range(self.start_step, self.num_steps): - batch_data = train_data.get_batch() - # numpy to jax - jax_data = convert_numpy_data_to_jax_data(batch_data) - extended_coord, extended_atype, nlist, mapping, fp, ap = prepare_input( - rcut=model.get_rcut(), - sel=model.get_sel(), - coord=jax_data["coord"], - atype=jax_data["type"], - box=jax_data["box"] if jax_data["find_box"] else None, - fparam=jax_data.get("fparam", None), - aparam=jax_data.get("aparam", None), - ) - train_step( - model, - optimizer, - self.lr.value(step), - jax_data, - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ) - if self.display_in_training and ( - step == 0 or (step + 1) % self.disp_freq == 0 - ): - wall_time = time.time() - start_time - log.info( - format_training_message( - batch=step + 1, - wall_time=wall_time, - ) - ) - more_loss = loss_fn_more_loss( - model, - self.lr.value(step), - jax_data, - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ) - if valid_data is not None: - valid_more_loss_list = [] - for _ in range(self.valid_numb_batch): - valid_batch_data = valid_data.get_batch() - jax_valid_data = convert_numpy_data_to_jax_data( - valid_batch_data - ) - extended_coord, extended_atype, nlist, mapping, fp, ap = ( - prepare_input( - rcut=model.get_rcut(), - sel=model.get_sel(), - coord=jax_valid_data["coord"], - atype=jax_valid_data["type"], - box=jax_valid_data["box"] - if jax_valid_data["find_box"] - else None, - fparam=jax_valid_data.get("fparam", None), - aparam=jax_valid_data.get("aparam", None), - ) - ) - valid_more_loss_list.append( - loss_fn_more_loss( - model, - self.lr.value(step), - jax_valid_data, - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ) - ) - valid_more_loss = { - key: sum(loss[key] for loss in valid_more_loss_list) - / len(valid_more_loss_list) - for key in valid_more_loss_list[0] - } - else: - valid_more_loss = None - if disp_file_fp.tell() == 0: - self.print_header( - disp_file_fp, - train_results=more_loss, - valid_results=valid_more_loss, - ) - self.print_on_training( - disp_file_fp, - train_results=more_loss, - valid_results=valid_more_loss, - cur_batch=step + 1, - cur_lr=self.lr.value(step), - ) - start_time = time.time() - if (step + 1) % self.save_freq == 0: - self._save_checkpoint(model, step + 1) - if self.num_steps > self.start_step and self.num_steps % self.save_freq != 0: - self._save_checkpoint(model, self.num_steps) + _legacy_optimizer_update(optimizer, grads, lr) + + return train_step, loss_fn_more_loss + + def select_task(self, tasks: TrainingTaskCollection) -> TrainingTask: + """Select a task using DeePMD's seeded random helper.""" + if len(tasks) == 1: + return tasks[tasks.keys[0]] + from deepmd.utils import random as dp_random + + model_index = dp_random.choice( + np.arange(len(tasks), dtype=np.int_), + p=tasks.probabilities, + ) + return tasks[tasks.keys[int(model_index)]] + + def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: + """Run one JAX optimizer step.""" + task_key = task.key + if task_key not in self.optimizers or task_key not in self._train_step_impls: + raise RuntimeError("JAX trainer has not been initialized.") + prepared = self._prepare_batch(task.training_data.get_batch(), task_key) + self._train_step_impls[task_key]( + self.models[task_key], + self.optimizers[task_key], + self.lr.value(step), + *prepared, + ) + return TrainStepResult(task_key=task_key, step=step, payload=prepared) + + def evaluate_training( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float]: + """Evaluate training loss terms for display.""" + prepared = ( + step_result.payload + if step_result is not None and step_result.task_key == task.key + else self._prepare_batch(task.training_data.get_batch(), task.key) + ) + return self._evaluate_prepared_batch(task.key, step, prepared) + + def evaluate_validation( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float] | None: + """Evaluate validation loss terms for display.""" + if task.validation_data is None: + return None + valid_more_loss_list = [ + self._evaluate_prepared_batch( + task.key, + step, + self._prepare_batch(task.validation_data.get_batch(), task.key), + ) + for _ in range(task.valid_numb_batch) + ] + return { + key: sum(loss[key] for loss in valid_more_loss_list) + / len(valid_more_loss_list) + for key in valid_more_loss_list[0] + } + + def learning_rate(self, step: int) -> float: + """Return the configured learning rate for a zero-based step.""" + return float(self.lr.value(step)) + + def save_checkpoint(self, step: int) -> None: + """Persist a JAX checkpoint for a one-based step.""" + self._save_checkpoint(step) + + def run_full_validation( + self, + *, + step: int, + display_step: int, + learning_rate: float, + ) -> None: + """Run optional full validation for one step.""" + if self.full_validator is None: + return None + self.full_validator.run( + step_id=display_step, + display_step=display_step, + lr=learning_rate, + save_checkpoint=self._save_full_validation_checkpoint, + ) + return None + + def _prepare_batch( + self, + batch_data: dict[str, np.ndarray | np.floating], + task_key: str, + ) -> tuple[ + dict[str, jnp.ndarray | bool], + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray | None, + jnp.ndarray | None, + jnp.ndarray | None, + ]: + """Convert one data-system batch into JAX model inputs.""" + model = self.models[task_key] + jax_data = convert_numpy_data_to_jax_data(batch_data) + extended_coord, extended_atype, nlist, mapping, fp, ap = prepare_input( + rcut=model.get_rcut(), + sel=model.get_sel(), + coord=jax_data["coord"], + atype=jax_data["type"], + box=jax_data["box"] if jax_data["find_box"] else None, + fparam=jax_data.get("fparam", None), + aparam=jax_data.get("aparam", None), + ) + return jax_data, extended_coord, extended_atype, nlist, mapping, fp, ap - def _save_checkpoint(self, model: BaseModel, step: int) -> None: + def _evaluate_prepared_batch( + self, + task_key: str, + step: int, + prepared: tuple[ + dict[str, jnp.ndarray | bool], + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray | None, + jnp.ndarray | None, + jnp.ndarray | None, + ], + ) -> dict[str, float]: + if task_key not in self._loss_fn_more_loss: + raise RuntimeError("JAX trainer has not been initialized.") + more_loss = self._loss_fn_more_loss[task_key]( + self.models[task_key], + self.lr.value(step), + *prepared, + ) + return {key: float(value) for key, value in more_loss.items()} + + def _save_checkpoint(self, step: int) -> None: """Save a JAX checkpoint and update the stable checkpoint pointer.""" - _, state = nnx.split(model) ckpt_path = Path(f"{self.save_ckpt}-{step}.jax") + self._write_checkpoint(ckpt_path, step=step) + log.info(f"Trained model has been saved to: {ckpt_path!s}") + _link_checkpoint(ckpt_path, Path(f"{self.save_ckpt}.jax")) + self._cleanup_old_checkpoints() + with open("checkpoint", "w") as fp: + fp.write(f"{self.save_ckpt}.jax") + + def _save_full_validation_checkpoint( + self, + save_path: Path, + lr: float = 0.0, + step: int = 0, + ) -> None: + """Save a full-validation-selected JAX checkpoint.""" + del lr + self._write_checkpoint(save_path, step=step) + + def _write_checkpoint(self, ckpt_path: Path, *, step: int) -> None: + """Write a JAX checkpoint directory to an explicit path.""" + if self.multi_task: + state = { + "models": { + model_key: nnx.split(model)[1].to_pure_dict() + for model_key, model in self.models.items() + } + } + else: + _, single_state = nnx.split(self.models[DEFAULT_TASK_KEY]) + state = single_state.to_pure_dict() if ckpt_path.is_dir(): - # remove old checkpoint if it exists shutil.rmtree(ckpt_path) - model_def_script_cpy = self.model_def_script.copy() + model_def_script_cpy = deepcopy(self.model_def_script) model_def_script_cpy["current_step"] = step with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") @@ -410,15 +841,10 @@ def _save_checkpoint(self, model: BaseModel, step: int) -> None: checkpointer.save( ckpt_path.absolute(), ocp.args.Composite( - state=ocp.args.StandardSave(state.to_pure_dict()), + state=ocp.args.StandardSave(state), model_def_script=ocp.args.JsonSave(model_def_script_cpy), ), ) - log.info(f"Trained model has been saved to: {ckpt_path!s}") - _link_checkpoint(ckpt_path, Path(f"{self.save_ckpt}.jax")) - self._cleanup_old_checkpoints() - with open("checkpoint", "w") as fp: - fp.write(f"{self.save_ckpt}.jax") def _cleanup_old_checkpoints(self) -> None: """Remove old checkpoint directories beyond the retention limit.""" @@ -436,68 +862,127 @@ def _cleanup_old_checkpoints(self) -> None: for _, path in sorted(checkpoints)[: -self.max_ckpt_keep]: shutil.rmtree(path) - @staticmethod - def print_on_training( - fp: TextIO, - train_results: dict[str, float], - valid_results: dict[str, float] | None, - cur_batch: int, - cur_lr: float, - ) -> None: - """Append one training/validation loss row to the learning-curve file.""" - print_str = "" - print_str += f"{cur_batch:7d}" - if valid_results is not None: - prop_fmt = " %11.2e %11.2e" - for k in valid_results.keys(): - # assert k in train_results.keys() - print_str += prop_fmt % (valid_results[k], train_results[k]) - else: - prop_fmt = " %11.2e" - for k in train_results.keys(): - print_str += prop_fmt % (train_results[k]) - print_str += f" {cur_lr:8.1e}\n" - log.info( - format_training_message_per_task( - batch=cur_batch, - task_name="trn", - rmse=train_results, - learning_rate=cur_lr, - ) + +def _evaluate_model_dict( + model: BaseModel, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: jnp.ndarray | None, + fp: jnp.ndarray | None, + ap: jnp.ndarray | None, +) -> dict[str, jnp.ndarray]: + model_dict_lower = model.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + model_dict = communicate_extended_output( + model_dict_lower, + model.model_output_def(), + mapping, + do_atomic_virial=False, + ) + model_dict["atom_energy"] = model_dict["energy"] + model_dict["energy"] = model_dict["energy_redu"] + model_dict["force"] = model_dict["energy_derv_r"].squeeze(-2) + model_dict["virial"] = model_dict["energy_derv_c_redu"].squeeze(-2) + return model_dict + + +def _init_empty_state(params: Any) -> optax.EmptyState: + """Initialize an empty Optax state without requiring optax.init_empty_state. + + Older Optax releases expose ``EmptyState`` but not the convenience helper + ``init_empty_state``. Constructing the state locally keeps this transform + compatible with the Optax versions selected by the current JAX/Flax pins. + """ + del params + return optax.EmptyState() + + +def _scale_by_global_learning_rate() -> optax.GradientTransformationExtraArgs: + """Scale optimizer updates by the learning rate from the global step.""" + + def update_fn( + updates: Any, + state: optax.EmptyState, + params: Any = None, + **kwargs: Any, + ) -> tuple[Any, optax.EmptyState]: + del params + learning_rate = kwargs["learning_rate"] + updates = jax.tree_util.tree_map( + lambda update: -learning_rate * update, updates ) - if valid_results is not None: - log.info( - format_training_message_per_task( - batch=cur_batch, - task_name="val", - rmse=valid_results, - learning_rate=None, - ) + return updates, state + + return optax.GradientTransformationExtraArgs(_init_empty_state, update_fn) + + +def _legacy_optimizer_update(optimizer: Any, grads: Any, lr: float) -> None: + """Run an NNX optimizer update with extra args on Flax before 0.11.""" + from flax.nnx.training.optimizer import ( + _opt_state_variables_to_state, + _update_opt_state, + ) + + params = nnx.state(optimizer.model, optimizer.wrt) + opt_state = _opt_state_variables_to_state(optimizer.opt_state) + updates, new_opt_state = optimizer.tx.update( + grads, + opt_state, + params, + learning_rate=lr, + ) + new_params = optax.apply_updates(params, updates) + optimizer.step.value += 1 + nnx.update(optimizer.model, new_params) + _update_opt_state(optimizer.opt_state, new_opt_state) + + +def _copy_matching_state_tree( + target: Any, + source: Any, + *, + random_fitting: bool, + path: tuple[Any, ...] = (), +) -> Any: + if isinstance(target, dict): + if not isinstance(source, dict): + return target + return { + key: _copy_matching_state_tree( + value, + source.get(key), + random_fitting=random_fitting, + path=(*path, key), ) - fp.write(print_str) - fp.flush() + for key, value in target.items() + } + if source is None: + return target + if random_fitting and not any("descriptor" in str(part) for part in path): + return target + if _same_state_leaf(target, source): + return source + return target - @staticmethod - def print_header( - fp: TextIO, - train_results: dict[str, float], - valid_results: dict[str, float] | None, - ) -> None: - """Write the learning-curve header for the configured loss terms.""" - print_str = "" - print_str += "# {:5s}".format("step") - if valid_results is not None: - prop_fmt = " %11s %11s" - for k in train_results.keys(): - print_str += prop_fmt % (k + "_val", k + "_trn") - else: - prop_fmt = " %11s" - for k in train_results.keys(): - print_str += prop_fmt % (k + "_trn") - print_str += " {:8s}\n".format("lr") - print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n" - fp.write(print_str) - fp.flush() + +def _same_state_leaf(target: Any, source: Any) -> bool: + target_shape = getattr(target, "shape", None) + source_shape = getattr(source, "shape", None) + target_dtype = getattr(target, "dtype", None) + source_dtype = getattr(source, "dtype", None) + return ( + target_shape is not None + and source_shape is not None + and target_shape == source_shape + and target_dtype == source_dtype + ) def _link_checkpoint(source: Path, target: Path) -> None: diff --git a/deepmd/jax/train/validation.py b/deepmd/jax/train/validation.py new file mode 100644 index 0000000000..62cb21563c --- /dev/null +++ b/deepmd/jax/train/validation.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Full validation support for the JAX trainer.""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np + +from deepmd.dpmodel.train.validation import ( + LOG_COLUMN_ORDER, + FullValidatorBase, +) +from deepmd.jax.env import ( + jax, + jnp, +) +from deepmd.jax.utils.auto_batch_size import ( + AutoBatchSize, +) +from deepmd.utils.eval_metrics import ( + FULL_VALIDATION_WEIGHTED_METRIC_KEYS, + compute_energy_type_metrics, +) +from deepmd.utils.weight_avg import ( + weighted_average, +) + +if TYPE_CHECKING: + from deepmd.jax.model.base_model import ( + BaseModel, + ) + + +class JAXFullValidator(FullValidatorBase): + """Run full validation for a single-task JAX energy model.""" + + def __init__( + self, + *, + validating_params: dict[str, Any], + validation_data: Any, + model: BaseModel, + state_store: dict[str, Any], + num_steps: int, + rank: int, + restart_training: bool, + checkpoint_dir: Any = None, + ) -> None: + self.validation_data = validation_data + self.model = model + self.auto_batch_size = AutoBatchSize(silent=True) + super().__init__( + validating_params=validating_params, + state_store=state_store, + num_steps=num_steps, + rank=rank, + restart_training=restart_training, + checkpoint_dir=checkpoint_dir, + best_checkpoint_suffix=".jax", + ) + + def evaluate_all_systems(self) -> dict[str, float]: + """Evaluate every validation system and aggregate metrics.""" + system_metrics = [ + self._evaluate_system(data_system) + for data_system in self._iter_validation_data_systems() + ] + aggregated = weighted_average([metric for metric in system_metrics if metric]) + return { + metric_key: float(aggregated[metric_key]) + for _, metric_key in LOG_COLUMN_ORDER + if metric_key in aggregated + } + + def propagate_error(self, error_message: str | None) -> str | None: + """Broadcast rank-0 full-validation failures to all JAX processes.""" + if jax.process_count() <= 1: + return error_message + from jax.experimental import ( + multihost_utils, + ) + + has_error = multihost_utils.broadcast_one_to_all( + np.asarray(error_message is not None, dtype=np.bool_), + is_source=self.rank == 0, + ) + if not bool(np.asarray(has_error).item()): + return None + return error_message or "Full validation failed on rank 0; see rank-0 logs." + + def _iter_validation_data_systems(self) -> Any: + """Yield DeepmdData-like validation systems.""" + validation_data = self.validation_data + if hasattr(validation_data, "data_systems"): + yield from validation_data.data_systems + return + if hasattr(validation_data, "get_test"): + yield validation_data + return + if hasattr(validation_data, "systems"): + for dataset in validation_data.systems: + yield getattr(dataset, "data_system", dataset) + return + raise TypeError( + "JAX full validation expects a DeepmdDataSystem, DeepmdData-like " + f"object, or loader set; got {type(validation_data)!r}." + ) + + def _evaluate_system(self, data_system: Any) -> dict[str, tuple[float, float]]: + """Evaluate one validation system.""" + test_data = data_system.get_test() + natoms = int(test_data["type"].shape[1]) + nframes = int(test_data["coord"].shape[0]) + has_pbc = bool(getattr(data_system, "pbc", False)) + include_virial = has_pbc and bool(test_data.get("find_virial", 0.0)) + prediction = self._predict_outputs( + coord=test_data["coord"].reshape(nframes, -1), + atom_types=test_data["type"], + box=test_data["box"] if has_pbc else None, + fparam=test_data["fparam"] + if self.model.get_dim_fparam() > 0 + and bool(test_data.get("find_fparam", 0.0)) + else None, + aparam=test_data["aparam"] if self.model.get_dim_aparam() > 0 else None, + include_virial=include_virial, + natoms=natoms, + nframes=nframes, + ) + shared_metrics = compute_energy_type_metrics( + prediction=prediction, + test_data=test_data, + natoms=natoms, + has_pbc=has_pbc, + ) + return shared_metrics.as_weighted_average_errors( + FULL_VALIDATION_WEIGHTED_METRIC_KEYS + ) + + def _predict_outputs( + self, + *, + coord: np.ndarray, + atom_types: np.ndarray, + box: np.ndarray | None, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + include_virial: bool, + natoms: int, + nframes: int, + ) -> dict[str, np.ndarray]: + """Predict energy, force, and virial for the full validation batch.""" + + def predict_batch( + coord_batch: np.ndarray, + atom_types_batch: np.ndarray, + box_batch: np.ndarray | None, + fparam_batch: np.ndarray | None, + aparam_batch: np.ndarray | None, + ) -> dict[str, np.ndarray]: + from deepmd.jax.train.trainer import ( + _evaluate_model_dict, + prepare_input, + ) + + coord_input = jnp.asarray(coord_batch.reshape(-1, natoms * 3)) + type_input = jnp.asarray(atom_types_batch) + box_input = ( + jnp.asarray(box_batch.reshape(-1, 9)) if box_batch is not None else None + ) + fparam_input = ( + jnp.asarray(fparam_batch.reshape(-1, self.model.get_dim_fparam())) + if fparam_batch is not None + else None + ) + aparam_input = ( + jnp.asarray( + aparam_batch.reshape(-1, natoms, self.model.get_dim_aparam()) + ) + if aparam_batch is not None + else None + ) + extended_coord, extended_atype, nlist, mapping, fp, ap = prepare_input( + rcut=self.model.get_rcut(), + sel=self.model.get_sel(), + coord=coord_input, + atype=type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + ) + batch_output = _evaluate_model_dict( + self.model, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + prediction = { + "energy": np.asarray(batch_output["energy"]).reshape(-1, 1), + "force": np.asarray(batch_output["force"]).reshape(-1, natoms * 3), + } + if include_virial: + if "virial" not in batch_output: + raise KeyError( + "Full validation requested virial metrics, but model " + "output does not contain `virial`." + ) + prediction["virial"] = np.asarray(batch_output["virial"]).reshape(-1, 9) + return prediction + + batch_prediction = self.auto_batch_size.execute_all( + predict_batch, + nframes, + natoms, + coord, + atom_types, + box, + fparam, + aparam, + ) + prediction = { + "energy": np.asarray(batch_prediction["energy"]), + "force": np.asarray(batch_prediction["force"]), + } + if include_virial: + prediction["virial"] = np.asarray(batch_prediction["virial"]) + return prediction diff --git a/deepmd/jax/utils/finetune.py b/deepmd/jax/utils/finetune.py new file mode 100644 index 0000000000..88de03f67d --- /dev/null +++ b/deepmd/jax/utils/finetune.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Fine-tuning config utilities for the JAX backend.""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +from deepmd.jax.utils.serialization import ( + serialize_from_file, +) +from deepmd.utils.finetune import ( + FinetuneRuleItem, + get_finetune_rules_from_model_params, +) + + +def _load_model_params(finetune_model: str) -> dict[str, Any]: + """Extract model params from a JAX checkpoint.""" + return serialize_from_file(finetune_model)["model_def_script"] + + +def get_finetune_rules( + finetune_model: str, + model_config: dict[str, Any], + model_branch: str = "", + change_model_params: bool = True, +) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: + """Build JAX fine-tuning rules for single-task or multi-task configs.""" + return get_finetune_rules_from_model_params( + _load_model_params(finetune_model), + model_config, + model_branch=model_branch, + change_model_params=change_model_params, + ) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index e484475235..a2df36160c 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -25,6 +25,31 @@ ) +def _convert_str_to_int_key(item: dict) -> None: + """Convert Orbax-restored numeric index keys from strings back to ints.""" + for key, value in item.copy().items(): + if isinstance(value, dict): + _convert_str_to_int_key(value) + if isinstance(key, str) and key.isdigit(): + item[int(key)] = item.pop(key) + + +def _normalize_restored_state_keys( + state: dict, + model_def_script: dict, +) -> None: + """Normalize restored state keys while preserving multi-task branch names.""" + if "model_dict" in model_def_script: + state_by_model = state.get("models", state) + for model_key in model_def_script["model_dict"]: + if model_key in state_by_model and isinstance( + state_by_model[model_key], dict + ): + _convert_str_to_int_key(state_by_model[model_key]) + return + _convert_str_to_int_key(state) + + def _state_sequence_to_numpy_list(state_value: Any) -> list[np.ndarray]: """Convert an Orbax-restored list/dict sequence to NumPy arrays.""" if isinstance(state_value, dict): @@ -171,7 +196,6 @@ def deserialize_to_file(model_file: str, data: dict) -> None: The dictionary to be deserialized. """ if model_file.endswith(".jax"): - model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"].copy() min_nbor_dist = _to_optional_float(data.get("min_nbor_dist")) if min_nbor_dist is None: @@ -180,14 +204,28 @@ def deserialize_to_file(model_file: str, data: dict) -> None: ) if min_nbor_dist is not None: model_def_script["_min_nbor_dist"] = min_nbor_dist - _, state = nnx.split(model) + if "model_dict" in model_def_script: + models = { + model_key: BaseModel.deserialize(data["model"]["model_dict"][model_key]) + for model_key in model_def_script["model_dict"] + } + state = { + "models": { + model_key: nnx.split(model)[1].to_pure_dict() + for model_key, model in models.items() + } + } + else: + model = BaseModel.deserialize(data["model"]) + _, state = nnx.split(model) + state = state.to_pure_dict() with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") ) as checkpointer: checkpointer.save( Path(model_file).absolute(), ocp.args.Composite( - state=ocp.args.StandardSave(state.to_pure_dict()), + state=ocp.args.StandardSave(state), model_def_script=ocp.args.JsonSave(model_def_script), ), ) @@ -319,27 +357,31 @@ def serialize_from_file(model_file: str) -> dict: ), ) state = data.state - - # convert str "1" to int 1 key - def convert_str_to_int_key(item: dict) -> None: - for key, value in item.copy().items(): - if isinstance(value, dict): - convert_str_to_int_key(value) - if key.isdigit(): - item[int(key)] = item.pop(key) - - convert_str_to_int_key(state) - model_def_script = data.model_def_script - abstract_model = get_model(model_def_script) - _restore_compression_slots_from_state(abstract_model, state) - graphdef, abstract_state = nnx.split(abstract_model) - abstract_state.replace_by_pure_dict(state) - model = nnx.merge(graphdef, abstract_state) - model_dict = model.serialize() - min_nbor_dist = _to_optional_float(model.get_min_nbor_dist()) - if min_nbor_dist is None: - min_nbor_dist = _to_optional_float(model_def_script.get("_min_nbor_dist")) + _normalize_restored_state_keys(state, model_def_script) + min_nbor_dist = None + + def restore_model(model_params: dict, model_state: dict) -> BaseModel: + abstract_model = get_model(model_params) + _restore_compression_slots_from_state(abstract_model, model_state) + graphdef, abstract_state = nnx.split(abstract_model) + abstract_state.replace_by_pure_dict(model_state) + return nnx.merge(graphdef, abstract_state) + + if "model_dict" in model_def_script: + state_by_model = state.get("models", state) + model_dict = {"model_dict": {}} + for model_key, model_params in model_def_script["model_dict"].items(): + model = restore_model(model_params, state_by_model[model_key]) + model_dict["model_dict"][model_key] = model.serialize() + else: + model = restore_model(model_def_script, state) + model_dict = model.serialize() + min_nbor_dist = _to_optional_float(model.get_min_nbor_dist()) + if min_nbor_dist is None: + min_nbor_dist = _to_optional_float( + model_def_script.get("_min_nbor_dist") + ) data = { "backend": "JAX", "jax_version": jax.__version__, diff --git a/deepmd/pd/utils/finetune.py b/deepmd/pd/utils/finetune.py index 510a91c392..3bd259cc82 100644 --- a/deepmd/pd/utils/finetune.py +++ b/deepmd/pd/utils/finetune.py @@ -1,210 +1,40 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging -from copy import ( - deepcopy, +"""Fine-tuning rule helpers for the Paddle backend.""" + +from typing import ( + Any, ) import paddle from deepmd.utils.finetune import ( FinetuneRuleItem, - warn_descriptor_config_differences, +) +from deepmd.utils.finetune import get_finetune_rule_single as get_finetune_rule_single +from deepmd.utils.finetune import ( + get_finetune_rules_from_model_params, ) -log = logging.getLogger(__name__) - - -def get_finetune_rule_single( - _single_param_target: dict, - _model_param_pretrained: dict, - from_multitask: bool = False, - model_branch: str = "Default", - model_branch_from: str = "", - change_model_params: bool = False, -) -> dict: - single_config = deepcopy(_single_param_target) - new_fitting = False - model_branch_chosen = "Default" - - if not from_multitask: - single_config_chosen = deepcopy(_model_param_pretrained) - if model_branch_from == "RANDOM": - # not ["", "RANDOM"], because single-from-single finetune uses pretrained fitting in default - new_fitting = True - else: - model_dict_params = _model_param_pretrained["model_dict"] - if model_branch_from in ["", "RANDOM"]: - model_branch_chosen = next(iter(model_dict_params.keys())) - new_fitting = True - log.warning( - "The fitting net will be re-init instead of using that in the pretrained model! " - "The bias_adjust_mode will be set-by-statistic!" - ) - else: - model_branch_chosen = model_branch_from - assert model_branch_chosen in model_dict_params, ( - f"No model branch named '{model_branch_chosen}'! " - f"Available ones are {list(model_dict_params.keys())}." - ) - single_config_chosen = deepcopy(model_dict_params[model_branch_chosen]) - old_type_map, new_type_map = ( - single_config_chosen["type_map"], - single_config["type_map"], - ) - finetune_rule = FinetuneRuleItem( - p_type_map=old_type_map, - type_map=new_type_map, - model_branch=model_branch_chosen, - random_fitting=new_fitting, - ) - if change_model_params: - if "descriptor" in single_config and "descriptor" in single_config_chosen: - warn_descriptor_config_differences( - single_config["descriptor"], - single_config_chosen["descriptor"], - model_branch_chosen, - ) - trainable_param = { - "descriptor": single_config.get("descriptor", {}).get("trainable", True), - "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), - } - single_config["descriptor"] = single_config_chosen["descriptor"] - if not new_fitting: - single_config["fitting_net"] = single_config_chosen["fitting_net"] - log.info( - f"Change the '{model_branch}' model configurations according to the model branch " - f"'{model_branch_chosen}' in the pretrained one..." - ) - for net_type in trainable_param: - if net_type in single_config: - single_config[net_type]["trainable"] = trainable_param[net_type] - else: - single_config[net_type] = {"trainable": trainable_param[net_type]} - return single_config, finetune_rule +__all__ = ["get_finetune_rule_single", "get_finetune_rules"] def get_finetune_rules( finetune_model: str, - model_config: dict, + model_config: dict[str, Any], model_branch: str = "", change_model_params: bool = True, -) -> tuple[dict, str]: - """ - Get fine-tuning rules and (optionally) change the model_params according to the pretrained one. - - This function gets the fine-tuning rules and (optionally) changes input in different modes as follows: - 1. Single-task fine-tuning from a single-task pretrained model: - - The model will be fine-tuned based on the pretrained model. - - (Optional) Updates the model parameters based on the pretrained model. - 2. Single-task fine-tuning from a multi-task pretrained model: - - The model will be fine-tuned based on the selected branch in the pretrained model. - The chosen branch can be defined from the command-line or `finetune_head` input parameter. - If not defined, model parameters in the fitting network will be randomly initialized. - - (Optional) Updates the model parameters based on the selected branch in the pretrained model. - 3. Multi-task fine-tuning from a single-task pretrained model: - - The model in each branch will be fine-tuned or resumed based on the single branch ('Default') in the pretrained model. - The chosen branches can be defined from the `finetune_head` input parameter of each branch. - - If `finetune_head` is defined as 'Default', - it will be fine-tuned based on the single branch ('Default') in the pretrained model. - - If `finetune_head` is not defined and the model_key is 'Default', - it will resume from the single branch ('Default') in the pretrained model without fine-tuning. - - If `finetune_head` is not defined and the model_key is not 'Default', - it will be fine-tuned based on the single branch ('Default') in the pretrained model, - while model parameters in the fitting network of the branch will be randomly initialized. - - (Optional) Updates model parameters in each branch based on the single branch ('Default') in the pretrained model. - 4. Multi-task fine-tuning from a multi-task pretrained model: - - The model in each branch will be fine-tuned or resumed based on the chosen branches in the pretrained model. - The chosen branches can be defined from the `finetune_head` input parameter of each branch. - - If `finetune_head` is defined as one of the branches in the pretrained model, - it will be fine-tuned based on the chosen branch in the pretrained model. - - If `finetune_head` is not defined and the model_key is the same as one of those in the pretrained model, - it will resume from the model_key branch in the pretrained model without fine-tuning. - - If `finetune_head` is not defined and a new model_key is used, - it will be fine-tuned based on the chosen branch in the pretrained model, - while model parameters in the fitting network of the branch will be randomly initialized. - - (Optional) Updates model parameters in each branch based on the chosen branches in the pretrained model. - - Parameters - ---------- - finetune_model - The pretrained model. - model_config - The fine-tuning input parameters. - model_branch - The model branch chosen in command-line mode, only for single-task fine-tuning. - change_model_params - Whether to change the model parameters according to the pretrained one. - - Returns - ------- - model_config: - Updated model parameters. - finetune_links: - Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs. - """ - multi_task = "model_dict" in model_config +) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: + """Get fine-tuning rules for a single-task or multi-task Paddle model.""" state_dict = paddle.load(finetune_model) if "model" in state_dict: state_dict = state_dict["model"] - last_model_params = state_dict["_extra_state"]["model_params"] - finetune_from_multi_task = "model_dict" in last_model_params - finetune_links = {} - if not multi_task: - # use command-line first - if model_branch == "" and "finetune_head" in model_config: - model_branch = model_config["finetune_head"] - model_config, finetune_rule = get_finetune_rule_single( - model_config, - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch="Default", - model_branch_from=model_branch, - change_model_params=change_model_params, - ) - finetune_links["Default"] = finetune_rule - else: - assert model_branch == "", ( + return get_finetune_rules_from_model_params( + state_dict["_extra_state"]["model_params"], + model_config, + model_branch=model_branch, + change_model_params=change_model_params, + multitask_branch_error=( "Multi-task fine-tuning does not support command-line branches chosen!" "Please define the 'finetune_head' in each model params!" - ) - target_keys = model_config["model_dict"].keys() - if not finetune_from_multi_task: - pretrained_keys = ["Default"] - else: - pretrained_keys = last_model_params["model_dict"].keys() - for model_key in target_keys: - resuming = False - if ( - "finetune_head" in model_config["model_dict"][model_key] - and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" - ): - pretrained_key = model_config["model_dict"][model_key]["finetune_head"] - assert pretrained_key in pretrained_keys, ( - f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" - f"Available heads are: {list(pretrained_keys)}" - ) - model_branch_from = pretrained_key - elif ( - "finetune_head" not in model_config["model_dict"][model_key] - and model_key in pretrained_keys - ): - # not do anything if not defined "finetune_head" in heads that exist in the pretrained model - # this will just do resuming - model_branch_from = model_key - resuming = True - else: - # if not defined "finetune_head" in new heads or "finetune_head" is "RANDOM", the fitting net will bre randomly initialized - model_branch_from = "RANDOM" - model_config["model_dict"][model_key], finetune_rule = ( - get_finetune_rule_single( - model_config["model_dict"][model_key], - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch=model_key, - model_branch_from=model_branch_from, - change_model_params=change_model_params, - ) - ) - finetune_links[model_key] = finetune_rule - finetune_links[model_key].resuming = resuming - return model_config, finetune_links + ), + ) diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py index a40e7113c7..f206df03c3 100644 --- a/deepmd/pt/train/validation.py +++ b/deepmd/pt/train/validation.py @@ -7,6 +7,7 @@ import logging import re +import shutil import traceback from contextlib import ( nullcontext, @@ -122,14 +123,21 @@ class BestCheckpointRecord: step: int -def build_best_checkpoint_glob(best_checkpoint_prefix: str) -> str: +def build_best_checkpoint_glob( + best_checkpoint_prefix: str, best_checkpoint_suffix: str = ".pt" +) -> str: """Build the glob pattern for managed best checkpoints.""" - return f"{best_checkpoint_prefix}-*.t-*.pt" + return f"{best_checkpoint_prefix}-*.t-*{best_checkpoint_suffix}" -def build_best_checkpoint_pattern(best_checkpoint_prefix: str) -> re.Pattern[str]: +def build_best_checkpoint_pattern( + best_checkpoint_prefix: str, best_checkpoint_suffix: str = ".pt" +) -> re.Pattern[str]: """Build the regex pattern for managed best checkpoints.""" - return re.compile(rf"^{re.escape(best_checkpoint_prefix)}-(\d+)\.t-(\d+)\.pt$") + return re.compile( + rf"^{re.escape(best_checkpoint_prefix)}-(\d+)\.t-(\d+)" + rf"{re.escape(best_checkpoint_suffix)}$" + ) def parse_validation_metric(metric: str) -> tuple[str, str]: @@ -208,6 +216,7 @@ def __init__( checkpoint_dir: Path | None = None, full_val_file: str | Path | None = None, best_checkpoint_prefix: str = BEST_CKPT_PREFIX, + best_checkpoint_suffix: str = ".pt", metric_name_info_key: str = BEST_METRIC_NAME_INFO_KEY, topk_records_info_key: str = TOPK_RECORDS_INFO_KEY, stale_state_keys: tuple[str, ...] = STALE_FULL_VALIDATION_INFO_KEYS, @@ -227,9 +236,12 @@ def __init__( self.topk_records_info_key = topk_records_info_key self.stale_state_keys = stale_state_keys self.best_checkpoint_prefix = best_checkpoint_prefix - self.best_checkpoint_glob = build_best_checkpoint_glob(best_checkpoint_prefix) + self.best_checkpoint_suffix = best_checkpoint_suffix + self.best_checkpoint_glob = build_best_checkpoint_glob( + best_checkpoint_prefix, best_checkpoint_suffix + ) self.best_checkpoint_pattern = build_best_checkpoint_pattern( - best_checkpoint_prefix + best_checkpoint_prefix, best_checkpoint_suffix ) self.emit_best_save_log = emit_best_save_log self.model_eval_context = model_eval_context or nullcontext @@ -432,6 +444,16 @@ def _iter_validation_data_systems(self) -> Iterator[Any]: yield LmdbTestDataNlocView(lmdb_test_data, nloc) return + if hasattr(validation_data, "_reader"): + lmdb_test_data = self._get_lmdb_test_data_snapshot(validation_data) + for nloc in sorted(lmdb_test_data.nloc_groups.keys()): + yield LmdbTestDataNlocView(lmdb_test_data, nloc) + return + + if hasattr(validation_data, "data_systems"): + yield from validation_data.data_systems + return + for dataset in validation_data.systems: if not isinstance(dataset, DeepmdDataSetForLoader): raise TypeError( @@ -440,7 +462,7 @@ def _iter_validation_data_systems(self) -> Iterator[Any]: ) yield dataset.data_system - def _get_lmdb_test_data_snapshot(self, lmdb_dataset: LmdbDataset) -> LmdbTestData: + def _get_lmdb_test_data_snapshot(self, lmdb_dataset: Any) -> LmdbTestData: """Build (once) and return the cached LMDB test snapshot. Reuses the ``type_map`` and previously-registered @@ -451,12 +473,32 @@ def _get_lmdb_test_data_snapshot(self, lmdb_dataset: LmdbDataset) -> LmdbTestDat if self._lmdb_test_data is not None: return self._lmdb_test_data + reader = getattr(lmdb_dataset, "_reader", None) + lmdb_path = getattr(lmdb_dataset, "lmdb_path", None) + if lmdb_path is None and reader is not None: + lmdb_path = getattr(reader, "lmdb_path", None) + type_map = getattr(lmdb_dataset, "type_map", None) + if type_map is None and reader is not None: + type_map = getattr(reader, "type_map", None) + data_requirements = getattr(lmdb_dataset, "data_requirements", None) + if data_requirements is None and reader is not None: + data_requirements = getattr(reader, "data_requirements", None) + if lmdb_path is None: + raise TypeError( + "Full validation could not resolve the LMDB path from " + f"{type(lmdb_dataset)!r}." + ) + if type_map is None: + raise TypeError( + "Full validation could not resolve the LMDB type_map from " + f"{type(lmdb_dataset)!r}." + ) + self._lmdb_test_data = LmdbTestData( - lmdb_dataset.lmdb_path, - type_map=list(lmdb_dataset.type_map), + lmdb_path, + type_map=list(type_map), shuffle_test=False, ) - data_requirements = lmdb_dataset.data_requirements if data_requirements: self._lmdb_test_data.add_data_requirement(data_requirements) return self._lmdb_test_data @@ -521,7 +563,7 @@ def predict_batch( ), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE, - ) + ).requires_grad_(True) type_input = torch.tensor( atom_types_batch.astype(np.int64), dtype=torch.long, @@ -662,7 +704,7 @@ def _load_topk_records(self) -> list[BestCheckpointRecord]: def _best_checkpoint_name(self, step: int, rank: int) -> str: """Build the best-checkpoint filename for one step.""" - return f"{self.best_checkpoint_prefix}-{step}.t-{rank}.pt" + return f"{self.best_checkpoint_prefix}-{step}.t-{rank}{self.best_checkpoint_suffix}" def _best_checkpoint_path(self, step: int, rank: int) -> Path: """Build the best-checkpoint path for one step.""" @@ -673,11 +715,19 @@ def _list_best_checkpoints(self) -> list[Path]: best_checkpoints = [ path for path in self.checkpoint_dir.glob(self.best_checkpoint_glob) - if path.is_file() and not path.is_symlink() + if path.exists() and not path.is_symlink() ] best_checkpoints.sort(key=lambda path: path.stat().st_mtime) return best_checkpoints + @staticmethod + def _remove_checkpoint_path(path: Path) -> None: + """Remove one managed checkpoint path, file or directory.""" + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path) + else: + path.unlink(missing_ok=True) + def _expected_topk_checkpoint_names(self) -> dict[int, str]: """Return the expected checkpoint filename for each retained step.""" return { @@ -723,9 +773,9 @@ def _reconcile_best_checkpoints(self) -> None: temp_moves.append((temp_path, keep_path.with_name(expected_name))) for checkpoint_path in stale_files: - checkpoint_path.unlink(missing_ok=True) + self._remove_checkpoint_path(checkpoint_path) for temp_path, final_path in temp_moves: - final_path.unlink(missing_ok=True) + self._remove_checkpoint_path(final_path) temp_path.rename(final_path) def _initialize_best_checkpoints(self, restart_training: bool) -> None: @@ -734,7 +784,7 @@ def _initialize_best_checkpoints(self, restart_training: bool) -> None: self._reconcile_best_checkpoints() return for checkpoint_path in self._list_best_checkpoints(): - checkpoint_path.unlink(missing_ok=True) + self._remove_checkpoint_path(checkpoint_path) def _raise_if_distributed_error( self, diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index c4db694578..bf1c328408 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging -from copy import ( - deepcopy, -) +"""Fine-tuning rule helpers for the PyTorch backend.""" + from typing import ( Any, ) @@ -14,85 +12,13 @@ ) from deepmd.utils.finetune import ( FinetuneRuleItem, - warn_descriptor_config_differences, ) -from deepmd.utils.model_branch_dict import ( - get_model_dict, +from deepmd.utils.finetune import get_finetune_rule_single as get_finetune_rule_single +from deepmd.utils.finetune import ( + get_finetune_rules_from_model_params, ) -log = logging.getLogger(__name__) - - -def get_finetune_rule_single( - _single_param_target: dict[str, Any], - _model_param_pretrained: dict[str, Any], - from_multitask: bool = False, - model_branch: str = "Default", - model_branch_from: str = "", - change_model_params: bool = False, -) -> tuple[dict[str, Any], FinetuneRuleItem]: - single_config = deepcopy(_single_param_target) - new_fitting = False - model_branch_chosen = "Default" - - if not from_multitask: - single_config_chosen = deepcopy(_model_param_pretrained) - if model_branch_from == "RANDOM": - # not ["", "RANDOM"], because single-from-single finetune uses pretrained fitting in default - new_fitting = True - else: - model_dict_params = _model_param_pretrained["model_dict"] - if model_branch_from in ["", "RANDOM"]: - model_branch_chosen = next(iter(model_dict_params.keys())) - new_fitting = True - log.warning( - "The fitting net will be re-init instead of using that in the pretrained model! " - "The bias_adjust_mode will be set-by-statistic!" - ) - else: - model_branch_chosen = model_branch_from - model_alias_dict, model_branch_dict = get_model_dict(model_dict_params) - assert model_branch_chosen in model_alias_dict, ( - f"No model branch or alias named '{model_branch_chosen}'! " - f"Available ones are {list(model_dict_params.keys())}." - f"Use `dp --pt show your_model.pt model-branch` to show detail information." - ) - model_branch_chosen = model_alias_dict[model_branch_chosen] - single_config_chosen = deepcopy(model_dict_params[model_branch_chosen]) - old_type_map, new_type_map = ( - single_config_chosen["type_map"], - single_config["type_map"], - ) - finetune_rule = FinetuneRuleItem( - p_type_map=old_type_map, - type_map=new_type_map, - model_branch=model_branch_chosen, - random_fitting=new_fitting, - ) - if change_model_params: - if "descriptor" in single_config and "descriptor" in single_config_chosen: - warn_descriptor_config_differences( - single_config["descriptor"], - single_config_chosen["descriptor"], - model_branch_chosen, - ) - trainable_param = { - "descriptor": single_config.get("descriptor", {}).get("trainable", True), - "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), - } - single_config["descriptor"] = single_config_chosen["descriptor"] - if not new_fitting: - single_config["fitting_net"] = single_config_chosen["fitting_net"] - log.info( - f"Change the '{model_branch}' model configurations according to the model branch " - f"'{model_branch_chosen}' in the pretrained one..." - ) - for net_type in trainable_param: - if net_type in single_config: - single_config[net_type]["trainable"] = trainable_param[net_type] - else: - single_config[net_type] = {"trainable": trainable_param[net_type]} - return single_config, finetune_rule +__all__ = ["get_finetune_rule_single", "get_finetune_rules"] def get_finetune_rules( @@ -101,122 +27,17 @@ def get_finetune_rules( model_branch: str = "", change_model_params: bool = True, ) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: - """ - Get fine-tuning rules and (optionally) change the model_params according to the pretrained one. - - This function gets the fine-tuning rules and (optionally) changes input in different modes as follows: - 1. Single-task fine-tuning from a single-task pretrained model: - - The model will be fine-tuned based on the pretrained model. - - (Optional) Updates the model parameters based on the pretrained model. - 2. Single-task fine-tuning from a multi-task pretrained model: - - The model will be fine-tuned based on the selected branch in the pretrained model. - The chosen branch can be defined from the command-line or `finetune_head` input parameter. - If not defined, model parameters in the fitting network will be randomly initialized. - - (Optional) Updates the model parameters based on the selected branch in the pretrained model. - 3. Multi-task fine-tuning from a single-task pretrained model: - - The model in each branch will be fine-tuned or resumed based on the single branch ('Default') in the pretrained model. - The chosen branches can be defined from the `finetune_head` input parameter of each branch. - - If `finetune_head` is defined as 'Default', - it will be fine-tuned based on the single branch ('Default') in the pretrained model. - - If `finetune_head` is not defined and the model_key is 'Default', - it will resume from the single branch ('Default') in the pretrained model without fine-tuning. - - If `finetune_head` is not defined and the model_key is not 'Default', - it will be fine-tuned based on the single branch ('Default') in the pretrained model, - while model parameters in the fitting network of the branch will be randomly initialized. - - (Optional) Updates model parameters in each branch based on the single branch ('Default') in the pretrained model. - 4. Multi-task fine-tuning from a multi-task pretrained model: - - The model in each branch will be fine-tuned or resumed based on the chosen branches in the pretrained model. - The chosen branches can be defined from the `finetune_head` input parameter of each branch. - - If `finetune_head` is defined as one of the branches in the pretrained model, - it will be fine-tuned based on the chosen branch in the pretrained model. - - If `finetune_head` is not defined and the model_key is the same as one of those in the pretrained model, - it will resume from the model_key branch in the pretrained model without fine-tuning. - - If `finetune_head` is not defined and a new model_key is used, - it will be fine-tuned based on the chosen branch in the pretrained model, - while model parameters in the fitting network of the branch will be randomly initialized. - - (Optional) Updates model parameters in each branch based on the chosen branches in the pretrained model. - - Parameters - ---------- - finetune_model - The pretrained model. - model_config - The fine-tuning input parameters. - model_branch - The model branch chosen in command-line mode, only for single-task fine-tuning. - change_model_params - Whether to change the model parameters according to the pretrained one. - - Returns - ------- - model_config: - Updated model parameters. - finetune_links: - Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs. - """ - multi_task = "model_dict" in model_config + """Get fine-tuning rules for a single-task or multi-task PyTorch model.""" state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] - last_model_params = state_dict["_extra_state"]["model_params"] - finetune_from_multi_task = "model_dict" in last_model_params - finetune_links = {} - if not multi_task: - # use command-line first - if model_branch == "" and "finetune_head" in model_config: - model_branch = model_config["finetune_head"] - model_config, finetune_rule = get_finetune_rule_single( - model_config, - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch="Default", - model_branch_from=model_branch, - change_model_params=change_model_params, - ) - finetune_links["Default"] = finetune_rule - else: - assert model_branch == "", ( + return get_finetune_rules_from_model_params( + state_dict["_extra_state"]["model_params"], + model_config, + model_branch=model_branch, + change_model_params=change_model_params, + multitask_branch_error=( "Multi-task fine-tuning does not support command-line branches chosen!" "Please define the 'finetune_head' in each model params!" - ) - target_keys = model_config["model_dict"].keys() - if not finetune_from_multi_task: - pretrained_keys = ["Default"] - else: - pretrained_keys = last_model_params["model_dict"].keys() - for model_key in target_keys: - resuming = False - if ( - "finetune_head" in model_config["model_dict"][model_key] - and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" - ): - pretrained_key = model_config["model_dict"][model_key]["finetune_head"] - assert pretrained_key in pretrained_keys, ( - f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" - f"Available heads are: {list(pretrained_keys)}" - ) - model_branch_from = pretrained_key - elif ( - "finetune_head" not in model_config["model_dict"][model_key] - and model_key in pretrained_keys - ): - # not do anything if not defined "finetune_head" in heads that exist in the pretrained model - # this will just do resuming - model_branch_from = model_key - resuming = True - else: - # if not defined "finetune_head" in new heads or "finetune_head" is "RANDOM", the fitting net will bre randomly initialized - model_branch_from = "RANDOM" - model_config["model_dict"][model_key], finetune_rule = ( - get_finetune_rule_single( - model_config["model_dict"][model_key], - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch=model_key, - model_branch_from=model_branch_from, - change_model_params=change_model_params, - ) - ) - finetune_links[model_key] = finetune_rule - finetune_links[model_key].resuming = resuming - return model_config, finetune_links + ), + ) diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index da28229bf4..9d09938f5d 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -2,9 +2,11 @@ """Training entrypoint for the pt_expt backend.""" import argparse -import json import logging import os +from dataclasses import ( + replace, +) from pathlib import ( Path, ) @@ -14,6 +16,14 @@ import h5py +from deepmd.dpmodel.train import ( + AbstractTrainEntrypoint, + TrainEntrypointOptions, + TrainingTaskConfig, + iter_training_task_configs, + make_task_maps, + print_data_summaries, +) from deepmd.dpmodel.utils.lmdb_data import ( is_lmdb, ) @@ -23,12 +33,6 @@ from deepmd.pt_expt.utils.lmdb_dataset import ( LmdbDataSystem, ) -from deepmd.utils.argcheck import ( - normalize, -) -from deepmd.utils.compat import ( - update_deepmd_input, -) from deepmd.utils.data_system import ( DeepmdDataSystem, get_data, @@ -37,9 +41,19 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter log = logging.getLogger(__name__) +_PT_EXPT_MODEL_SUFFIXES = (".pt", ".pte", ".pt2") + + +def _ensure_pt_expt_model_suffix(model_path: str | None) -> str | None: + """Append the default checkpoint suffix when a model path is a prefix.""" + if model_path is not None and not model_path.endswith(_PT_EXPT_MODEL_SUFFIXES): + return f"{model_path}.pt" + return model_path + def _update_changed_model_tensors( target_state_dict: dict[str, Any], @@ -150,6 +164,21 @@ def _build_data_system( ) +def _ensure_stat_file_path(stat_file_path: str | None) -> DPPath | None: + """Create a stat-file target and return a DPPath wrapper.""" + if stat_file_path is None: + return None + path = Path(stat_file_path) + if not path.exists(): + if stat_file_path.endswith((".h5", ".hdf5")): + path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(path, "w"): + pass + else: + path.mkdir(parents=True, exist_ok=True) + return DPPath(stat_file_path, "a") + + def get_trainer( config: dict[str, Any], init_model: str | None = None, @@ -159,74 +188,42 @@ def get_trainer( shared_links: dict | None = None, ) -> training.Trainer: """Build a :class:`training.Trainer` from a normalised config.""" - model_params = config["model"] training_params = config["training"] - multi_task = "model_dict" in model_params + multi_task = "model_dict" in config["model"] data_seed = training_params.get("seed", None) - if not multi_task: - type_map = model_params["type_map"] - - # ----- training data ------------------------------------------------ - training_dataset_params = training_params["training_data"] + def factory( + task_config: TrainingTaskConfig, + ) -> tuple[DeepmdDataSystem | LmdbDataSystem, Any | None, DPPath | None]: + type_map = list(task_config.model_params["type_map"]) train_data = _build_data_system( - training_dataset_params, type_map, seed=data_seed + dict(task_config.training_data_params), type_map, seed=data_seed ) - - # ----- validation data ---------------------------------------------- validation_data = None - validation_dataset_params = training_params.get("validation_data", None) - if validation_dataset_params is not None: + if task_config.validation_data_params is not None: validation_data = _build_data_system( - validation_dataset_params, type_map, seed=data_seed + dict(task_config.validation_data_params), type_map, seed=data_seed ) + return ( + train_data, + validation_data, + _ensure_stat_file_path(task_config.stat_file), + ) - # ----- stat file path ----------------------------------------------- - stat_file_path = training_params.get("stat_file", None) - if stat_file_path is not None: - if not Path(stat_file_path).exists(): - if stat_file_path.endswith((".h5", ".hdf5")): - with h5py.File(stat_file_path, "w"): - pass - else: - Path(stat_file_path).mkdir(parents=True, exist_ok=True) - stat_file_path = DPPath(stat_file_path, "a") + train_data_map, validation_data_map, stat_file_path_map = make_task_maps( + config, factory + ) + print_data_summaries(train_data_map, validation_data_map) + if multi_task: + train_data = train_data_map + validation_data = validation_data_map + stat_file_path = stat_file_path_map else: - # Multi-task: build per-task data systems - train_data = {} - validation_data = {} - stat_file_path = {} - for model_key in model_params["model_dict"]: - type_map = model_params["model_dict"][model_key]["type_map"] - data_params = training_params["data_dict"][model_key] - - # training data - train_data[model_key] = _build_data_system( - data_params["training_data"], type_map, seed=data_seed - ) - - # validation data - vd_params = data_params.get("validation_data", None) - if vd_params is not None: - validation_data[model_key] = _build_data_system( - vd_params, type_map, seed=data_seed - ) - else: - validation_data[model_key] = None - - # stat file - _sf = data_params.get("stat_file", None) - if _sf is not None: - if not Path(_sf).exists(): - if _sf.endswith((".h5", ".hdf5")): - with h5py.File(_sf, "w"): - pass - else: - Path(_sf).mkdir(parents=True, exist_ok=True) - stat_file_path[model_key] = DPPath(_sf, "a") - else: - stat_file_path[model_key] = None + task_key = next(iter(train_data_map)) + train_data = train_data_map[task_key] + validation_data = validation_data_map[task_key] + stat_file_path = stat_file_path_map[task_key] trainer = training.Trainer( config, @@ -242,6 +239,201 @@ def get_trainer( return trainer +class SummaryPrinter(BaseSummaryPrinter): + """Summary printer for pt_expt.""" + + def is_built_with_cuda(self) -> bool: + """Check if PyTorch was built with CUDA.""" + import torch + + return torch.version.cuda is not None + + def is_built_with_rocm(self) -> bool: + """Check if PyTorch was built with ROCm.""" + import torch + + return torch.version.hip is not None + + def get_compute_device(self) -> str: + """Get the selected compute device.""" + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + return str(DEVICE) + + def get_ngpus(self) -> int: + """Get the number of visible CUDA devices.""" + import torch + + return torch.cuda.device_count() + + def get_backend_info(self) -> dict: + """Get backend information.""" + import torch + + return { + "Backend": "PyTorch Experimental", + "PT Ver": f"v{torch.__version__}-g{torch.version.git_version[:11]}", + } + + def get_device_name(self) -> str | None: + """Return the current CUDA device name when available.""" + import torch + + if torch.cuda.is_available(): + return torch.cuda.get_device_name(torch.cuda.current_device()) + return None + + +class PTExptTrainEntrypoint(AbstractTrainEntrypoint): + """pt_expt implementation of the common training entrypoint pipeline.""" + + def __init__(self) -> None: + self.finetune_links: dict[str, Any] | None = None + self.shared_links: dict[str, Any] | None = None + self._owns_process_group = False + + def prepare_options( + self, + options: TrainEntrypointOptions, + ) -> TrainEntrypointOptions: + """Normalize checkpoint prefixes accepted by the train CLI.""" + return replace( + options, + init_model=_ensure_pt_expt_model_suffix(options.init_model), + restart=_ensure_pt_expt_model_suffix(options.restart), + ) + + def preprocess_config( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> dict[str, Any]: + """Apply pt_expt multi-task, finetune, and pretrained-model preprocessing.""" + import torch + + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + self.finetune_links = None + self.shared_links = None + + if self.is_multi_task(config): + from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, + ) + + config["model"], self.shared_links = preprocess_shared_params( + config["model"] + ) + if "RANDOM" in config["model"]["model_dict"]: + raise ValueError("Model name can not be 'RANDOM' in multi-task mode!") + + if options.finetune is not None: + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + config["model"], self.finetune_links = get_finetune_rules( + options.finetune, + config["model"], + model_branch=options.model_branch, + change_model_params=options.use_pretrain_script, + ) + + if options.init_model is not None and options.use_pretrain_script: + init_state_dict = torch.load( + options.init_model, map_location=DEVICE, weights_only=True + ) + if "model" in init_state_dict: + init_state_dict = init_state_dict["model"] + config["model"] = init_state_dict["_extra_state"]["model_params"] + + return config + + def update_neighbor_stat( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + *, + multi_task: bool, + ) -> tuple[dict[str, Any], None]: + """Update pt_expt descriptor selections from neighbor statistics.""" + log.info( + "Calculate neighbor statistics... " + "(add --skip-neighbor-stat to skip this step)" + ) + from deepmd.pt_expt.model import ( + BaseModel, + ) + + for task_config in iter_training_task_configs(config): + type_map = task_config.model_params.get("type_map") + train_data = _get_neighbor_stat_data( + dict(task_config.training_data_params), type_map + ) + updated_model_params, _ = BaseModel.update_sel( + train_data, type_map, dict(task_config.model_params) + ) + if multi_task: + config["model"]["model_dict"][task_config.key] = updated_model_params + else: + config["model"] = updated_model_params + return config, None + + def print_summary(self) -> None: + """Print pt_expt backend summary.""" + SummaryPrinter()() + + def setup_run( + self, + options: TrainEntrypointOptions, + config: dict[str, Any], + ) -> None: + """Initialize pt_expt distributed training when launched by torchrun/srun.""" + self._owns_process_group = False + if os.environ.get("LOCAL_RANK") is not None: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + return + dist.init_process_group(backend="cuda:nccl,cpu:gloo") + self._owns_process_group = True + + def teardown_run( + self, + options: TrainEntrypointOptions, + config: dict[str, Any], + ) -> None: + """Destroy the pt_expt distributed process group if this entrypoint made one.""" + if not self._owns_process_group: + return + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() + self._owns_process_group = False + + def run_training( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + neighbor_stat: Any, + ) -> None: + """Build and run the pt_expt trainer.""" + trainer = get_trainer( + config, + options.init_model, + options.restart, + finetune_model=options.finetune, + finetune_links=self.finetune_links, + shared_links=self.shared_links, + ) + trainer.run() + + def train( input_file: str, init_model: str | None = None, @@ -273,114 +465,18 @@ def train( output : str Where to dump the normalised config. """ - import torch - - from deepmd.common import ( - j_loader, - ) - from deepmd.pt_expt.utils.env import ( - DEVICE, - ) - - log.info("Configuration path: %s", input_file) - config = j_loader(input_file) - - # suffix fix - if init_model is not None and not init_model.endswith(".pt"): - init_model += ".pt" - if restart is not None and not restart.endswith(".pt"): - restart += ".pt" - - # Multi-task detection and shared params preprocessing - multi_task = "model_dict" in config.get("model", {}) - shared_links = None - if multi_task: - from deepmd.pt_expt.utils.multi_task import ( - preprocess_shared_params, - ) - - config["model"], shared_links = preprocess_shared_params(config["model"]) - assert "RANDOM" not in config["model"]["model_dict"], ( - "Model name can not be 'RANDOM' in multi-task mode!" - ) - - # update fine-tuning config - finetune_links = None - if finetune is not None: - from deepmd.pt_expt.utils.finetune import ( - get_finetune_rules, - ) - - config["model"], finetune_links = get_finetune_rules( - finetune, - config["model"], + PTExptTrainEntrypoint().run( + TrainEntrypointOptions( + input_file=input_file, + output=output, + init_model=init_model, + restart=restart, + finetune=finetune, model_branch=model_branch, - change_model_params=use_pretrain_script, - ) - - # update init_model config if --use-pretrain-script - if init_model is not None and use_pretrain_script: - init_state_dict = torch.load(init_model, map_location=DEVICE, weights_only=True) - if "model" in init_state_dict: - init_state_dict = init_state_dict["model"] - config["model"] = init_state_dict["_extra_state"]["model_params"] - - # argcheck - config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") - config = normalize(config, multi_task=multi_task) - - # neighbour stat - if not skip_neighbor_stat: - log.info( - "Calculate neighbor statistics... " - "(add --skip-neighbor-stat to skip this step)" - ) - from deepmd.pt_expt.model import ( - BaseModel, - ) - - if not multi_task: - type_map = config["model"].get("type_map") - train_data = _get_neighbor_stat_data( - config["training"]["training_data"], type_map - ) - config["model"], _ = BaseModel.update_sel( - train_data, type_map, config["model"] - ) - else: - for model_key in config["model"]["model_dict"]: - type_map = config["model"]["model_dict"][model_key]["type_map"] - train_data = _get_neighbor_stat_data( - config["training"]["data_dict"][model_key]["training_data"], - type_map, - ) - config["model"]["model_dict"][model_key], _ = BaseModel.update_sel( - train_data, - type_map, - config["model"]["model_dict"][model_key], - ) - - with open(output, "w") as fp: - json.dump(config, fp, indent=4) - - import torch.distributed as dist - - if os.environ.get("LOCAL_RANK") is not None: - dist.init_process_group(backend="cuda:nccl,cpu:gloo") - - try: - trainer = get_trainer( - config, - init_model, - restart, - finetune_model=finetune, - finetune_links=finetune_links, - shared_links=shared_links, + use_pretrain_script=use_pretrain_script, + skip_neighbor_stat=skip_neighbor_stat, ) - trainer.run() - finally: - if dist.is_available() and dist.is_initialized(): - dist.destroy_process_group() + ) def freeze( diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index bd6fdb02a3..d879180f43 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -6,9 +6,9 @@ converted to torch tensors at the boundary. """ -import datetime import functools import logging +import os import time from copy import ( deepcopy, @@ -27,6 +27,15 @@ from deepmd.dpmodel.common import ( to_numpy_array, ) +from deepmd.dpmodel.train import ( + DEFAULT_TASK_KEY, + AbstractTrainer, + RankContext, + TrainerConfig, + TrainingTask, + TrainingTaskCollection, + TrainStepResult, +) from deepmd.dpmodel.utils.batch import ( normalize_batch, split_batch, @@ -34,9 +43,12 @@ from deepmd.dpmodel.utils.learning_rate import ( LearningRateExp, ) -from deepmd.loggers.training import ( - format_training_message, - format_training_message_per_task, +from deepmd.pt.train.utils import ( + resolve_best_checkpoint_dir, +) +from deepmd.pt.train.validation import ( + FullValidator, + resolve_full_validation_start_step, ) from deepmd.pt.utils.compile_compat import next_safe_prime as _next_safe_prime from deepmd.pt.utils.compile_compat import rebuild_graph_module as _rebuild_graph_module @@ -269,6 +281,26 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: return additional_data_requirement +def _as_task_map( + value: Any, + *, + multi_task: bool, + model_keys: list[str], +) -> dict[str, Any]: + """Return a task-keyed mapping, wrapping single-task values as Default.""" + if multi_task: + return {model_key: value[model_key] for model_key in model_keys} + return {DEFAULT_TASK_KEY: value} + + +def _replace_latest_checkpoint_link(latest: Path, ckpt_path: Path) -> None: + """Point latest to ckpt_path using a target relative to latest's directory.""" + if latest.is_symlink() or latest.exists(): + latest.unlink() + latest.parent.mkdir(parents=True, exist_ok=True) + latest.symlink_to(os.path.relpath(ckpt_path, latest.parent)) + + # --------------------------------------------------------------------------- # torch.compile helpers # --------------------------------------------------------------------------- @@ -757,7 +789,7 @@ def forward( # --------------------------------------------------------------------------- -class Trainer: +class Trainer(AbstractTrainer): """Training driver for the pt_expt backend. Uses ``DeepmdDataSystem`` for data loading (numpy batches converted @@ -806,13 +838,37 @@ def __init__( model_params = config["model"] training_params = config["training"] + validating_params = config.get("validating", {}) or {} - # Multi-task detection + # Task normalization -------------------------------------------------- self.multi_task = "model_dict" in model_params self.model_keys = ( - list(model_params["model_dict"]) if self.multi_task else ["Default"] + list(model_params["model_dict"]) if self.multi_task else [DEFAULT_TASK_KEY] ) self.num_model = len(self.model_keys) + self.model_params_by_task = ( + { + model_key: model_params["model_dict"][model_key] + for model_key in self.model_keys + } + if self.multi_task + else {DEFAULT_TASK_KEY: model_params} + ) + self.training_data_by_task = _as_task_map( + training_data, + multi_task=self.multi_task, + model_keys=self.model_keys, + ) + self.validation_data_by_task = _as_task_map( + validation_data, + multi_task=self.multi_task, + model_keys=self.model_keys, + ) + self.stat_file_path_by_task = _as_task_map( + stat_file_path, + multi_task=self.multi_task, + model_keys=self.model_keys, + ) # Distributed training detection self.is_distributed = dist.is_available() and dist.is_initialized() @@ -825,127 +881,110 @@ def __init__( self.disp_freq = training_params.get("disp_freq", 1000) self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") self.save_freq = training_params.get("save_freq", 1000) + self.max_ckpt_keep = int(training_params.get("max_ckpt_keep", 5)) self.display_in_training = training_params.get("disp_training", True) self.timing_in_training = training_params.get("time_training", True) - self.lcurve_should_print_header = True # Model --------------------------------------------------------------- - if not self.multi_task: - self.model = get_model(deepcopy(model_params)).to(DEVICE) - else: - self.model = {} - do_case_embd, case_embd_index = _get_case_embd_config(model_params) - for model_key in self.model_keys: - self.model[model_key] = get_model( - deepcopy(model_params["model_dict"][model_key]) - ).to(DEVICE) - if do_case_embd and not resuming: - self.model[model_key].set_case_embd(case_embd_index[model_key]) + self.models: dict[str, torch.nn.Module] = {} + do_case_embd, case_embd_index = ( + _get_case_embd_config(model_params) if self.multi_task else (False, {}) + ) + for model_key in self.model_keys: + self.models[model_key] = get_model( + deepcopy(self.model_params_by_task[model_key]) + ).to(DEVICE) + if do_case_embd and not resuming: + self.models[model_key].set_case_embd(case_embd_index[model_key]) + self.model = self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] # Loss ---------------------------------------------------------------- - if not self.multi_task: - self.loss = get_loss( - config.get("loss", {}), + self.losses: dict[str, EnergyLoss] = {} + for model_key in self.model_keys: + loss_param = ( + config["loss_dict"][model_key] + if self.multi_task + else config.get("loss", {}) + ) + self.losses[model_key] = get_loss( + deepcopy(loss_param), config["learning_rate"]["start_lr"], - len(model_params["type_map"]), - self.model, + len(self.model_params_by_task[model_key]["type_map"]), + self.models[model_key], ) - else: - self.loss = {} - for model_key in self.model_keys: - loss_param = config["loss_dict"][model_key] - lr_param = config["learning_rate"]["start_lr"] - ntypes = len(model_params["model_dict"][model_key]["type_map"]) - self.loss[model_key] = get_loss( - loss_param, lr_param, ntypes, self.model[model_key] - ) + self.loss = self.losses if self.multi_task else self.losses[DEFAULT_TASK_KEY] # Data requirements --------------------------------------------------- - if not self.multi_task: - data_requirement = self.loss.label_requirement - data_requirement += get_additional_data_requirement(self.model) - training_data.add_data_requirements(data_requirement) - if validation_data is not None: - validation_data.add_data_requirements(data_requirement) - - self.training_data = training_data - self.validation_data = validation_data - self.valid_numb_batch = training_params.get("validation_data", {}).get( - "numb_btch", 1 + self.valid_numb_batch_by_task: dict[str, int] = {} + for model_key in self.model_keys: + data_requirement = list(self.losses[model_key].label_requirement) + data_requirement += get_additional_data_requirement(self.models[model_key]) + self.training_data_by_task[model_key].add_data_requirements( + data_requirement ) - else: - self.training_data = {} - self.validation_data = {} - self.valid_numb_batch = {} - for model_key in self.model_keys: - data_requirement = self.loss[model_key].label_requirement - data_requirement += get_additional_data_requirement( - self.model[model_key] + if self.validation_data_by_task[model_key] is not None: + self.validation_data_by_task[model_key].add_data_requirements( + data_requirement ) - training_data[model_key].add_data_requirements(data_requirement) - if validation_data[model_key] is not None: - validation_data[model_key].add_data_requirements(data_requirement) - self.training_data[model_key] = training_data[model_key] - self.validation_data[model_key] = validation_data[model_key] - self.valid_numb_batch[model_key] = ( - training_params["data_dict"][model_key] - .get("validation_data", {}) - .get("numb_btch", 1) + if self.multi_task: + valid_params = ( + training_params["data_dict"][model_key].get("validation_data", {}) + or {} ) + else: + valid_params = training_params.get("validation_data", {}) or {} + self.valid_numb_batch_by_task[model_key] = max( + int(valid_params.get("numb_btch", 1)), + 1, + ) + self.training_data = ( + self.training_data_by_task + if self.multi_task + else self.training_data_by_task[DEFAULT_TASK_KEY] + ) + self.validation_data = ( + self.validation_data_by_task + if self.multi_task + else self.validation_data_by_task[DEFAULT_TASK_KEY] + ) + self.valid_numb_batch = ( + self.valid_numb_batch_by_task + if self.multi_task + else self.valid_numb_batch_by_task[DEFAULT_TASK_KEY] + ) # Statistics ---------------------------------------------------------- - if not self.multi_task: - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) + self._finetune_update_stat = False + self._sample_funcs: dict[str, Any] = {} + for model_key in self.model_keys: + _nbatch = self.model_params_by_task[model_key].get("data_stat_nbatch", 10) + _data = self.training_data_by_task[model_key] + _stat_path = self.stat_file_path_by_task[model_key] @functools.lru_cache - def get_sample() -> list[dict[str, np.ndarray]]: - return make_stat_input(training_data, data_stat_nbatch) + def _make_sample( + _d: DeepmdDataSystem = _data, _n: int = _nbatch + ) -> list[dict[str, np.ndarray]]: + return make_stat_input(_d, _n) + + self._sample_funcs[model_key] = _make_sample - finetune_has_new_type = ( + _finetune_has_new_type = ( finetune_model is not None and finetune_links is not None - and finetune_links["Default"].get_has_new_type() + and model_key in finetune_links + and finetune_links[model_key].get_has_new_type() ) - if (not resuming or finetune_has_new_type) and self.rank == 0: - self.model.compute_or_load_stat( - sampled_func=get_sample, - stat_file_path=stat_file_path, + if _finetune_has_new_type: + self._finetune_update_stat = True + if (not resuming or _finetune_has_new_type) and self.rank == 0: + self.models[model_key].compute_or_load_stat( + sampled_func=_make_sample, + stat_file_path=_stat_path, ) - if self.is_distributed: - self._broadcast_model_stat(self.model) - else: - self._finetune_update_stat = False - self._sample_funcs: dict[str, Any] = {} + if self.is_distributed: for model_key in self.model_keys: - _nbatch = model_params["model_dict"][model_key].get( - "data_stat_nbatch", 10 - ) - _data = training_data[model_key] - _stat_path = stat_file_path[model_key] if stat_file_path else None - - def _make_sample( - _d: DeepmdDataSystem = _data, _n: int = _nbatch - ) -> list[dict[str, np.ndarray]]: - return make_stat_input(_d, _n) - - self._sample_funcs[model_key] = _make_sample - - _finetune_has_new_type = ( - finetune_model is not None - and finetune_links is not None - and model_key in finetune_links - and finetune_links[model_key].get_has_new_type() - ) - if _finetune_has_new_type: - self._finetune_update_stat = True - if (not resuming or _finetune_has_new_type) and self.rank == 0: - self.model[model_key].compute_or_load_stat( - sampled_func=_make_sample, - stat_file_path=_stat_path, - ) - if self.is_distributed: - for model_key in self.model_keys: - self._broadcast_model_stat(self.model[model_key]) + self._broadcast_model_stat(self.models[model_key]) # Model probability (multi-task) -------------------------------------- if self.multi_task: @@ -956,7 +995,7 @@ def _make_sample( self.model_prob = resolve_model_prob( self.model_keys, training_params.get("model_prob"), - training_data, + self.training_data_by_task, ) else: self.model_prob = None @@ -1180,39 +1219,29 @@ def _make_sample( self._unwrapped.load_state_dict(new_state) # Per-branch bias adjustment (rank 0 only, then broadcast) - if not self.multi_task: - finetune_rule = finetune_links["Default"] + for model_key in self.model_keys: + finetune_rule = finetune_links[model_key] + if finetune_rule.get_resuming(): + log.info(f"Model branch {model_key} will resume training.") + continue + if self.multi_task: + log.info(f"Model branch {model_key} will be fine-tuned.") bias_mode = ( "change-by-statistic" if not finetune_rule.get_random_fitting() else "set-by-statistic" ) if self.rank == 0: - self.model = model_change_out_bias( - self.model, get_sample, _bias_adjust_mode=bias_mode + self.models[model_key] = model_change_out_bias( + self.models[model_key], + self._sample_funcs[model_key], + _bias_adjust_mode=bias_mode, ) if self.is_distributed: - self._broadcast_model_stat(self.model) - else: - for model_key in self.model_keys: - finetune_rule = finetune_links[model_key] - if finetune_rule.get_resuming(): - log.info(f"Model branch {model_key} will resume training.") - continue - log.info(f"Model branch {model_key} will be fine-tuned.") - bias_mode = ( - "change-by-statistic" - if not finetune_rule.get_random_fitting() - else "set-by-statistic" - ) - if self.rank == 0: - self.model[model_key] = model_change_out_bias( - self.model[model_key], - self._sample_funcs[model_key], - _bias_adjust_mode=bias_mode, - ) - if self.is_distributed: - self._broadcast_model_stat(self.model[model_key]) + self._broadcast_model_stat(self.models[model_key]) + self.model = ( + self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] + ) else: # --- Normal resume (init_model / restart) -------------------- self._unwrapped.load_state_dict(state_dict) @@ -1245,6 +1274,94 @@ def _make_sample( log.info("Compiling model with torch.compile (%s)", compile_opts) self._compile_model(compile_opts) + self.training_tasks = self._make_training_tasks() + super().__init__( + TrainerConfig.from_training_params( + training_params, + num_steps=self.num_steps, + start_step=self.start_step, + restart_training=self.restart_training, + ), + rank_context=RankContext(rank=self.rank, world_size=self.world_size), + ) + self.full_validator = self._create_full_validator( + validating_params=validating_params, + validation_data=self.validation_data if not self.multi_task else None, + ) + + def _create_full_validator( + self, + *, + validating_params: dict[str, Any], + validation_data: Any | None, + ) -> FullValidator | None: + """Create the runtime full validator when it is active.""" + if not self._is_validation_requested(validating_params, "full_validation"): + return None + self._raise_if_full_validation_unsupported(validation_data) + if validation_data is None: + raise RuntimeError( + "validation_data must be available after full validation checks." + ) + return FullValidator( + validating_params=validating_params, + validation_data=validation_data, + model=self.models[DEFAULT_TASK_KEY], + state_store=self._unwrapped.train_infos, + num_steps=self.num_steps, + rank=self.rank, + zero_stage=0, + restart_training=self.restart_training, + checkpoint_dir=resolve_best_checkpoint_dir( + validating_params, self.save_ckpt + ), + ) + + def _is_validation_requested( + self, + validating_params: dict[str, Any], + flag_name: str, + ) -> bool: + """Check whether a full validation flow can trigger during this run.""" + if not validating_params.get(flag_name, False): + return False + start_step = resolve_full_validation_start_step( + validating_params.get("full_val_start", 0.5), + self.num_steps, + ) + return start_step is not None and start_step <= self.num_steps + + def _raise_if_full_validation_unsupported( + self, + validation_data: Any | None, + ) -> None: + """Validate runtime full validation constraints.""" + if self.multi_task: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; multi-task training is not supported." + ) + + has_spin = getattr(self.models[DEFAULT_TASK_KEY], "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if has_spin or isinstance(self.loss, EnergySpinLoss): + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; spin-energy training is not supported." + ) + + if not isinstance(self.loss, EnergyLoss): + raise ValueError( + "validating.full_validation only supports single-task energy training." + ) + + if validation_data is None: + raise ValueError( + "validating.full_validation requires `training.validation_data` " + "to be configured." + ) + # ------------------------------------------------------------------ # torch.compile helpers # ------------------------------------------------------------------ @@ -1394,14 +1511,12 @@ def get_data( ------- input_dict, label_dict """ - if not self.multi_task: - data_sys = self.training_data if is_train else self.validation_data - else: - data_sys = ( - self.training_data[task_key] - if is_train - else self.validation_data[task_key] - ) + task_key = task_key if self.multi_task else DEFAULT_TASK_KEY + data_sys = ( + self.training_data_by_task[task_key] + if is_train + else self.validation_data_by_task[task_key] + ) if data_sys is None: return {}, {} @@ -1464,6 +1579,25 @@ def _broadcast_model_stat(model: torch.nn.Module) -> None: # ------------------------------------------------------------------ def save_checkpoint(self, step: int) -> None: + ckpt_path = Path(f"{self.save_ckpt}-{step}.pt") + self._save_checkpoint_to_path(ckpt_path, step=step) + latest = Path(f"{self.save_ckpt}.pt") + _replace_latest_checkpoint_link(latest, ckpt_path) + self._cleanup_old_checkpoints() + log.info(f"Saved checkpoint to {ckpt_path}") + + def _save_full_validation_checkpoint( + self, + save_path: Path, + lr: float = 0.0, + step: int = 0, + ) -> None: + """Save a checkpoint selected by full validation.""" + del lr + self._save_checkpoint_to_path(save_path, step=step) + + def _save_checkpoint_to_path(self, ckpt_path: Path, *, step: int) -> None: + """Serialize the current trainer state to an explicit checkpoint path.""" self._unwrapped.train_infos["step"] = step # When compiled, wrapper.model[key] is _CompiledModel whose state_dict # uses keys like "original_model.*". Restart would load into a plain @@ -1485,14 +1619,25 @@ def save_checkpoint(self, step: int) -> None: finally: for task_key, compiled in compiled_backup.items(): wrapper.model[task_key] = compiled - ckpt_path = f"{self.save_ckpt}-{step}.pt" + ckpt_path.parent.mkdir(parents=True, exist_ok=True) torch.save(state, ckpt_path) - # symlink latest - latest = Path(f"{self.save_ckpt}.pt") - if latest.is_symlink() or latest.exists(): - latest.unlink() - latest.symlink_to(ckpt_path) - log.info(f"Saved checkpoint to {ckpt_path}") + + def _cleanup_old_checkpoints(self) -> None: + """Remove old step checkpoint files beyond the retention limit.""" + if self.max_ckpt_keep <= 0: + return + ckpt_prefix_path = Path(self.save_ckpt) + ckpt_parent = ckpt_prefix_path.parent + ckpt_prefix = ckpt_prefix_path.name + checkpoints: list[tuple[int, Path]] = [] + for path in ckpt_parent.glob(f"{ckpt_prefix}-*.pt"): + if path.is_dir() or path.is_symlink(): + continue + step_text = path.name.removeprefix(f"{ckpt_prefix}-").removesuffix(".pt") + if step_text.isdigit(): + checkpoints.append((int(step_text), path)) + for _, path in sorted(checkpoints)[: -self.max_ckpt_keep]: + path.unlink(missing_ok=True) # ------------------------------------------------------------------ # Training loop @@ -1508,321 +1653,191 @@ def _optimizer_step(self) -> None: self.optimizer.step() self.scheduler.step() - def run(self) -> None: - from deepmd.utils import random as dp_random - - fout = ( - open( - self.disp_file, - mode="w" if not self.restart_training else "a", - buffering=1, - ) - if self.rank == 0 - else None + def _make_training_tasks(self) -> TrainingTaskCollection: + """Build the backend-independent task collection.""" + return TrainingTaskCollection( + [ + TrainingTask( + key=model_key, + training_data=self.training_data_by_task[model_key], + validation_data=self.validation_data_by_task[model_key], + valid_numb_batch=self.valid_numb_batch_by_task[model_key], + ) + for model_key in self.model_keys + ], + probabilities=self.model_prob, ) - log.info("Start to train %d steps.", self.num_steps) - self.wrapper.train() + def run(self) -> None: + """Run pt_expt training through the backend-independent trainer loop.""" + log.info("Start to train %d steps.", self.num_steps) wall_start = time.time() - last_log_time = wall_start - last_log_step = self.start_step + super().run(self.training_tasks) + log.info("Training finished. Total wall time: %.2fs", time.time() - wall_start) - for step_id in range(self.start_step, self.num_steps): - cur_lr = float(self.lr_schedule.value(step_id)) + def run_full_validation( + self, + *, + step: int, + display_step: int, + learning_rate: float, + ) -> None: + """Run optional full validation for one step.""" + if self.full_validator is None: + return None + self.full_validator.run( + step_id=display_step, + display_step=display_step, + lr=learning_rate, + save_checkpoint=self._save_full_validation_checkpoint, + ) + return None - # --- task selection (multi-task) --- - task_key = "Default" - if self.multi_task: - model_index = dp_random.choice( - np.arange(self.num_model, dtype=np.int_), - p=self.model_prob, - ) - task_key = self.model_keys[model_index] + def select_task(self, tasks: TrainingTaskCollection) -> TrainingTask: + """Select a task using DeePMD's seeded random helper.""" + if not tasks.is_multitask: + return tasks[tasks.keys[0]] + from deepmd.utils import random as dp_random - # --- forward / backward --- - self.optimizer.zero_grad(set_to_none=True) - input_dict, label_dict = self.get_data(is_train=True, task_key=task_key) + model_index = dp_random.choice( + np.arange(len(tasks), dtype=np.int_), + p=tasks.probabilities, + ) + return tasks[tasks.keys[int(model_index)]] - cur_lr_sched = self.scheduler.get_last_lr()[0] - _model_pred, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=cur_lr_sched, - label=label_dict, - task_key=task_key if self.multi_task else None, + def on_train_begin(self, tasks: TrainingTaskCollection) -> None: + """Switch the wrapper to training mode.""" + self.wrapper.train() + + def collect_display_results( + self, + tasks: TrainingTaskCollection, + *, + active_task: TrainingTask, + step: int, + step_result: TrainStepResult, + ) -> tuple[dict[str, Any], dict[str, Any] | None]: + """Collect display metrics in eval mode, then resume training mode.""" + self.wrapper.eval() + self._display_cur_lr_sched = step_result.payload["cur_lr_sched"] + try: + return super().collect_display_results( + tasks, + active_task=active_task, + step=step, + step_result=step_result, ) - loss.backward() + finally: + self._display_cur_lr_sched = None + self.wrapper.train() + + def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: + """Run one pt_expt optimizer step.""" + task_key = task.key + self.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict = self.get_data(is_train=True, task_key=task_key) + + cur_lr_sched = self.scheduler.get_last_lr()[0] + _model_pred, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=cur_lr_sched, + label=label_dict, + task_key=task_key, + ) + loss.backward() - if self.gradient_max_norm > 0.0: - torch.nn.utils.clip_grad_norm_( - self.wrapper.parameters(), self.gradient_max_norm - ) + if self.gradient_max_norm > 0.0: + torch.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), self.gradient_max_norm + ) - self._optimizer_step() + self._optimizer_step() + return TrainStepResult( + task_key=task_key, + step=step, + payload={ + "loss": loss, + "more_loss": more_loss, + "cur_lr_sched": cur_lr_sched, + }, + ) - # --- display --- - display_step_id = step_id + 1 - if self.display_in_training and ( - display_step_id % self.disp_freq == 0 or display_step_id == 1 - ): - self.wrapper.eval() - - if self.rank == 0: - - def _to_float(v: Any) -> float: - return v.detach().item() if torch.is_tensor(v) else float(v) - - if not self.multi_task: - train_results = { - k: _to_float(v) - for k, v in more_loss.items() - if "l2_" not in k - } - - # validation - valid_results: dict[str, Any] = {} - if self.validation_data is not None: - sum_natoms = 0 - for _ii in range(self.valid_numb_batch): - val_input, val_label = self.get_data(is_train=False) - if not val_input: - break - _, _vloss, _vmore = self._unwrapped( - **val_input, - cur_lr=cur_lr_sched, - label=val_label, - ) - natoms = int(val_input["atype"].shape[-1]) - sum_natoms += natoms - for k, v in _vmore.items(): - if "l2_" not in k: - valid_results[k] = ( - valid_results.get(k, 0.0) - + _to_float(v) * natoms - ) - if sum_natoms > 0: - valid_results = { - k: v / sum_natoms for k, v in valid_results.items() - } - else: - # Multi-task: compute loss for ALL tasks - train_results = {_key: {} for _key in self.model_keys} - valid_results = {_key: {} for _key in self.model_keys} - - # current task already has loss - train_results[task_key] = { - k: _to_float(v) - for k, v in more_loss.items() - if "l2_" not in k - } - - # compute loss for other tasks - for _key in self.model_keys: - if _key != task_key: - self.optimizer.zero_grad() - _inp, _lab = self.get_data(is_train=True, task_key=_key) - _, _loss, _more = self._unwrapped( - **_inp, - cur_lr=cur_lr_sched, - label=_lab, - task_key=_key, - ) - train_results[_key] = { - k: _to_float(v) - for k, v in _more.items() - if "l2_" not in k - } - - # validation for each task - _vdata = self.validation_data[_key] - if _vdata is not None: - _sum_natoms = 0 - _vres: dict[str, Any] = {} - for _ii in range(self.valid_numb_batch[_key]): - _vi, _vl = self.get_data( - is_train=False, task_key=_key - ) - if not _vi: - break - _, _vloss, _vmore = self._unwrapped( - **_vi, - cur_lr=cur_lr_sched, - label=_vl, - task_key=_key, - ) - natoms = int(_vi["atype"].shape[-1]) - _sum_natoms += natoms - for k, v in _vmore.items(): - if "l2_" not in k: - _vres[k] = ( - _vres.get(k, 0.0) - + _to_float(v) * natoms - ) - if _sum_natoms > 0: - _vres = { - k: v / _sum_natoms for k, v in _vres.items() - } - valid_results[_key] = _vres - # wall-clock time - current_time = time.time() - wall_elapsed = current_time - wall_start - interval_wall_time = current_time - last_log_time - # average wall time per step over the interval since the - # last log (number of steps counted exactly once across - # intervals via last_log_step) - interval_steps = max(1, display_step_id - last_log_step) - step_time = interval_wall_time / interval_steps - last_log_time = current_time - last_log_step = display_step_id - if self.timing_in_training: - steps_completed_since_restart = max( - 1, - display_step_id - self.start_step, - ) - eta = int( - (self.num_steps - display_step_id) - / steps_completed_since_restart - * wall_elapsed - ) - log.info( - format_training_message( - batch=display_step_id, - wall_time=interval_wall_time, - eta=eta, - current_time=datetime.datetime.fromtimestamp( - current_time, - tz=datetime.timezone.utc, - ).astimezone(), - step_time=step_time, - ) - ) - else: - log.info( - format_training_message( - batch=display_step_id, - wall_time=interval_wall_time, - ) - ) + def evaluate_training( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float]: + """Evaluate training loss terms for one task.""" + if step_result is not None and step_result.task_key == task.key: + return self._more_loss_to_float(step_result.payload["more_loss"]) + + self.optimizer.zero_grad() + input_dict, label_dict = self.get_data(is_train=True, task_key=task.key) + _, _loss, more_loss = self._unwrapped( + **input_dict, + cur_lr=self._get_display_cur_lr_sched(), + label=label_dict, + task_key=task.key, + ) + return self._more_loss_to_float(more_loss) - # log - if not self.multi_task: - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name="trn", - rmse=train_results, - learning_rate=cur_lr, - ) - ) - if valid_results: - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name="val", - rmse=valid_results, - learning_rate=None, - ) - ) - else: - for _key in self.model_keys: - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name=_key + "_trn", - rmse=train_results[_key], - learning_rate=cur_lr, - ) - ) - if valid_results[_key]: - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name=_key + "_val", - rmse=valid_results[_key], - learning_rate=None, - ) - ) - - # lcurve file - if self.lcurve_should_print_header: - self.print_header(fout, train_results, valid_results) - self.lcurve_should_print_header = False - self.print_on_training( - fout, display_step_id, cur_lr, train_results, valid_results + def evaluate_validation( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float] | None: + """Evaluate validation loss terms for one task.""" + if task.validation_data is None: + return None + + valid_results: dict[str, float] = {} + sum_natoms = 0 + for _ii in range(task.valid_numb_batch): + val_input, val_label = self.get_data(is_train=False, task_key=task.key) + if not val_input: + break + _, _vloss, vmore = self._unwrapped( + **val_input, + cur_lr=self._get_display_cur_lr_sched(), + label=val_label, + task_key=task.key, + ) + natoms = int(val_input["atype"].shape[-1]) + sum_natoms += natoms + for key, value in vmore.items(): + if "l2_" not in key: + valid_results[key] = ( + valid_results.get(key, 0.0) + self._to_float(value) * natoms ) + if sum_natoms > 0: + valid_results = { + key: value / sum_natoms for key, value in valid_results.items() + } + return valid_results - self.wrapper.train() - - # --- checkpoint --- - if display_step_id % self.save_freq == 0 and self.rank == 0: - self.save_checkpoint(display_step_id) - - # final save - if self.rank == 0: - self.save_checkpoint(self.num_steps) - wall_total = time.time() - wall_start - if fout is not None: - fout.close() - log.info("Training finished. Total wall time: %.2fs", wall_total) + def learning_rate(self, step: int) -> float: + """Return the configured learning rate for a zero-based step.""" + return float(self.lr_schedule.value(step)) - # ------------------------------------------------------------------ - # Printing helpers - # ------------------------------------------------------------------ - - def print_header( - self, - fout: Any, - train_results: dict[str, Any], - valid_results: dict[str, Any], - ) -> None: - header = "# {:5s}".format("step") - if not self.multi_task: - train_keys = sorted(train_results.keys()) - if valid_results: - for k in train_keys: - header += f" {k + '_val':>11s} {k + '_trn':>11s}" - else: - for k in train_keys: - header += f" {k + '_trn':>11s}" - else: - for model_key in self.model_keys: - if valid_results[model_key]: - for k in sorted(train_results[model_key].keys()): - header += f" {k + '_val_' + model_key:>11s} {k + '_trn_' + model_key:>11s}" - else: - for k in sorted(train_results[model_key].keys()): - header += f" {k + '_trn_' + model_key:>11s}" - header += " {:8s}\n".format("lr") - fout.write(header) - fout.flush() + @staticmethod + def _to_float(value: Any) -> float: + return value.detach().item() if torch.is_tensor(value) else float(value) - def print_on_training( - self, - fout: Any, - step_id: int, - cur_lr: float, - train_results: dict, - valid_results: dict, - ) -> None: - line = f"{step_id:7d}" - if not self.multi_task: - train_keys = sorted(train_results.keys()) - if valid_results: - for k in train_keys: - line += f" {valid_results.get(k, float('nan')):11.2e} {train_results[k]:11.2e}" - else: - for k in train_keys: - line += f" {train_results[k]:11.2e}" - else: - for model_key in self.model_keys: - if valid_results[model_key]: - for k in sorted(valid_results[model_key].keys()): - line += f" {valid_results[model_key][k]:11.2e} {train_results[model_key][k]:11.2e}" - else: - for k in sorted(train_results[model_key].keys()): - line += f" {train_results[model_key][k]:11.2e}" - line += f" {cur_lr:8.1e}\n" - fout.write(line) - fout.flush() + def _get_display_cur_lr_sched(self) -> float: + cur_lr_sched = getattr(self, "_display_cur_lr_sched", None) + if cur_lr_sched is None: + cur_lr_sched = self.scheduler.get_last_lr()[0] + return cur_lr_sched + + @classmethod + def _more_loss_to_float(cls, more_loss: dict[str, Any]) -> dict[str, float]: + return { + key: cls._to_float(value) + for key, value in more_loss.items() + if "l2_" not in key + } def model_change_out_bias( diff --git a/deepmd/pt_expt/utils/finetune.py b/deepmd/pt_expt/utils/finetune.py index 473bb43710..c44de996dc 100644 --- a/deepmd/pt_expt/utils/finetune.py +++ b/deepmd/pt_expt/utils/finetune.py @@ -10,15 +10,18 @@ import torch -from deepmd.pt.utils.finetune import ( - get_finetune_rule_single, -) from deepmd.pt_expt.utils.env import ( DEVICE, ) from deepmd.utils.finetune import ( FinetuneRuleItem, ) +from deepmd.utils.finetune import get_finetune_rule_single as get_finetune_rule_single +from deepmd.utils.finetune import ( + get_finetune_rules_from_model_params, +) + +__all__ = ["get_finetune_rule_single", "get_finetune_rules"] def _is_pte(path: str) -> bool: @@ -74,79 +77,18 @@ def get_finetune_rules( single-task, or per-branch keys for multi-task). """ last_model_params = _load_model_params(finetune_model) - - if change_model_params and "descriptor" not in last_model_params: - # For multi-task pretrained, check inside model_dict - if "model_dict" not in last_model_params or "descriptor" not in next( - iter(last_model_params["model_dict"].values()) - ): - raise ValueError( - "Cannot use --use-pretrain-script: the pretrained model does not " - "contain full model params. If finetuning from a .pte file, " - "re-freeze it with the latest code so that model_def_script is embedded." - ) - - multi_task = "model_dict" in model_config - finetune_from_multi_task = "model_dict" in last_model_params - finetune_links: dict[str, FinetuneRuleItem] = {} - - if not multi_task: - # Single-task target - if model_branch == "" and "finetune_head" in model_config: - model_branch = model_config["finetune_head"] - model_config, finetune_rule = get_finetune_rule_single( - model_config, - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch="Default", - model_branch_from=model_branch, - change_model_params=change_model_params, - ) - finetune_links["Default"] = finetune_rule - else: - # Multi-task target — mirrors PT's logic - if model_branch != "": - raise ValueError( - "Multi-task fine-tuning does not support command-line branches chosen! " - "Please define the 'finetune_head' in each model params!" - ) - if not finetune_from_multi_task: - pretrained_keys = ["Default"] - else: - pretrained_keys = list(last_model_params["model_dict"].keys()) - for model_key in model_config["model_dict"]: - resuming = False - if ( - "finetune_head" in model_config["model_dict"][model_key] - and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" - ): - pretrained_key = model_config["model_dict"][model_key]["finetune_head"] - if pretrained_key not in pretrained_keys: - raise ValueError( - f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model! " - f"Available heads are: {list(pretrained_keys)}" - ) - model_branch_from = pretrained_key - elif ( - "finetune_head" not in model_config["model_dict"][model_key] - and model_key in pretrained_keys - ): - # resume — no finetune - model_branch_from = model_key - resuming = True - else: - # new branch or RANDOM → random fitting - model_branch_from = "RANDOM" - model_config["model_dict"][model_key], finetune_rule = ( - get_finetune_rule_single( - model_config["model_dict"][model_key], - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch=model_key, - model_branch_from=model_branch_from, - change_model_params=change_model_params, - ) - ) - finetune_links[model_key] = finetune_rule - finetune_links[model_key].resuming = resuming - return model_config, finetune_links + return get_finetune_rules_from_model_params( + last_model_params, + model_config, + model_branch=model_branch, + change_model_params=change_model_params, + multitask_branch_error=( + "Multi-task fine-tuning does not support command-line branches chosen! " + "Please define the 'finetune_head' in each model params!" + ), + missing_model_params_error=( + "Cannot use --use-pretrain-script: the pretrained model does not " + "contain full model params. If finetuning from a .pte file, " + "re-freeze it with the latest code so that model_def_script is embedded." + ), + ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 07b256e21b..31a1b9cf05 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -5495,6 +5495,9 @@ def resolve_full_validation_start_step( def validating_args() -> Argument: """Generate full validation arguments.""" valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + doc_full_validation_supported = ( + "(Supported Backend: PyTorch, PyTorch Experimental, JAX) " + ) doc_full_validation = ( "Whether to run an additional full validation pass over the entire " "validation dataset during training. This flow is independent from the " @@ -5566,7 +5569,7 @@ def validating_args() -> Argument: bool, optional=True, default=False, - doc=doc_only_pt_supported + doc_full_validation, + doc=doc_full_validation_supported + doc_full_validation, ), Argument( "ema_full_validation", @@ -5580,7 +5583,7 @@ def validating_args() -> Argument: int, optional=True, default=5000, - doc=doc_only_pt_supported + doc_validation_freq, + doc=doc_full_validation_supported + doc_validation_freq, extra_check=lambda x: x > 0, extra_check_errmsg="must be greater than 0", ), @@ -5589,21 +5592,21 @@ def validating_args() -> Argument: bool, optional=True, default=True, - doc=doc_only_pt_supported + doc_save_best, + doc=doc_full_validation_supported + doc_save_best, ), Argument( "save_best_dir", [str, None], optional=True, default=None, - doc=doc_only_pt_supported + doc_save_best_dir, + doc=doc_full_validation_supported + doc_save_best_dir, ), Argument( "max_best_ckpt", int, optional=True, default=1, - doc=doc_only_pt_supported + doc_max_best_ckpt, + doc=doc_full_validation_supported + doc_max_best_ckpt, extra_check=lambda x: x > 0, extra_check_errmsg="must be greater than 0", ), @@ -5612,7 +5615,7 @@ def validating_args() -> Argument: str, optional=True, default="E:MAE", - doc=doc_only_pt_supported + doc_validation_metric, + doc=doc_full_validation_supported + doc_validation_metric, extra_check=is_valid_full_validation_metric, extra_check_errmsg=( "must be one of " @@ -5624,14 +5627,14 @@ def validating_args() -> Argument: str, optional=True, default="val.log", - doc=doc_only_pt_supported + doc_full_val_file, + doc=doc_full_validation_supported + doc_full_val_file, ), Argument( "full_val_start", [int, float], optional=True, default=0.5, - doc=doc_only_pt_supported + doc_full_val_start, + doc=doc_full_validation_supported + doc_full_val_start, extra_check=lambda x: x >= 0, extra_check_errmsg="must be greater than or equal to 0", ), @@ -5657,7 +5660,7 @@ def validating_args() -> Argument: sub_variants=[], optional=True, default={}, - doc=doc_only_pt_supported + doc=doc_full_validation_supported + "Independent full validation options for single-task energy training.", ) diff --git a/deepmd/utils/finetune.py b/deepmd/utils/finetune.py index 9bcb85dc82..7263566641 100644 --- a/deepmd/utils/finetune.py +++ b/deepmd/utils/finetune.py @@ -10,6 +10,10 @@ Any, ) +from deepmd.utils.model_branch_dict import ( + get_model_dict, +) + log = logging.getLogger(__name__) _IGNORED_DESCRIPTOR_KEYS = frozenset({"trainable"}) @@ -258,6 +262,256 @@ def get_finetune_tmap(self) -> list[str]: return self.type_map +class FinetuneRuleBuilder: + """Build backend-independent fine-tuning config rules. + + The builder only handles model-config and branch-selection semantics. Backend + wrappers are still responsible for loading pretrained model params, and + backend trainers are responsible for copying tensor/state values. + """ + + def __init__( + self, + pretrained_model_params: dict[str, Any], + target_model_config: dict[str, Any], + *, + model_branch: str = "", + change_model_params: bool = True, + default_branch: str = "Default", + multitask_branch_error: str | None = None, + missing_model_params_error: str | None = None, + ) -> None: + self.pretrained_model_params = pretrained_model_params + self.target_model_config = target_model_config + self.model_branch = model_branch + self.change_model_params = change_model_params + self.default_branch = default_branch + self.multitask_branch_error = multitask_branch_error or ( + "Multi-task fine-tuning does not support command-line branch " + "selection. Define 'finetune_head' in each model branch." + ) + self.missing_model_params_error = missing_model_params_error or ( + "Cannot use --use-pretrain-script: the pretrained model does not " + "contain full model params." + ) + + @property + def target_is_multitask(self) -> bool: + """Whether the target model config is multi-task.""" + return "model_dict" in self.target_model_config + + @property + def pretrained_is_multitask(self) -> bool: + """Whether the pretrained model params are multi-task.""" + return "model_dict" in self.pretrained_model_params + + def build(self) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: + """Return updated target model config and fine-tuning rules.""" + model_config = deepcopy(self.target_model_config) + finetune_links: dict[str, FinetuneRuleItem] = {} + + if not self.target_is_multitask: + model_branch = self.model_branch + if model_branch == "" and "finetune_head" in model_config: + model_branch = model_config["finetune_head"] + model_config, finetune_rule = self.build_single_rule( + model_config, + model_branch=self.default_branch, + model_branch_from=model_branch, + ) + finetune_links[self.default_branch] = finetune_rule + return model_config, finetune_links + + if self.model_branch != "": + raise ValueError(self.multitask_branch_error) + + pretrained_keys = self._pretrained_keys() + pretrained_aliases = self._pretrained_aliases() + for model_key in model_config["model_dict"]: + target_branch = model_config["model_dict"][model_key] + resuming = False + if ( + "finetune_head" in target_branch + and target_branch["finetune_head"] != "RANDOM" + ): + model_branch_from = target_branch["finetune_head"] + if model_branch_from not in pretrained_aliases: + raise ValueError( + f"'{model_branch_from}' head chosen to finetune does not " + "exist in the pretrained model. Available heads are: " + f"{pretrained_keys}" + ) + elif "finetune_head" not in target_branch and model_key in pretrained_keys: + model_branch_from = model_key + resuming = True + else: + model_branch_from = "RANDOM" + + model_config["model_dict"][model_key], finetune_rule = ( + self.build_single_rule( + target_branch, + model_branch=model_key, + model_branch_from=model_branch_from, + ) + ) + finetune_rule.resuming = resuming + finetune_links[model_key] = finetune_rule + return model_config, finetune_links + + def build_single_rule( + self, + single_param_target: dict[str, Any], + *, + model_branch: str = "Default", + model_branch_from: str = "", + ) -> tuple[dict[str, Any], FinetuneRuleItem]: + """Build a fine-tuning rule for one target branch.""" + single_config = deepcopy(single_param_target) + new_fitting = False + model_branch_chosen = self.default_branch + + if not self.pretrained_is_multitask: + if model_branch_from not in ("", self.default_branch, "RANDOM"): + raise ValueError( + "Single-task pretrained models only provide the " + f"{self.default_branch!r} branch, got {model_branch_from!r}." + ) + single_config_chosen = deepcopy(self.pretrained_model_params) + if model_branch_from == "RANDOM": + new_fitting = True + else: + model_dict_params = self.pretrained_model_params["model_dict"] + if model_branch_from in ["", "RANDOM"]: + model_branch_chosen = next(iter(model_dict_params)) + new_fitting = True + log.warning( + "The fitting net will be re-initialized instead of using the " + "pretrained fitting net. The bias_adjust_mode will be " + "set-by-statistic." + ) + else: + model_branch_chosen = model_branch_from + model_alias_dict, _ = get_model_dict(model_dict_params) + if model_branch_chosen not in model_alias_dict: + raise ValueError( + f"No model branch or alias named '{model_branch_chosen}'. " + f"Available branches are {list(model_dict_params)}." + ) + model_branch_chosen = model_alias_dict[model_branch_chosen] + single_config_chosen = deepcopy(model_dict_params[model_branch_chosen]) + + old_type_map = single_config_chosen["type_map"] + new_type_map = single_config["type_map"] + finetune_rule = FinetuneRuleItem( + p_type_map=old_type_map, + type_map=new_type_map, + model_branch=model_branch_chosen, + random_fitting=new_fitting, + ) + if self.change_model_params: + self._apply_pretrained_model_params( + single_config, + single_config_chosen, + model_branch=model_branch, + model_branch_chosen=model_branch_chosen, + random_fitting=new_fitting, + ) + return single_config, finetune_rule + + def _apply_pretrained_model_params( + self, + single_config: dict[str, Any], + pretrained_config: dict[str, Any], + *, + model_branch: str, + model_branch_chosen: str, + random_fitting: bool, + ) -> None: + if "descriptor" not in pretrained_config: + raise ValueError(self.missing_model_params_error) + if not random_fitting and "fitting_net" not in pretrained_config: + raise ValueError(self.missing_model_params_error) + if "descriptor" in single_config: + warn_descriptor_config_differences( + single_config["descriptor"], + pretrained_config["descriptor"], + model_branch_chosen, + ) + trainable_param = { + "descriptor": single_config.get("descriptor", {}).get("trainable", True), + "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), + } + single_config["descriptor"] = deepcopy(pretrained_config["descriptor"]) + if not random_fitting: + single_config["fitting_net"] = deepcopy(pretrained_config["fitting_net"]) + log.info( + "Change the '%s' model configurations according to pretrained branch '%s'.", + model_branch, + model_branch_chosen, + ) + for net_type, trainable in trainable_param.items(): + if net_type in single_config: + single_config[net_type]["trainable"] = trainable + else: + single_config[net_type] = {"trainable": trainable} + + def _pretrained_keys(self) -> list[str]: + if self.pretrained_is_multitask: + return list(self.pretrained_model_params["model_dict"]) + return [self.default_branch] + + def _pretrained_aliases(self) -> dict[str, str]: + if self.pretrained_is_multitask: + model_alias_dict, _ = get_model_dict( + self.pretrained_model_params["model_dict"] + ) + return dict(model_alias_dict) + return {self.default_branch: self.default_branch} + + +def get_finetune_rule_single( + single_param_target: dict[str, Any], + model_param_pretrained: dict[str, Any], + from_multitask: bool = False, + model_branch: str = "Default", + model_branch_from: str = "", + change_model_params: bool = False, +) -> tuple[dict[str, Any], FinetuneRuleItem]: + """Build one backend-independent fine-tuning rule.""" + builder = FinetuneRuleBuilder( + model_param_pretrained, + single_param_target, + change_model_params=change_model_params, + ) + if from_multitask != builder.pretrained_is_multitask: + raise ValueError("from_multitask does not match pretrained model params.") + return builder.build_single_rule( + single_param_target, + model_branch=model_branch, + model_branch_from=model_branch_from, + ) + + +def get_finetune_rules_from_model_params( + pretrained_model_params: dict[str, Any], + model_config: dict[str, Any], + *, + model_branch: str = "", + change_model_params: bool = True, + multitask_branch_error: str | None = None, + missing_model_params_error: str | None = None, +) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: + """Build fine-tuning rules from already-loaded pretrained model params.""" + return FinetuneRuleBuilder( + pretrained_model_params, + model_config, + model_branch=model_branch, + change_model_params=change_model_params, + multitask_branch_error=multitask_branch_error, + missing_model_params_error=missing_model_params_error, + ).build() + + def get_index_between_two_maps( old_map: list[str], new_map: list[str], diff --git a/source/tests/common/dpmodel/test_train_abstract_trainer.py b/source/tests/common/dpmodel/test_train_abstract_trainer.py new file mode 100644 index 0000000000..5ac21b71ee --- /dev/null +++ b/source/tests/common/dpmodel/test_train_abstract_trainer.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import pytest + +from deepmd.dpmodel.train import ( + AbstractTrainer, + LearningCurveWriter, + RankContext, + TrainerConfig, + TrainingTask, + TrainingTaskCollection, + TrainStepResult, +) + + +class DummyData: + def __init__(self, values: list[float]) -> None: + self.values = values + self.index = 0 + self.requirements: list[Any] = [] + + def __len__(self) -> int: + return len(self.values) + + def get_batch(self) -> dict[str, float]: + value = self.values[self.index % len(self.values)] + self.index += 1 + return {"value": value} + + def add_data_requirements(self, requirements: list[Any]) -> None: + self.requirements.extend(requirements) + + +class DummyTrainer(AbstractTrainer): + def __init__( + self, + trainer_config: TrainerConfig, + *, + rank_context: RankContext | None = None, + ) -> None: + super().__init__(trainer_config, rank_context=rank_context) + self.steps: list[tuple[str, int, float]] = [] + self.checkpoints: list[int] = [] + + def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: + batch = task.training_data.get_batch() + self.steps.append((task.key, step, batch["value"])) + return TrainStepResult(task_key=task.key, step=step, payload=batch) + + def evaluate_training( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float]: + if step_result is None or step_result.task_key != task.key: + return {"rmse": 0.0} + return {"rmse": float(step_result.payload["value"])} + + def evaluate_validation( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float] | None: + if task.validation_data is None: + return None + values = [ + float(task.validation_data.get_batch()["value"]) + for _ in range(task.valid_numb_batch) + ] + return {"rmse": sum(values) / len(values)} + + def learning_rate(self, step: int) -> float: + return 0.1 / (step + 1) + + def save_checkpoint(self, step: int) -> None: + self.checkpoints.append(step) + + +def _lcurve_steps(path: Path) -> list[int]: + steps = [] + for line in path.read_text().splitlines(): + if line.startswith("#") or not line.strip(): + continue + steps.append(int(line.split()[0])) + return steps + + +def test_single_task_collection_adds_requirements() -> None: + train_data = DummyData([1.0]) + valid_data = DummyData([2.0]) + tasks = TrainingTaskCollection.single( + train_data, + valid_data, + data_requirements=["energy"], + ) + + task = tasks.select() + task.add_data_requirements() + + assert not tasks.is_multitask + assert tasks.select() is task + assert train_data.requirements == ["energy"] + assert valid_data.requirements == ["energy"] + + +def test_task_collection_rejects_duplicate_sequence_keys() -> None: + with pytest.raises(ValueError, match="unique"): + TrainingTaskCollection( + [ + TrainingTask("task", DummyData([1.0])), + TrainingTask("task", DummyData([2.0])), + ] + ) + + +def test_task_collection_rejects_unknown_probability_keys() -> None: + with pytest.raises(ValueError, match="Unknown task probabilities"): + TrainingTaskCollection( + [ + TrainingTask("task_a", DummyData([1.0])), + TrainingTask("task_b", DummyData([2.0])), + ], + probabilities={"task_a": 1.0, "task_b": 1.0, "stale": 1.0}, + ) + + +def test_learning_curve_row_uses_training_metric_order() -> None: + row = LearningCurveWriter().format_row( + step=1, + learning_rate=0.1, + train_results={"rmse": 1.0, "mae": 2.0}, + valid_results={"mae": 3.0}, + ) + + assert row.split() == ["1", "nan", "1.00e+00", "3.00e+00", "2.00e+00", "1.0e-01"] + + +def test_abstract_trainer_drives_single_task_loop(tmp_path: Path) -> None: + lcurve = tmp_path / "lcurve.out" + trainer = DummyTrainer( + TrainerConfig( + num_steps=3, + disp_file=str(lcurve), + disp_freq=2, + save_freq=2, + timing_in_training=False, + ) + ) + tasks = TrainingTaskCollection.single( + DummyData([1.0, 2.0, 3.0]), + DummyData([10.0, 20.0]), + valid_numb_batch=2, + ) + + trainer.run(tasks) + + assert trainer.steps == [ + ("Default", 0, 1.0), + ("Default", 1, 2.0), + ("Default", 2, 3.0), + ] + assert trainer.checkpoints == [2, 3] + assert _lcurve_steps(lcurve) == [1, 2] + assert "rmse_val" in lcurve.read_text() + + +def test_non_chief_rank_skips_user_visible_outputs(tmp_path: Path) -> None: + lcurve = tmp_path / "lcurve.out" + trainer = DummyTrainer( + TrainerConfig( + num_steps=3, + disp_file=str(lcurve), + disp_freq=1, + save_freq=1, + timing_in_training=False, + ), + rank_context=RankContext(rank=1, world_size=2), + ) + tasks = TrainingTaskCollection.single( + DummyData([1.0, 2.0, 3.0]), + DummyData([10.0]), + ) + + trainer.run(tasks) + + assert trainer.steps == [ + ("Default", 0, 1.0), + ("Default", 1, 2.0), + ("Default", 2, 3.0), + ] + assert trainer.checkpoints == [] + assert not lcurve.exists() + + +def test_abstract_trainer_runs_full_validation_before_checkpoint( + tmp_path: Path, +) -> None: + class RecordingTrainer(DummyTrainer): + def __init__(self, trainer_config: TrainerConfig) -> None: + super().__init__(trainer_config) + self.events: list[str] = [] + + def run_full_validation( + self, + *, + step: int, + display_step: int, + learning_rate: float, + ) -> None: + self.events.append(f"full:{display_step}:{learning_rate:.2f}") + + def save_checkpoint(self, step: int) -> None: + self.events.append(f"save:{step}") + super().save_checkpoint(step) + + trainer = RecordingTrainer( + TrainerConfig( + num_steps=2, + disp_file=str(tmp_path / "lcurve.out"), + disp_freq=1, + save_freq=1, + display_in_training=False, + ) + ) + + trainer.run(TrainingTaskCollection.single(DummyData([1.0, 2.0]))) + + assert trainer.events == ["full:1:0.10", "save:1", "full:2:0.05", "save:2"] + + +def test_abstract_trainer_tears_down_when_lcurve_open_fails(tmp_path: Path) -> None: + events: list[str] = [] + + def record_begin(tasks: TrainingTaskCollection) -> None: + events.append("begin") + + def record_end(tasks: TrainingTaskCollection) -> None: + events.append("end") + + trainer = DummyTrainer( + TrainerConfig( + num_steps=1, + disp_file=str(tmp_path / "missing" / "lcurve.out"), + disp_freq=1, + save_freq=1, + ) + ) + trainer.on_train_begin = record_begin + trainer.on_train_end = record_end + + with pytest.raises(FileNotFoundError): + trainer.run( + TrainingTaskCollection.single( + DummyData([1.0]), + None, + ) + ) + + assert events == ["begin", "end"] + + +def test_multitask_training_uses_single_task_as_collection_item( + tmp_path: Path, +) -> None: + lcurve = tmp_path / "lcurve.out" + trainer = DummyTrainer( + TrainerConfig( + num_steps=1, + disp_file=str(lcurve), + disp_freq=1, + save_freq=1, + timing_in_training=False, + ) + ) + tasks = TrainingTaskCollection( + [ + TrainingTask( + "task_a", + DummyData([1.0]), + DummyData([11.0]), + ), + TrainingTask( + "task_b", + DummyData([2.0]), + DummyData([22.0]), + ), + ], + probabilities={"task_a": 1.0, "task_b": 0.0}, + ) + + trainer.run(tasks) + + assert trainer.steps == [("task_a", 0, 1.0)] + assert trainer.checkpoints == [1] + assert tasks["task_b"].training_data.index == 0 + lcurve_text = lcurve.read_text() + assert "rmse_val_task_a" in lcurve_text + assert "rmse_trn_task_b" in lcurve_text diff --git a/source/tests/common/dpmodel/test_train_data.py b/source/tests/common/dpmodel/test_train_data.py new file mode 100644 index 0000000000..ee9880a491 --- /dev/null +++ b/source/tests/common/dpmodel/test_train_data.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +from deepmd.dpmodel.train.data import ( + _print_summary, +) + + +def test_print_summary_supports_legacy_no_probability_signature() -> None: + class LegacySummary: + def __init__(self) -> None: + self.calls: list[str] = [] + + def print_summary(self, name: str) -> None: + self.calls.append(name) + + data = LegacySummary() + + _print_summary(data, "training", [1.0]) + + assert data.calls == ["training"] + + +def test_print_summary_does_not_swallow_internal_type_error() -> None: + class BrokenSummary: + def print_summary(self, name: str, prob: list[float] | None) -> None: + raise TypeError("internal summary failure") + + with pytest.raises(TypeError, match="internal summary failure"): + _print_summary(BrokenSummary(), "training", [1.0]) diff --git a/source/tests/common/dpmodel/test_train_entrypoint.py b/source/tests/common/dpmodel/test_train_entrypoint.py new file mode 100644 index 0000000000..860beef615 --- /dev/null +++ b/source/tests/common/dpmodel/test_train_entrypoint.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import pytest + +from deepmd.dpmodel.train import ( + AbstractTrainEntrypoint, + TrainEntrypointOptions, +) + + +class RecordingTrainEntrypoint(AbstractTrainEntrypoint): + def __init__( + self, + *, + fail_setup: bool = False, + fail_training: bool = False, + ) -> None: + self.events: list[str] = [] + self.fail_setup = fail_setup + self.fail_training = fail_training + self.neighbor_stat: Any = "unset" + + def prepare_options( + self, + options: TrainEntrypointOptions, + ) -> TrainEntrypointOptions: + self.events.append("prepare_options") + return options + + def load_config(self, input_file: str) -> dict[str, Any]: + self.events.append(f"load_config:{input_file}") + return { + "model": {}, + "training": {}, + } + + def validate_options( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> None: + self.events.append("validate_options") + + def preprocess_config( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> dict[str, Any]: + self.events.append("preprocess_config") + config["model"]["model_dict"] = {"task": {"type_map": ["O"]}} + return config + + def update_input(self, config: dict[str, Any]) -> dict[str, Any]: + self.events.append("update_input") + config["compat_updated"] = True + return config + + def normalize_config( + self, + config: dict[str, Any], + *, + multi_task: bool, + ) -> dict[str, Any]: + self.events.append(f"normalize_config:{multi_task}") + config["normalized"] = True + return config + + def update_neighbor_stat( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + *, + multi_task: bool, + ) -> tuple[dict[str, Any], dict[str, bool]]: + self.events.append(f"update_neighbor_stat:{multi_task}") + config["neighbor_updated"] = True + return config, {"multi_task": multi_task} + + def dump_config(self, config: dict[str, Any], output: str) -> None: + self.events.append(f"dump_config:{Path(output).name}") + super().dump_config(config, output) + + def print_summary(self) -> None: + self.events.append("print_summary") + + def setup_run( + self, + options: TrainEntrypointOptions, + config: dict[str, Any], + ) -> None: + self.events.append("setup_run") + if self.fail_setup: + raise RuntimeError("setup failed") + + def teardown_run( + self, + options: TrainEntrypointOptions, + config: dict[str, Any], + ) -> None: + self.events.append("teardown_run") + + def run_training( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + neighbor_stat: Any, + ) -> None: + self.events.append("run_training") + self.neighbor_stat = neighbor_stat + if self.fail_training: + raise RuntimeError("training failed") + + +def test_train_entrypoint_runs_common_pipeline(tmp_path: Path) -> None: + entrypoint = RecordingTrainEntrypoint() + output = tmp_path / "out.json" + + entrypoint.run( + TrainEntrypointOptions( + input_file="input.json", + output=str(output), + ) + ) + + assert entrypoint.events == [ + "prepare_options", + "load_config:input.json", + "validate_options", + "preprocess_config", + "update_input", + "normalize_config:True", + "update_neighbor_stat:True", + "dump_config:out.json", + "print_summary", + "setup_run", + "run_training", + "teardown_run", + ] + assert entrypoint.neighbor_stat == {"multi_task": True} + assert '"normalized": true' in output.read_text() + + +def test_train_entrypoint_can_skip_neighbor_stat(tmp_path: Path) -> None: + entrypoint = RecordingTrainEntrypoint() + + entrypoint.run( + TrainEntrypointOptions( + input_file="input.json", + output=str(tmp_path / "out.json"), + skip_neighbor_stat=True, + ) + ) + + assert "update_neighbor_stat:True" not in entrypoint.events + assert entrypoint.neighbor_stat is None + + +def test_train_entrypoint_tears_down_after_training_error(tmp_path: Path) -> None: + entrypoint = RecordingTrainEntrypoint(fail_training=True) + + with pytest.raises(RuntimeError, match="training failed"): + entrypoint.run( + TrainEntrypointOptions( + input_file="input.json", + output=str(tmp_path / "out.json"), + ) + ) + + assert entrypoint.events[-2:] == ["run_training", "teardown_run"] + + +def test_train_entrypoint_tears_down_after_setup_error(tmp_path: Path) -> None: + entrypoint = RecordingTrainEntrypoint(fail_setup=True) + + with pytest.raises(RuntimeError, match="setup failed"): + entrypoint.run( + TrainEntrypointOptions( + input_file="input.json", + output=str(tmp_path / "out.json"), + ) + ) + + assert entrypoint.events[-2:] == ["setup_run", "teardown_run"] diff --git a/source/tests/common/dpmodel/test_training_utils.py b/source/tests/common/dpmodel/test_training_utils.py new file mode 100644 index 0000000000..c86ccad00d --- /dev/null +++ b/source/tests/common/dpmodel/test_training_utils.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +from deepmd.dpmodel.utils.training_utils import ( + resolve_model_prob, +) + + +class UnsizedData: + pass + + +class SizedData: + def __len__(self) -> int: + return 3 + + +class SystemCountData: + def get_nsystems(self) -> int: + return 5 + + def __len__(self) -> int: + return 99 + + +class BrokenLenData: + def __len__(self) -> int: + raise TypeError("broken length") + + +def test_resolve_model_prob_uses_get_nsystems() -> None: + prob = resolve_model_prob( + ["systems", "sized"], + None, + { + "systems": SystemCountData(), + "sized": SizedData(), + }, + rank=1, + ) + + assert prob.tolist() == pytest.approx([5.0 / 8.0, 3.0 / 8.0]) + + +def test_resolve_model_prob_falls_back_for_unsized_data_only() -> None: + prob = resolve_model_prob( + ["unsized", "sized"], + None, + { + "unsized": UnsizedData(), + "sized": SizedData(), + }, + rank=1, + ) + + assert prob.tolist() == pytest.approx([0.25, 0.75]) + + +def test_resolve_model_prob_propagates_broken_len() -> None: + with pytest.raises(TypeError, match="broken length"): + resolve_model_prob( + ["broken"], + None, + {"broken": BrokenLenData()}, + rank=1, + ) diff --git a/source/tests/common/test_finetune_utils.py b/source/tests/common/test_finetune_utils.py index 776d7bb131..cce2ca1850 100644 --- a/source/tests/common/test_finetune_utils.py +++ b/source/tests/common/test_finetune_utils.py @@ -1,11 +1,35 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from copy import ( + deepcopy, +) from deepmd.utils import ( finetune, ) +def _model_config( + type_map: list[str], + *, + descriptor_sel: list[int] | None = None, + fitting_neuron: list[int] | None = None, + trainable: bool = True, +) -> dict: + return { + "type_map": type_map, + "descriptor": { + "type": "se_e2_a", + "sel": descriptor_sel or [1 for _ in type_map], + "trainable": trainable, + }, + "fitting_net": { + "neuron": fitting_neuron or [4], + "trainable": trainable, + }, + } + + def test_descriptor_normalization_uses_descriptor_type_count(): assert finetune._infer_synthetic_type_count({"sel": [16, 24, 32]}) == 3 assert finetune._infer_synthetic_type_count({"exclude_types": [[0, 3]]}) == 4 @@ -97,3 +121,154 @@ def test_descriptor_config_warning_distinguishes_none_from_missing(monkeypatch, assert "input_none: input=None, pretrained=(missing)" in caplog.text assert "pretrained_none: input=(missing), pretrained=None" in caplog.text + + +def test_finetune_rule_builder_updates_single_task_config_preserving_trainable(): + pretrained = _model_config(["O", "H"], descriptor_sel=[8, 16], fitting_neuron=[32]) + target = _model_config( + ["O", "H", "B"], + descriptor_sel=[1, 1, 1], + fitting_neuron=[2], + trainable=False, + ) + + updated, links = finetune.FinetuneRuleBuilder( + pretrained, + target, + change_model_params=True, + ).build() + + rule = links["Default"] + assert updated["descriptor"]["sel"] == [8, 16] + assert updated["descriptor"]["trainable"] is False + assert updated["fitting_net"]["neuron"] == [32] + assert updated["fitting_net"]["trainable"] is False + assert rule.get_finetune_tmap() == ["O", "H", "B"] + assert rule.get_pretrained_tmap() == ["O", "H"] + assert rule.get_has_new_type() + + +def test_finetune_rule_builder_random_fitting_keeps_target_fitting_net(): + pretrained = _model_config(["O", "H"], descriptor_sel=[8, 16], fitting_neuron=[32]) + target = _model_config(["O", "H"], descriptor_sel=[1, 1], fitting_neuron=[2]) + + updated, links = finetune.FinetuneRuleBuilder( + pretrained, + target, + model_branch="RANDOM", + change_model_params=True, + ).build() + + assert updated["descriptor"]["sel"] == [8, 16] + assert updated["fitting_net"]["neuron"] == [2] + assert links["Default"].get_random_fitting() + + +def test_finetune_rule_builder_rejects_unknown_branch_from_single_task(): + try: + finetune.FinetuneRuleBuilder( + _model_config(["O", "H"]), + _model_config(["O", "H"], descriptor_sel=[1, 1]), + model_branch="typo", + ).build() + except ValueError as exc: + assert "Single-task pretrained models" in str(exc) + assert "typo" in str(exc) + else: + raise AssertionError("expected ValueError") + + +def test_finetune_rule_builder_handles_multitask_resume_branch_and_random(): + pretrained = { + "model_dict": { + "task_a": _model_config(["O", "H"], descriptor_sel=[8, 16]), + "task_b": _model_config(["O", "H"], descriptor_sel=[4, 4]), + } + } + target = { + "model_dict": { + "task_a": _model_config(["O", "H"], descriptor_sel=[1, 1]), + "task_c": { + **_model_config(["O", "H"], descriptor_sel=[1, 1]), + "finetune_head": "task_b", + }, + "task_d": _model_config( + ["O", "H"], descriptor_sel=[1, 1], fitting_neuron=[7] + ), + } + } + + updated, links = finetune.FinetuneRuleBuilder( + pretrained, + target, + change_model_params=True, + ).build() + + assert links["task_a"].get_resuming() + assert links["task_a"].get_model_branch() == "task_a" + assert not links["task_c"].get_resuming() + assert links["task_c"].get_model_branch() == "task_b" + assert not links["task_c"].get_random_fitting() + assert links["task_d"].get_model_branch() == "task_a" + assert links["task_d"].get_random_fitting() + assert updated["model_dict"]["task_c"]["descriptor"]["sel"] == [4, 4] + assert updated["model_dict"]["task_d"]["descriptor"]["sel"] == [8, 16] + assert updated["model_dict"]["task_d"]["fitting_net"]["neuron"] == [7] + + +def test_finetune_rule_builder_accepts_multitask_finetune_head_alias(): + pretrained = { + "model_dict": { + "task_a": _model_config(["O", "H"], descriptor_sel=[8, 16]), + "task_b": { + **_model_config(["O", "H"], descriptor_sel=[4, 4]), + "model_branch_alias": ["alias_b"], + }, + } + } + target = { + "model_dict": { + "task_c": { + **_model_config(["O", "H"], descriptor_sel=[1, 1]), + "finetune_head": "alias_b", + }, + } + } + + updated, links = finetune.FinetuneRuleBuilder( + pretrained, + target, + change_model_params=True, + ).build() + + assert links["task_c"].get_model_branch() == "task_b" + assert updated["model_dict"]["task_c"]["descriptor"]["sel"] == [4, 4] + + +def test_finetune_rule_builder_does_not_mutate_input_config(): + target = _model_config(["O", "H"], descriptor_sel=[1, 1], fitting_neuron=[2]) + target_before = deepcopy(target) + + finetune.FinetuneRuleBuilder( + _model_config(["O", "H"], descriptor_sel=[8, 16], fitting_neuron=[32]), + target, + change_model_params=True, + ).build() + + assert target == target_before + + +def test_finetune_rule_builder_rejects_multitask_cli_branch(): + pretrained = {"model_dict": {"task_a": _model_config(["O", "H"])}} + target = {"model_dict": {"task_a": _model_config(["O", "H"])}} + + try: + finetune.FinetuneRuleBuilder( + pretrained, + target, + model_branch="task_a", + ).build() + except ValueError as exc: + assert "Multi-task fine-tuning" in str(exc) + else: + raise AssertionError("expected ValueError") diff --git a/source/tests/jax/test_training.py b/source/tests/jax/test_training.py index d0713fe466..25d3ccdc49 100644 --- a/source/tests/jax/test_training.py +++ b/source/tests/jax/test_training.py @@ -14,10 +14,19 @@ from pathlib import ( Path, ) +from types import ( + SimpleNamespace, +) from unittest.mock import ( patch, ) +import numpy as np +import optax + +from deepmd.dpmodel.train import ( + TrainEntrypointOptions, +) from deepmd.jax.entrypoints.freeze import ( freeze, ) @@ -25,8 +34,23 @@ main, ) from deepmd.jax.entrypoints.train import ( + JAXTrainEntrypoint, update_sel, ) +from deepmd.jax.env import ( + jnp, +) +from deepmd.jax.train.trainer import ( + DPTrainer, + _copy_matching_state_tree, + _scale_by_global_learning_rate, +) +from deepmd.jax.utils.finetune import ( + _load_model_params, +) +from deepmd.jax.utils.serialization import ( + _normalize_restored_state_keys, +) from deepmd.utils.compat import ( convert_optimizer_v31_to_v32, ) @@ -81,6 +105,235 @@ def _lcurve_steps(path: Path) -> set[int]: return steps +def test_jax_optimizer_scales_updates_with_explicit_global_lr() -> None: + """The optimizer LR comes from the loop step, not the per-task optax count.""" + tx = optax.chain(optax.scale_by_adam(), _scale_by_global_learning_rate()) + params = {"w": jnp.asarray(1.0)} + grads = {"w": jnp.asarray(1.0)} + state = tx.init(params) + + updates, state = tx.update( + grads, + state, + params, + learning_rate=jnp.asarray(0.2), + ) + np.testing.assert_allclose(np.asarray(updates["w"]), -0.2, rtol=1e-5) + + updates, _ = tx.update( + grads, + state, + params, + learning_rate=jnp.asarray(0.05), + ) + np.testing.assert_allclose(np.asarray(updates["w"]), -0.05, rtol=1e-5) + + +@patch("deepmd.jax.utils.finetune.serialize_from_file") +def test_jax_finetune_load_model_params_accepts_loader_paths( + serialize_from_file, +) -> None: + """Fine-tuning accepts every checkpoint path handled by serialize_from_file.""" + model_params = {"type_map": ["O"], "descriptor": {}, "fitting_net": {}} + serialize_from_file.return_value = {"model_def_script": model_params} + + assert _load_model_params("checkpoint") == model_params + serialize_from_file.assert_called_once_with("checkpoint") + + +def _minimal_jax_config(model_params: dict) -> dict: + return { + "model": model_params, + "training": { + "numb_steps": 1, + }, + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-8, + "decay_steps": 1, + }, + "loss": {}, + } + + +@patch("deepmd.jax.train.trainer.DPTrainer._build_losses") +@patch("deepmd.jax.train.trainer.DPTrainer._deserialize_models") +@patch("deepmd.jax.train.trainer.serialize_from_file") +def test_jax_init_model_preserves_input_model_script( + serialize_from_file, + deserialize_models, + build_losses, +) -> None: + """init_model loads weights without replacing input model metadata.""" + input_model = {"type_map": ["O"], "descriptor": {"input": True}} + checkpoint_model = {"type_map": ["O"], "descriptor": {"checkpoint": True}} + serialize_from_file.return_value = { + "model": {}, + "model_def_script": checkpoint_model, + } + deserialize_models.return_value = { + "Default": SimpleNamespace(get_dim_fparam=lambda: 0) + } + build_losses.return_value = {"Default": SimpleNamespace(label_requirement=[])} + + trainer = DPTrainer(_minimal_jax_config(input_model), init_model="model-1.jax") + + assert trainer.model_def_script == input_model + assert trainer.model_params_by_task["Default"] == input_model + + +@patch("deepmd.jax.train.trainer.DPTrainer._build_losses") +@patch("deepmd.jax.train.trainer.DPTrainer._deserialize_models") +@patch("deepmd.jax.train.trainer.serialize_from_file") +def test_jax_restart_uses_checkpoint_model_script( + serialize_from_file, + deserialize_models, + build_losses, +) -> None: + """Restart keeps checkpoint metadata and resumed current_step.""" + input_model = {"type_map": ["O"], "descriptor": {"input": True}} + checkpoint_model = { + "type_map": ["O"], + "descriptor": {"checkpoint": True}, + "current_step": 7, + } + serialize_from_file.return_value = { + "model": {}, + "model_def_script": checkpoint_model, + } + deserialize_models.return_value = { + "Default": SimpleNamespace(get_dim_fparam=lambda: 0) + } + build_losses.return_value = {"Default": SimpleNamespace(label_requirement=[])} + + trainer = DPTrainer(_minimal_jax_config(input_model), restart="model-7.jax") + + assert trainer.model_def_script == checkpoint_model + assert trainer.model_params_by_task["Default"] == checkpoint_model + assert trainer.start_step == 7 + + +def test_jax_full_validator_saves_directory_best_checkpoint(tmp_path: Path) -> None: + """JAX full validation uses .jax directory checkpoints.""" + from deepmd.jax.train.validation import ( + JAXFullValidator, + ) + + state_store: dict = {} + validator = JAXFullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": str(tmp_path / "val.log"), + "full_val_start": 0.0, + }, + validation_data=SimpleNamespace(), + model=SimpleNamespace(), + state_store=state_store, + num_steps=2, + rank=0, + restart_training=False, + checkpoint_dir=tmp_path, + ) + + def save_checkpoint(path: Path, lr: float = 0.0, step: int = 0) -> None: + del lr, step + path.mkdir(parents=True) + + with patch.object( + validator, + "evaluate_all_systems", + return_value={"mae_e_per_atom": 1.0}, + ): + result = validator.run( + step_id=1, + display_step=1, + lr=0.001, + save_checkpoint=save_checkpoint, + ) + + assert result is not None + assert (tmp_path / "best.ckpt-1.t-1.jax").is_dir() + assert state_store["full_validation_topk_records"] == [{"metric": 1.0, "step": 1}] + assert "1000.0" in (tmp_path / "val.log").read_text() + + +def test_jax_full_validator_broadcasts_rank_zero_errors(tmp_path: Path) -> None: + """JAX full validation synchronizes rank-0 failures to peer processes.""" + from deepmd.jax.train.validation import ( + JAXFullValidator, + ) + + validator = JAXFullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": False, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": str(tmp_path / "val.log"), + "full_val_start": 0.0, + }, + validation_data=SimpleNamespace(), + model=SimpleNamespace(), + state_store={}, + num_steps=2, + rank=1, + restart_training=False, + checkpoint_dir=tmp_path, + ) + + with ( + patch("deepmd.jax.train.validation.jax.process_count", return_value=2), + patch( + "jax.experimental.multihost_utils.broadcast_one_to_all", + return_value=np.asarray(True), + ) as broadcast_one_to_all, + ): + assert ( + validator.propagate_error(None) + == "Full validation failed on rank 0; see rank-0 logs." + ) + + broadcast_one_to_all.assert_called_once() + assert broadcast_one_to_all.call_args.kwargs["is_source"] is False + + +def test_jax_full_validation_hook_uses_display_step() -> None: + """JAX full-validation checkpoints carry one-based display steps.""" + calls: list[dict] = [] + save_calls: list[tuple[Path, float, int]] = [] + + class FakeValidator: + def run(self, **kwargs) -> None: + calls.append(kwargs) + kwargs["save_checkpoint"](Path("best.jax"), lr=kwargs["lr"], step=99) + + trainer = DPTrainer.__new__(DPTrainer) + trainer.full_validator = FakeValidator() + + def save_checkpoint(path: Path, lr: float = 0.0, step: int = 0) -> None: + save_calls.append((path, lr, step)) + + trainer._save_full_validation_checkpoint = save_checkpoint + + DPTrainer.run_full_validation( + trainer, + step=0, + display_step=1, + learning_rate=0.25, + ) + + assert calls[0]["step_id"] == 1 + assert calls[0]["display_step"] == 1 + assert calls[0]["lr"] == 0.25 + assert save_calls == [(Path("best.jax"), 0.25, 99)] + + class TestJAXTraining(unittest.TestCase): """Regression tests for complete JAX training runs.""" @@ -164,6 +417,111 @@ def test_update_sel_uses_jax_neighbor_stat(self, get_nbor_stat, get_data) -> Non get_data.assert_called_once_with({}, 0, ["O", "H"], None) get_nbor_stat.assert_called_once() + def test_train_entrypoint_rejects_remaining_unsupported_features(self) -> None: + """JAX train gates features that are still backend-specific gaps.""" + entrypoint = JAXTrainEntrypoint() + + cases = [ + ( + {"model": {}, "training": {}}, + TrainEntrypointOptions( + input_file="input.json", + init_frz_model="frozen_model.pb", + ), + "init_frz_model", + ), + ( + { + "model": { + "model_dict": {"task": {}}, + "shared_dict": {"shared": {}}, + }, + "training": {}, + }, + TrainEntrypointOptions(input_file="input.json"), + "shared_dict", + ), + ] + + for config, options, message in cases: + with self.subTest(message=message): + with self.assertRaisesRegex(NotImplementedError, message): + entrypoint.validate_options(config, options) + + @patch("deepmd.jax.utils.finetune.get_finetune_rules") + def test_train_entrypoint_preprocesses_finetune_rules( + self, get_finetune_rules + ) -> None: + """JAX train preprocesses fine-tuning config through backend rules.""" + get_finetune_rules.return_value = ( + {"type_map": ["O"], "descriptor": {}, "fitting_net": {}}, + {"Default": object()}, + ) + entrypoint = JAXTrainEntrypoint() + config = {"model": {"type_map": ["O"]}, "training": {}} + + updated = entrypoint.preprocess_config( + config, + TrainEntrypointOptions( + input_file="input.json", + finetune="pretrain.jax", + model_branch="head", + use_pretrain_script=True, + ), + ) + + self.assertEqual(updated["model"]["type_map"], ["O"]) + self.assertIsNotNone(entrypoint.finetune_links) + get_finetune_rules.assert_called_once_with( + "pretrain.jax", + {"type_map": ["O"]}, + model_branch="head", + change_model_params=True, + ) + + @patch("deepmd.jax.entrypoints.train.get_data") + @patch("deepmd.jax.utils.update_sel.UpdateSel.get_nbor_stat") + def test_update_sel_supports_multitask(self, get_nbor_stat, get_data) -> None: + """JAX update_sel updates each multi-task branch.""" + get_nbor_stat.return_value = 0.5, [10, 20] + model_config = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "rcut": 6.0, + "sel": "auto", + }, + } + jdata = { + "model": { + "model_dict": { + "task_a": json.loads(json.dumps(model_config)), + "task_b": json.loads(json.dumps(model_config)), + } + }, + "training": { + "data_dict": { + "task_a": {"training_data": {"systems": ["a"]}}, + "task_b": {"training_data": {"systems": ["b"]}}, + } + }, + } + + updated, min_nbor_dist = update_sel(jdata, multi_task=True) + + self.assertEqual( + updated["model"]["model_dict"]["task_a"]["descriptor"]["sel"], [12, 24] + ) + self.assertEqual( + updated["model"]["model_dict"]["task_b"]["descriptor"]["sel"], [12, 24] + ) + self.assertEqual( + min_nbor_dist, + {"task_a": 0.5, "task_b": 0.5}, + ) + self.assertEqual(get_data.call_count, 2) + self.assertEqual(get_nbor_stat.call_count, 2) + @patch("deepmd.jax.entrypoints.freeze.deserialize_to_file") @patch("deepmd.jax.entrypoints.freeze.serialize_from_file") def test_freeze_entrypoint_uses_checkpoint_pointer( @@ -196,3 +554,53 @@ def test_main_dispatches_freeze(self, freeze_entrypoint) -> None: main(args) freeze_entrypoint.assert_called_once() + + +def test_jax_finetune_state_copy_preserves_random_fitting_target_leaves() -> None: + """Random fitting should copy descriptor leaves only.""" + target = { + "descriptor": {"w": np.zeros((2,), dtype=np.float64)}, + "fitting_net": {"w": np.zeros((2,), dtype=np.float64)}, + "output": {"bias": np.zeros((1,), dtype=np.float64)}, + } + source = { + "descriptor": {"w": np.ones((2,), dtype=np.float64)}, + "fitting_net": {"w": np.full((2,), 2.0, dtype=np.float64)}, + "output": {"bias": np.ones((1,), dtype=np.float64)}, + } + + copied = _copy_matching_state_tree(target, source, random_fitting=True) + + np.testing.assert_array_equal(copied["descriptor"]["w"], source["descriptor"]["w"]) + np.testing.assert_array_equal( + copied["fitting_net"]["w"], target["fitting_net"]["w"] + ) + np.testing.assert_array_equal(copied["output"]["bias"], target["output"]["bias"]) + + +def test_jax_finetune_state_copy_requires_matching_leaf_shape() -> None: + """Mismatched state leaves are left unchanged.""" + target = {"descriptor": {"w": np.zeros((2,), dtype=np.float64)}} + source = {"descriptor": {"w": np.ones((3,), dtype=np.float64)}} + + copied = _copy_matching_state_tree(target, source, random_fitting=False) + + np.testing.assert_array_equal(copied["descriptor"]["w"], target["descriptor"]["w"]) + + +def test_jax_multitask_state_key_normalization_preserves_numeric_task_names() -> None: + """Numeric-looking task keys are branch names, not layer indices.""" + state = { + "models": { + "1": {"layers": {"0": {"w": 1}}}, + "task": {"layers": {"0": {"w": 2}}}, + } + } + model_def_script = {"model_dict": {"1": {}, "task": {}}} + + _normalize_restored_state_keys(state, model_def_script) + + assert "1" in state["models"] + assert 1 not in state["models"] + assert 0 in state["models"]["1"]["layers"] + assert 0 in state["models"]["task"]["layers"] diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py index c3fee681e8..03ea41edca 100644 --- a/source/tests/pt/test_validation.py +++ b/source/tests/pt/test_validation.py @@ -54,6 +54,10 @@ def get_dim_aparam(self) -> int: return 0 +class _LmdbDatasetWithoutTypeMap: + lmdb_path = "missing-type-map.lmdb" + + def _make_lmdb_frame(natoms: int, seed: int) -> dict: """Create one synthetic LMDB frame for full-validation tests.""" rng = np.random.RandomState(seed) @@ -317,6 +321,60 @@ def test_full_validator_writes_best_into_custom_checkpoint_dir(self) -> None: self.assertEqual(new_best_path, str(best_dir / "best.ckpt-1.t-1.pt")) + def test_full_validator_reconciles_directory_checkpoints(self) -> None: + train_infos = {} + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + validator = FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 2, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=_DummyModel(), + state_store=train_infos, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=False, + best_checkpoint_suffix=".jax", + ) + new_best_path = validator._update_best_state( + display_step=1, + selected_metric_value=2.0, + ) + Path(new_best_path).mkdir() + validator._reconcile_best_checkpoints() + + new_best_path = validator._update_best_state( + display_step=2, + selected_metric_value=1.0, + ) + Path(new_best_path).mkdir() + validator._reconcile_best_checkpoints() + + new_best_path = validator._update_best_state( + display_step=3, + selected_metric_value=1.5, + ) + Path(new_best_path).mkdir() + validator._reconcile_best_checkpoints() + finally: + os.chdir(old_cwd) + + self.assertEqual(new_best_path, "best.ckpt-3.t-2.jax") + self.assertEqual( + sorted(path.name for path in Path(tmpdir).glob("best.ckpt-*.jax")), + ["best.ckpt-2.t-1.jax", "best.ckpt-3.t-2.jax"], + ) + def test_full_validator_lmdb_full_validation_iterates_nloc_groups(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: lmdb_path = _create_mixed_nloc_lmdb(f"{tmpdir}/mixed.lmdb") @@ -367,6 +425,29 @@ def fake_evaluate_system(data_system): self.assertAlmostEqual(metrics["mae_e_per_atom"], 8.4) self.assertAlmostEqual(metrics["rmse_e_per_atom"], np.sqrt(75.6)) + def test_full_validator_lmdb_snapshot_requires_type_map(self) -> None: + validator = FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": False, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=_DummyModel(), + state_store={}, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=False, + ) + + with self.assertRaisesRegex(TypeError, "LMDB type_map"): + validator._get_lmdb_test_data_snapshot(_LmdbDatasetWithoutTypeMap()) + class TestValidationArgcheck(unittest.TestCase): def test_normalize_rejects_missing_validation_data(self) -> None: diff --git a/source/tests/pt_expt/test_entrypoint.py b/source/tests/pt_expt/test_entrypoint.py new file mode 100644 index 0000000000..aba7fcb339 --- /dev/null +++ b/source/tests/pt_expt/test_entrypoint.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +from deepmd.dpmodel.train import ( + TrainEntrypointOptions, +) +from deepmd.pt_expt.entrypoints.main import ( + PTExptTrainEntrypoint, + _ensure_pt_expt_model_suffix, + _ensure_stat_file_path, + train, +) + + +@pytest.mark.parametrize( + ("model_path", "expected"), + [ + (None, None), + ("model", "model.pt"), + ("model.pt", "model.pt"), + ("model.pte", "model.pte"), + ("model.pt2", "model.pt2"), + ], +) +def test_pt_expt_model_suffix_accepts_checkpoint_and_export_suffixes( + model_path: str | None, + expected: str | None, +) -> None: + assert _ensure_pt_expt_model_suffix(model_path) == expected + + +def test_pt_expt_train_entrypoint_normalizes_checkpoint_prefixes() -> None: + options = TrainEntrypointOptions( + input_file="input.json", + init_model="init", + restart="restart.pte", + finetune="pretrain.pt2", + ) + + prepared = PTExptTrainEntrypoint().prepare_options(options) + + assert prepared.init_model == "init.pt" + assert prepared.restart == "restart.pte" + assert prepared.finetune == "pretrain.pt2" + + +def test_pt_expt_train_wrapper_uses_common_entrypoint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: list[TrainEntrypointOptions] = [] + + def fake_run( + self: PTExptTrainEntrypoint, + options: TrainEntrypointOptions, + ) -> None: + captured.append(options) + + monkeypatch.setattr(PTExptTrainEntrypoint, "run", fake_run) + + train( + input_file="input.json", + init_model="init", + restart="restart", + finetune="pretrain.pte", + model_branch="head", + use_pretrain_script=True, + skip_neighbor_stat=True, + output="normalized.json", + ) + + assert captured == [ + TrainEntrypointOptions( + input_file="input.json", + output="normalized.json", + init_model="init", + restart="restart", + finetune="pretrain.pte", + model_branch="head", + use_pretrain_script=True, + skip_neighbor_stat=True, + ) + ] + + +def test_pt_expt_entrypoint_keeps_caller_owned_process_group( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import torch.distributed as dist + + calls: list[str] = [] + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setattr(dist, "is_available", lambda: True) + monkeypatch.setattr(dist, "is_initialized", lambda: True) + monkeypatch.setattr( + dist, + "init_process_group", + lambda *args, **kwargs: calls.append("init"), + ) + monkeypatch.setattr(dist, "destroy_process_group", lambda: calls.append("destroy")) + + entrypoint = PTExptTrainEntrypoint() + entrypoint.setup_run(TrainEntrypointOptions(input_file="input.json"), {}) + entrypoint.teardown_run(TrainEntrypointOptions(input_file="input.json"), {}) + + assert calls == [] + + +def test_pt_expt_entrypoint_destroys_only_owned_process_group( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import torch.distributed as dist + + state = {"initialized": False} + calls: list[str] = [] + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setattr(dist, "is_available", lambda: True) + monkeypatch.setattr(dist, "is_initialized", lambda: state["initialized"]) + + def init_process_group(*args, **kwargs) -> None: + calls.append("init") + state["initialized"] = True + + def destroy_process_group() -> None: + calls.append("destroy") + state["initialized"] = False + + monkeypatch.setattr(dist, "init_process_group", init_process_group) + monkeypatch.setattr(dist, "destroy_process_group", destroy_process_group) + + entrypoint = PTExptTrainEntrypoint() + entrypoint.setup_run(TrainEntrypointOptions(input_file="input.json"), {}) + entrypoint.teardown_run(TrainEntrypointOptions(input_file="input.json"), {}) + entrypoint.teardown_run(TrainEntrypointOptions(input_file="input.json"), {}) + + assert calls == ["init", "destroy"] + + +def test_pt_expt_entrypoint_rejects_random_model_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "deepmd.pt_expt.utils.multi_task.preprocess_shared_params", + lambda model_params: (model_params, None), + ) + entrypoint = PTExptTrainEntrypoint() + + with pytest.raises(ValueError, match="RANDOM"): + entrypoint.preprocess_config( + {"model": {"model_dict": {"RANDOM": {}}}}, + TrainEntrypointOptions(input_file="input.json"), + ) + + +def test_pt_expt_checkpoint_cleanup_keeps_newest_steps(tmp_path) -> None: + from deepmd.pt_expt.train.training import ( + Trainer, + ) + + trainer = Trainer.__new__(Trainer) + trainer.save_ckpt = str(tmp_path / "model.ckpt") + trainer.max_ckpt_keep = 2 + for step in (1, 2, 3): + (tmp_path / f"model.ckpt-{step}.pt").write_text("") + (tmp_path / "model.ckpt.pt").symlink_to("model.ckpt-3.pt") + + trainer._cleanup_old_checkpoints() + + assert not (tmp_path / "model.ckpt-1.pt").exists() + assert (tmp_path / "model.ckpt-2.pt").exists() + assert (tmp_path / "model.ckpt-3.pt").exists() + assert (tmp_path / "model.ckpt.pt").exists() + + +def test_pt_expt_latest_checkpoint_link_uses_relative_target(tmp_path) -> None: + from deepmd.pt_expt.train.training import ( + _replace_latest_checkpoint_link, + ) + + ckpt_path = tmp_path / "ckpts" / "model-1.pt" + ckpt_path.parent.mkdir() + ckpt_path.write_text("") + latest = tmp_path / "ckpts" / "model.pt" + + _replace_latest_checkpoint_link(latest, ckpt_path) + + assert latest.is_symlink() + assert latest.resolve() == ckpt_path + assert latest.readlink().as_posix() == "model-1.pt" + + +def test_pt_expt_save_checkpoint_creates_parent_and_latest_link(tmp_path) -> None: + from deepmd.pt_expt.train.training import ( + Trainer, + ) + + class DummyWrapper: + train_infos: dict[str, int] + model: dict[str, object] + + def __init__(self) -> None: + self.train_infos = {} + self.model = {} + + def state_dict(self) -> dict[str, object]: + return {} + + class DummyOptimizer: + def state_dict(self) -> dict[str, object]: + return {} + + trainer = Trainer.__new__(Trainer) + trainer.wrapper = DummyWrapper() + trainer.optimizer = DummyOptimizer() + trainer.save_ckpt = str(tmp_path / "ckpts" / "model") + trainer.max_ckpt_keep = 2 + + trainer.save_checkpoint(1) + + ckpt_path = tmp_path / "ckpts" / "model-1.pt" + latest = tmp_path / "ckpts" / "model.pt" + assert ckpt_path.exists() + assert latest.is_symlink() + assert latest.resolve() == ckpt_path + assert latest.readlink().as_posix() == "model-1.pt" + + +def test_pt_expt_stat_file_path_creates_hdf5_parent(tmp_path) -> None: + stat_file = tmp_path / "stats" / "model_stat.hdf5" + + stat_path = _ensure_stat_file_path(str(stat_file)) + + assert stat_file.exists() + assert stat_path is not None diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 45061c084a..c4ee0955a7 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -311,6 +311,52 @@ def test_training_loop(self) -> None: config = normalize(config) self._run_training(config) + @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") + def test_full_validation_loop(self, mocked_eval) -> None: + """Run pt_expt full validation and verify best-checkpoint outputs.""" + mocked_eval.side_effect = [ + {"mae_e_per_atom": 1.0}, + {"mae_e_per_atom": 0.5}, + ] + config = _make_config(self.data_dir, numb_steps=2) + config["training"]["save_freq"] = 100 + config["validating"] = { + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + } + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_full_validation_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + self.assertIsNotNone(trainer.full_validator) + trainer.run() + + self.assertTrue(os.path.exists("best.ckpt-2.t-1.pt")) + self.assertFalse(os.path.exists("best.ckpt-1.t-1.pt")) + with open("val.log") as fp: + val_lines = [ + line for line in fp.readlines() if not line.startswith("#") + ] + self.assertEqual(len(val_lines), 2) + self.assertEqual( + trainer._unwrapped.train_infos["full_validation_topk_records"], + [{"metric": 0.5, "step": 2}], + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + def test_training_loop_dpa4(self) -> None: """Run a few DPA4/SeZM training steps (model type "dpa4" dispatch).""" config = _make_config(self.data_dir, numb_steps=5)