Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions deepmd/dpmodel/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Backend-independent training abstractions."""

from .data import (
TrainingTaskConfig,
iter_training_task_configs,
make_task_maps,
print_data_summaries,
)
from .entrypoint import (
AbstractTrainEntrypoint,
TrainEntrypointOptions,
)
from .trainer import (
DEFAULT_TASK_KEY,
AbstractTrainer,
LearningCurveWriter,
RankContext,
TrainerConfig,
TrainingTask,
TrainingTaskCollection,
TrainStepResult,
)

__all__ = [
"DEFAULT_TASK_KEY",
"AbstractTrainEntrypoint",
"AbstractTrainer",
"LearningCurveWriter",
"RankContext",
"TrainEntrypointOptions",
"TrainStepResult",
"TrainerConfig",
"TrainingTask",
"TrainingTaskCollection",
"TrainingTaskConfig",
"iter_training_task_configs",
"make_task_maps",
"print_data_summaries",
]
137 changes: 137 additions & 0 deletions deepmd/dpmodel/train/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Shared training-data helpers for backend entrypoints."""

from __future__ import (
annotations,
)

import inspect
from dataclasses import (
dataclass,
)
from typing import (
TYPE_CHECKING,
Any,
)

from .trainer import (
DEFAULT_TASK_KEY,
)

if TYPE_CHECKING:
from collections.abc import (
Callable,
Iterator,
Mapping,
)


@dataclass(frozen=True)
class TrainingTaskConfig:
"""Normalized config view for one training task."""

key: str
model_params: Mapping[str, Any]
training_data_params: Mapping[str, Any]
validation_data_params: Mapping[str, Any] | None
stat_file: str | None
valid_numb_batch: int


def iter_training_task_configs(
config: Mapping[str, Any],
) -> Iterator[TrainingTaskConfig]:
"""Yield task configs, treating single-task input as one ``Default`` task."""
model_params = config["model"]
training_params = config["training"]
if "model_dict" not in model_params:
validation_data_params = training_params.get("validation_data")
yield TrainingTaskConfig(
key=DEFAULT_TASK_KEY,
model_params=model_params,
training_data_params=training_params["training_data"],
validation_data_params=validation_data_params,
stat_file=training_params.get("stat_file"),
valid_numb_batch=_valid_numb_batch(validation_data_params),
)
return

data_dict = training_params["data_dict"]
for task_key, task_model_params in model_params["model_dict"].items():
task_data_params = data_dict[task_key]
validation_data_params = task_data_params.get("validation_data")
yield TrainingTaskConfig(
key=task_key,
model_params=task_model_params,
training_data_params=task_data_params["training_data"],
validation_data_params=validation_data_params,
stat_file=task_data_params.get("stat_file"),
valid_numb_batch=_valid_numb_batch(validation_data_params),
)


def make_task_maps(
config: Mapping[str, Any],
factory: Callable[[TrainingTaskConfig], tuple[Any, Any | None, Any | None]],
) -> tuple[dict[str, Any], dict[str, Any | None], dict[str, Any | None]]:
"""Build training, validation, and stat maps from normalized task configs."""
training_data: dict[str, Any] = {}
validation_data: dict[str, Any | None] = {}
stat_data: dict[str, Any | None] = {}
for task_config in iter_training_task_configs(config):
train_item, valid_item, stat_item = factory(task_config)
training_data[task_config.key] = train_item
validation_data[task_config.key] = valid_item
stat_data[task_config.key] = stat_item
return training_data, validation_data, stat_data


def print_data_summaries(
training_data: Mapping[str, Any],
validation_data: Mapping[str, Any | None],
*,
probabilities: Mapping[str, float] | None = None,
) -> None:
"""Print train/validation data summaries for one or more tasks."""
multi_task = len(training_data) > 1
for task_key, data in training_data.items():
name = f"training data({task_key})" if multi_task else "training"
_print_summary(data, name, _task_probability(probabilities, task_key))
valid_data = validation_data.get(task_key)
if valid_data is not None:
name = f"validation data({task_key})" if multi_task else "validation"
_print_summary(valid_data, name, None)


def _valid_numb_batch(validation_data_params: Mapping[str, Any] | None) -> int:
if validation_data_params is None:
return 1
return max(int(validation_data_params.get("numb_btch", 1)), 1)


def _task_probability(
probabilities: Mapping[str, float] | None,
task_key: str,
) -> list[float] | None:
if probabilities is None or task_key not in probabilities:
return None
return [float(probabilities[task_key])]


