-
Notifications
You must be signed in to change notification settings - Fork 629
feat(dpmodel): add backend-independent trainer abstraction #5603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
njzjz
wants to merge
7
commits into
deepmodeling:master
Choose a base branch
from
njzjz:feat/dpmodel-abstract-trainer-5229
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+5,542
−1,495
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9485193
feat(dpmodel): add backend-independent trainer abstraction
njzjz 7280c14
fix(jax): use global lr in multitask trainer
njzjz 476a857
chore: merge origin/master into dpmodel trainer branch
njzjz ca88f22
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7632fef
feat(train): enable full validation for jax and pt-expt
njzjz 418b4c8
fix(train): address validation review comments
njzjz 4599715
fix(jax): avoid optax init_empty_state dependency
njzjz-bot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| """Backend-independent training abstractions.""" | ||
|
|
||
| from .data import ( | ||
| TrainingTaskConfig, | ||
| iter_training_task_configs, | ||
| make_task_maps, | ||
| print_data_summaries, | ||
| ) | ||
| from .entrypoint import ( | ||
| AbstractTrainEntrypoint, | ||
| TrainEntrypointOptions, | ||
| ) | ||
| from .trainer import ( | ||
| DEFAULT_TASK_KEY, | ||
| AbstractTrainer, | ||
| LearningCurveWriter, | ||
| RankContext, | ||
| TrainerConfig, | ||
| TrainingTask, | ||
| TrainingTaskCollection, | ||
| TrainStepResult, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "DEFAULT_TASK_KEY", | ||
| "AbstractTrainEntrypoint", | ||
| "AbstractTrainer", | ||
| "LearningCurveWriter", | ||
| "RankContext", | ||
| "TrainEntrypointOptions", | ||
| "TrainStepResult", | ||
| "TrainerConfig", | ||
| "TrainingTask", | ||
| "TrainingTaskCollection", | ||
| "TrainingTaskConfig", | ||
| "iter_training_task_configs", | ||
| "make_task_maps", | ||
| "print_data_summaries", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| """Shared training-data helpers for backend entrypoints.""" | ||
|
|
||
| from __future__ import ( | ||
| annotations, | ||
| ) | ||
|
|
||
| import inspect | ||
| from dataclasses import ( | ||
| dataclass, | ||
| ) | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| ) | ||
|
|
||
| from .trainer import ( | ||
| DEFAULT_TASK_KEY, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import ( | ||
| Callable, | ||
| Iterator, | ||
| Mapping, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TrainingTaskConfig: | ||
| """Normalized config view for one training task.""" | ||
|
|
||
| key: str | ||
| model_params: Mapping[str, Any] | ||
| training_data_params: Mapping[str, Any] | ||
| validation_data_params: Mapping[str, Any] | None | ||
| stat_file: str | None | ||
| valid_numb_batch: int | ||
|
|
||
|
|
||
| def iter_training_task_configs( | ||
| config: Mapping[str, Any], | ||
| ) -> Iterator[TrainingTaskConfig]: | ||
| """Yield task configs, treating single-task input as one ``Default`` task.""" | ||
| model_params = config["model"] | ||
| training_params = config["training"] | ||
| if "model_dict" not in model_params: | ||
| validation_data_params = training_params.get("validation_data") | ||
| yield TrainingTaskConfig( | ||
| key=DEFAULT_TASK_KEY, | ||
| model_params=model_params, | ||
| training_data_params=training_params["training_data"], | ||
| validation_data_params=validation_data_params, | ||
| stat_file=training_params.get("stat_file"), | ||
| valid_numb_batch=_valid_numb_batch(validation_data_params), | ||
| ) | ||
| return | ||
|
|
||
| data_dict = training_params["data_dict"] | ||
| for task_key, task_model_params in model_params["model_dict"].items(): | ||
| task_data_params = data_dict[task_key] | ||
| validation_data_params = task_data_params.get("validation_data") | ||
| yield TrainingTaskConfig( | ||
| key=task_key, | ||
| model_params=task_model_params, | ||
| training_data_params=task_data_params["training_data"], | ||
| validation_data_params=validation_data_params, | ||
| stat_file=task_data_params.get("stat_file"), | ||
| valid_numb_batch=_valid_numb_batch(validation_data_params), | ||
| ) | ||
|
|
||
|
|
||
| def make_task_maps( | ||
| config: Mapping[str, Any], | ||
| factory: Callable[[TrainingTaskConfig], tuple[Any, Any | None, Any | None]], | ||
| ) -> tuple[dict[str, Any], dict[str, Any | None], dict[str, Any | None]]: | ||
| """Build training, validation, and stat maps from normalized task configs.""" | ||
| training_data: dict[str, Any] = {} | ||
| validation_data: dict[str, Any | None] = {} | ||
| stat_data: dict[str, Any | None] = {} | ||
| for task_config in iter_training_task_configs(config): | ||
| train_item, valid_item, stat_item = factory(task_config) | ||
| training_data[task_config.key] = train_item | ||
| validation_data[task_config.key] = valid_item | ||
| stat_data[task_config.key] = stat_item | ||
| return training_data, validation_data, stat_data | ||
|
|
||
|
|
||
| def print_data_summaries( | ||
| training_data: Mapping[str, Any], | ||
| validation_data: Mapping[str, Any | None], | ||
| *, | ||
| probabilities: Mapping[str, float] | None = None, | ||
| ) -> None: | ||
| """Print train/validation data summaries for one or more tasks.""" | ||
| multi_task = len(training_data) > 1 | ||
| for task_key, data in training_data.items(): | ||
| name = f"training data({task_key})" if multi_task else "training" | ||
| _print_summary(data, name, _task_probability(probabilities, task_key)) | ||
| valid_data = validation_data.get(task_key) | ||
| if valid_data is not None: | ||
| name = f"validation data({task_key})" if multi_task else "validation" | ||
| _print_summary(valid_data, name, None) | ||
|
|
||
|
|
||
| def _valid_numb_batch(validation_data_params: Mapping[str, Any] | None) -> int: | ||
| if validation_data_params is None: | ||
| return 1 | ||
| return max(int(validation_data_params.get("numb_btch", 1)), 1) | ||
|
|
||
|
|
||
| def _task_probability( | ||
| probabilities: Mapping[str, float] | None, | ||
| task_key: str, | ||
| ) -> list[float] | None: | ||
| if probabilities is None or task_key not in probabilities: | ||
| return None | ||
| return [float(probabilities[task_key])] | ||
|
|
||
|
|
||
| def _print_summary(data: Any, name: str, prob: list[float] | None) -> None: | ||
| printer = data.print_summary | ||
| try: | ||
| signature = inspect.signature(printer) | ||
| except (TypeError, ValueError): | ||
| printer(name, prob) | ||
| return | ||
| try: | ||
| signature.bind(name, prob) | ||
| except TypeError as exc: | ||
| try: | ||
| signature.bind(name) | ||
| except TypeError: | ||
| raise exc from None | ||
| printer(name) | ||
| else: | ||
| printer(name, prob) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| """Backend-independent training entrypoint pipeline.""" | ||
|
|
||
| from __future__ import ( | ||
| annotations, | ||
| ) | ||
|
|
||
| import json | ||
| import logging | ||
| from abc import ( | ||
| ABC, | ||
| abstractmethod, | ||
| ) | ||
| from dataclasses import ( | ||
| dataclass, | ||
| ) | ||
| from typing import ( | ||
| Any, | ||
| ) | ||
|
|
||
| from deepmd.common import ( | ||
| j_loader, | ||
| ) | ||
| from deepmd.utils.argcheck import ( | ||
| normalize, | ||
| ) | ||
| from deepmd.utils.compat import ( | ||
| update_deepmd_input, | ||
| ) | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class TrainEntrypointOptions: | ||
| """Common command options for backend train entrypoints.""" | ||
|
|
||
| input_file: str | ||
| output: str = "out.json" | ||
| init_model: str | None = None | ||
| restart: str | None = None | ||
| init_frz_model: str | None = None | ||
| finetune: str | None = None | ||
| model_branch: str = "" | ||
| use_pretrain_script: bool = False | ||
| skip_neighbor_stat: bool = False | ||
|
|
||
|
|
||
| class AbstractTrainEntrypoint(ABC): | ||
| """Shared pipeline for backend train entrypoints. | ||
|
|
||
| Backend subclasses keep ownership of backend-specific feature handling, | ||
| neighbor-stat updates, distributed setup, data construction, and trainer | ||
| construction. This pipeline only coordinates the common command flow. | ||
| """ | ||
|
|
||
| def run(self, options: TrainEntrypointOptions) -> None: | ||
| """Run the training entrypoint.""" | ||
| log.info("Configuration path: %s", options.input_file) | ||
| options = self.prepare_options(options) | ||
| config = self.load_config(options.input_file) | ||
| self.validate_options(config, options) | ||
|
|
||
| config = self.preprocess_config(config, options) | ||
| multi_task = self.is_multi_task(config) | ||
| config = self.update_input(config) | ||
| config = self.normalize_config(config, multi_task=multi_task) | ||
|
|
||
| neighbor_stat = None | ||
| if not options.skip_neighbor_stat: | ||
| config, neighbor_stat = self.update_neighbor_stat( | ||
| config, | ||
| options, | ||
| multi_task=multi_task, | ||
| ) | ||
|
|
||
| self.dump_config(config, options.output) | ||
| self.print_summary() | ||
|
|
||
| try: | ||
| self.setup_run(options, config) | ||
| self.run_training(config, options, neighbor_stat) | ||
| finally: | ||
| self.teardown_run(options, config) | ||
|
|
||
| def prepare_options( | ||
| self, | ||
| options: TrainEntrypointOptions, | ||
| ) -> TrainEntrypointOptions: | ||
| """Normalize command options before reading or preprocessing config.""" | ||
| return options | ||
|
|
||
| def load_config(self, input_file: str) -> dict[str, Any]: | ||
| """Load the JSON/YAML training config.""" | ||
| return j_loader(input_file) | ||
|
|
||
| def validate_options( | ||
| self, | ||
| config: dict[str, Any], | ||
| options: TrainEntrypointOptions, | ||
| ) -> None: | ||
| """Validate backend feature support before mutating the config.""" | ||
| return None | ||
|
|
||
| def preprocess_config( | ||
| self, | ||
| config: dict[str, Any], | ||
| options: TrainEntrypointOptions, | ||
| ) -> dict[str, Any]: | ||
| """Apply backend-specific config preprocessing before argcheck.""" | ||
| return config | ||
|
|
||
| def is_multi_task(self, config: dict[str, Any]) -> bool: | ||
| """Return whether the config is in multi-task layout.""" | ||
| return "model_dict" in config.get("model", {}) | ||
|
|
||
| def update_input(self, config: dict[str, Any]) -> dict[str, Any]: | ||
| """Apply DeePMD input-version compatibility conversion.""" | ||
| return update_deepmd_input(config, warning=True, dump="input_v2_compat.json") | ||
|
|
||
| def normalize_config( | ||
| self, | ||
| config: dict[str, Any], | ||
| *, | ||
| multi_task: bool, | ||
| ) -> dict[str, Any]: | ||
| """Run DeePMD argcheck normalization.""" | ||
| return normalize(config, multi_task=multi_task) | ||
|
|
||
| def update_neighbor_stat( | ||
| self, | ||
| config: dict[str, Any], | ||
| options: TrainEntrypointOptions, | ||
| *, | ||
| multi_task: bool, | ||
| ) -> tuple[dict[str, Any], Any]: | ||
| """Update descriptor selections from neighbor statistics.""" | ||
| return config, None | ||
|
|
||
| def dump_config(self, config: dict[str, Any], output: str) -> None: | ||
| """Dump the normalized config used for training.""" | ||
| with open(output, "w") as fp: | ||
| json.dump(config, fp, indent=4) | ||
|
|
||
| def print_summary(self) -> None: | ||
| """Print backend summary information.""" | ||
| return None | ||
|
|
||
| def setup_run( | ||
| self, | ||
| options: TrainEntrypointOptions, | ||
| config: dict[str, Any], | ||
| ) -> None: | ||
| """Set up backend runtime state before trainer execution.""" | ||
| return None | ||
|
|
||
| def teardown_run( | ||
| self, | ||
| options: TrainEntrypointOptions, | ||
| config: dict[str, Any], | ||
| ) -> None: | ||
| """Tear down backend runtime state after trainer execution.""" | ||
| return None | ||
|
|
||
| @abstractmethod | ||
| def run_training( | ||
| self, | ||
| config: dict[str, Any], | ||
| options: TrainEntrypointOptions, | ||
| neighbor_stat: Any, | ||
| ) -> None: | ||
| """Build backend data/trainer objects and run training.""" | ||
| raise NotImplementedError | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.