def _print_summary(data: Any, name: str, prob: list[float] | None) -> None:
printer = data.print_summary
try:
signature = inspect.signature(printer)
except (TypeError, ValueError):
printer(name, prob)
return
try:
signature.bind(name, prob)
except TypeError as exc:
try:
signature.bind(name)
except TypeError:
raise exc from None
printer(name)
else:
printer(name, prob)
173 changes: 173 additions & 0 deletions deepmd/dpmodel/train/entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Backend-independent training entrypoint pipeline."""

from __future__ import (
annotations,
)

import json
import logging
from abc import (
ABC,
abstractmethod,
)
from dataclasses import (
dataclass,
)
from typing import (
Any,
)

from deepmd.common import (
j_loader,
)
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)

log = logging.getLogger(__name__)


@dataclass
class TrainEntrypointOptions:
"""Common command options for backend train entrypoints."""

input_file: str
output: str = "out.json"
init_model: str | None = None
restart: str | None = None
init_frz_model: str | None = None
finetune: str | None = None
model_branch: str = ""
use_pretrain_script: bool = False
skip_neighbor_stat: bool = False


class AbstractTrainEntrypoint(ABC):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""Shared pipeline for backend train entrypoints.

Backend subclasses keep ownership of backend-specific feature handling,
neighbor-stat updates, distributed setup, data construction, and trainer
construction. This pipeline only coordinates the common command flow.
"""

def run(self, options: TrainEntrypointOptions) -> None:
"""Run the training entrypoint."""
log.info("Configuration path: %s", options.input_file)
options = self.prepare_options(options)
config = self.load_config(options.input_file)
self.validate_options(config, options)

config = self.preprocess_config(config, options)
multi_task = self.is_multi_task(config)
config = self.update_input(config)
config = self.normalize_config(config, multi_task=multi_task)

neighbor_stat = None
if not options.skip_neighbor_stat:
config, neighbor_stat = self.update_neighbor_stat(
config,
options,
multi_task=multi_task,
)

self.dump_config(config, options.output)
self.print_summary()

try:
self.setup_run(options, config)
self.run_training(config, options, neighbor_stat)
finally:
self.teardown_run(options, config)

def prepare_options(
self,
options: TrainEntrypointOptions,
) -> TrainEntrypointOptions:
"""Normalize command options before reading or preprocessing config."""
return options

def load_config(self, input_file: str) -> dict[str, Any]:
"""Load the JSON/YAML training config."""
return j_loader(input_file)

def validate_options(
self,
config: dict[str, Any],
options: TrainEntrypointOptions,
) -> None:
"""Validate backend feature support before mutating the config."""
return None

def preprocess_config(
self,
config: dict[str, Any],
options: TrainEntrypointOptions,
) -> dict[str, Any]:
"""Apply backend-specific config preprocessing before argcheck."""
return config

def is_multi_task(self, config: dict[str, Any]) -> bool:
"""Return whether the config is in multi-task layout."""
return "model_dict" in config.get("model", {})

def update_input(self, config: dict[str, Any]) -> dict[str, Any]:
"""Apply DeePMD input-version compatibility conversion."""
return update_deepmd_input(config, warning=True, dump="input_v2_compat.json")

def normalize_config(
self,
config: dict[str, Any],
*,
multi_task: bool,
) -> dict[str, Any]:
"""Run DeePMD argcheck normalization."""
return normalize(config, multi_task=multi_task)

def update_neighbor_stat(
self,
config: dict[str, Any],
options: TrainEntrypointOptions,
*,
multi_task: bool,
) -> tuple[dict[str, Any], Any]:
"""Update descriptor selections from neighbor statistics."""
return config, None

def dump_config(self, config: dict[str, Any], output: str) -> None:
"""Dump the normalized config used for training."""
with open(output, "w") as fp:
json.dump(config, fp, indent=4)

def print_summary(self) -> None:
"""Print backend summary information."""
return None

def setup_run(
self,
options: TrainEntrypointOptions,
config: dict[str, Any],
) -> None:
"""Set up backend runtime state before trainer execution."""
return None

def teardown_run(
self,
options: TrainEntrypointOptions,
config: dict[str, Any],
) -> None:
"""Tear down backend runtime state after trainer execution."""
return None

@abstractmethod
def run_training(
self,
config: dict[str, Any],
options: TrainEntrypointOptions,
neighbor_stat: Any,
) -> None:
"""Build backend data/trainer objects and run training."""
raise NotImplementedError
Loading
Loading