diff --git a/deepmd/backend/pretrained.py b/deepmd/backend/pretrained.py new file mode 100644 index 0000000000..f6233fc3a0 --- /dev/null +++ b/deepmd/backend/pretrained.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Callable, +) +from typing import ( + TYPE_CHECKING, + ClassVar, +) + +from deepmd.backend.backend import ( + Backend, +) +from deepmd.pretrained.registry import ( + available_model_names, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("pretrained") +class PretrainedBackend(Backend): + """Internal virtual backend for pretrained model-name alias dispatch. + + This backend is not intended to be selected explicitly by users as a real + compute backend (such as TensorFlow/PyTorch/Paddle/JAX). It only bridges + built-in pretrained model names into the regular deep-eval loading path. + + For convenience, all built-in pretrained model names are registered as + suffix-like aliases, so users can pass model names directly, e.g. + ``DeepPot("DPA-3.2-5M")``. + """ + + name = "Pretrained" + features: ClassVar[Backend.Feature] = Backend.Feature.DEEP_EVAL + suffixes: ClassVar[list[str]] = [ + *[model_name.lower() for model_name in available_model_names()], + ] + + def is_available(self) -> bool: + return True + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + raise NotImplementedError("Unsupported backend: pretrained") + + @property + def deep_eval(self) -> type["DeepEvalBackend"]: + from deepmd.pretrained.deep_eval import ( + PretrainedDeepEvalBackend, + ) + + return PretrainedDeepEvalBackend + + @property + def neighbor_stat(self) -> type["NeighborStat"]: + raise NotImplementedError("Unsupported backend: pretrained") + + @property + def serialize_hook(self) -> Callable[[str], dict]: + raise NotImplementedError("Unsupported backend: pretrained") + + @property + def deserialize_hook(self) -> Callable[[str, dict], None]: + raise NotImplementedError("Unsupported backend: pretrained") diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index 34ebe4d2e3..86c9687bd4 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -39,6 +39,9 @@ from deepmd.loggers.loggers import ( set_log_handles, ) +from deepmd.pretrained.entrypoints import ( + pretrained_entrypoint, +) def main(args: argparse.Namespace) -> None: @@ -97,5 +100,7 @@ def main(args: argparse.Namespace) -> None: convert_backend(**dict_args) elif args.command == "show": show(**dict_args) + elif args.command == "pretrained": + pretrained_entrypoint(args) else: raise ValueError(f"Unknown command: {args.command}") diff --git a/deepmd/main.py b/deepmd/main.py index 62118ae3c6..3afcda8b4a 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -20,6 +20,9 @@ from deepmd.backend.backend import ( Backend, ) +from deepmd.pretrained.registry import ( + available_model_names, +) try: from deepmd._version import version as __version__ @@ -942,6 +945,35 @@ def main_parser() -> argparse.ArgumentParser: ], nargs="+", ) + + # pretrained + parser_pretrained = subparsers.add_parser( + "pretrained", + parents=[parser_log], + help="Manage builtin pretrained models", + formatter_class=RawTextArgumentDefaultsHelpFormatter, + ) + pretrained_subparsers = parser_pretrained.add_subparsers( + dest="pretrained_command", + required=True, + ) + parser_pretrained_download = pretrained_subparsers.add_parser( + "download", + help="Download one pretrained model", + ) + + parser_pretrained_download.add_argument( + "MODEL", + choices=available_model_names(), + help="Pretrained model name", + ) + parser_pretrained_download.add_argument( + "--cache-dir", + default=None, + type=str, + help="Optional cache directory for pretrained model files", + ) + return parser @@ -997,6 +1029,7 @@ def main(args: list[str] | None = None) -> None: "gui", "convert-backend", "show", + "pretrained", ): # common entrypoints from deepmd.entrypoints.main import main as deepmd_main diff --git a/deepmd/pretrained/__init__.py b/deepmd/pretrained/__init__.py new file mode 100644 index 0000000000..0da3f0dbb0 --- /dev/null +++ b/deepmd/pretrained/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Pretrained model helpers for DeePMD-kit.""" diff --git a/deepmd/pretrained/deep_eval.py b/deepmd/pretrained/deep_eval.py new file mode 100644 index 0000000000..2dc671b0cc --- /dev/null +++ b/deepmd/pretrained/deep_eval.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeepEval adapter for pretrained model-name aliases.""" + +from __future__ import ( + annotations, +) + +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +from deepmd.infer.deep_eval import ( + DeepEval, + DeepEvalBackend, +) +from deepmd.pretrained.download import ( + resolve_model_path, +) +from deepmd.pretrained.registry import ( + MODEL_REGISTRY, +) + +if TYPE_CHECKING: + import numpy as np + + +class InvalidPretrainedAliasError(ValueError): + """Raised when a pretrained alias string is malformed.""" + + def __init__(self, model_file: str) -> None: + super().__init__(f"Invalid pretrained model name: {model_file}") + + +def parse_pretrained_alias(model_file: str) -> str: + """Extract built-in pretrained model name from alias string. + + Accepted form: + - ```` where ```` is a built-in registry name + """ + alias = Path(model_file).name + + if alias in MODEL_REGISTRY: + return alias + + lowered = alias.lower() + for model_name in MODEL_REGISTRY: + if model_name.lower() == lowered: + return model_name + + raise InvalidPretrainedAliasError(model_file) + + +class PretrainedDeepEvalBackend(DeepEvalBackend): + """Resolve alias and delegate to backend selected by resolved model path.""" + + def __init__( + self, + model_file: str, + output_def: object, + *args: object, + auto_batch_size: object = True, + neighbor_list: object | None = None, + **kwargs: object, + ) -> None: + model_name = parse_pretrained_alias(model_file) + resolved = str(resolve_model_path(model_name)) + + # DeepEvalBackend.__new__ dispatches by resolved suffix (.pt/.pb/.dp...) + self._backend = DeepEvalBackend( + resolved, + output_def, + *args, + auto_batch_size=auto_batch_size, + neighbor_list=neighbor_list, + **kwargs, + ) + + def eval( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + atomic: bool = False, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> dict[str, np.ndarray]: + return self._backend.eval( + coords, + cells, + atom_types, + atomic, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + + def eval_descriptor( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + efield: np.ndarray | None = None, + mixed_type: bool = False, + **kwargs: Any, + ) -> np.ndarray: + return self._backend.eval_descriptor( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + efield=efield, + mixed_type=mixed_type, + **kwargs, + ) + + def eval_fitting_last_layer( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> np.ndarray: + return self._backend.eval_fitting_last_layer( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + + def get_rcut(self) -> float: + return self._backend.get_rcut() + + def get_ntypes(self) -> int: + return self._backend.get_ntypes() + + def get_type_map(self) -> list[str]: + return self._backend.get_type_map() + + def get_dim_fparam(self) -> int: + return self._backend.get_dim_fparam() + + def has_default_fparam(self) -> bool: + return self._backend.has_default_fparam() + + def get_dim_aparam(self) -> int: + return self._backend.get_dim_aparam() + + @property + def model_type(self) -> type[DeepEval]: + return self._backend.model_type + + def get_sel_type(self) -> list[int]: + return self._backend.get_sel_type() + + def get_numb_dos(self) -> int: + return self._backend.get_numb_dos() + + def get_has_efield(self) -> bool: + return self._backend.get_has_efield() + + def get_has_spin(self) -> bool: + return self._backend.get_has_spin() + + def get_has_hessian(self) -> bool: + return self._backend.get_has_hessian() + + def get_var_name(self) -> str: + return self._backend.get_var_name() + + def get_ntypes_spin(self) -> int: + return self._backend.get_ntypes_spin() + + def get_model(self) -> Any: + return self._backend.get_model() diff --git a/deepmd/pretrained/download.py b/deepmd/pretrained/download.py new file mode 100644 index 0000000000..6a3fb7d8cc --- /dev/null +++ b/deepmd/pretrained/download.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Download and resolve pretrained model files.""" + +from __future__ import ( + annotations, +) + +import concurrent.futures +import hashlib +import logging +import shutil +import time +import urllib.error +import urllib.parse +import urllib.request +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +from deepmd.pretrained.registry import ( + MODEL_REGISTRY, +) + +DEFAULT_CACHE_DIR = Path.home() / ".cache" / "deepmd" / "pretrained" / "models" +DOWNLOAD_TIMEOUT_SECONDS = 120 +SOURCE_PROBE_TIMEOUT_SECONDS = 8 + + +def _validate_download_url(url: str) -> None: + """Validate that download URL uses HTTPS scheme.""" + parsed = urllib.parse.urlparse(url) + if parsed.scheme != "https": + raise ValueError(f"Unsupported URL scheme for download: {parsed.scheme}") + + +def _sha256sum(path: Path) -> str: + """Calculate SHA256 checksum of a file.""" + hasher = hashlib.sha256() + with path.open("rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +def _model_download_urls(model_info: dict[str, Any]) -> list[str]: + """Return candidate download URLs (deduplicated and ordered).""" + candidates: list[str] = [] + raw_urls = model_info.get("urls") + if isinstance(raw_urls, list): + candidates.extend(item for item in raw_urls if isinstance(item, str)) + + if not candidates and isinstance(model_info.get("url"), str): + # backward compatibility + candidates.append(model_info["url"]) + + seen: set[str] = set() + unique: list[str] = [] + for url in candidates: + if url not in seen: + seen.add(url) + unique.append(url) + return unique + + +def _probe_download_url(url: str) -> float | None: + """Probe one URL and return latency seconds if reachable; else None.""" + _validate_download_url(url) + request = urllib.request.Request( + url, + headers={"Range": "bytes=0-0"}, + method="GET", + ) + start = time.monotonic() + try: + with urllib.request.urlopen(request, timeout=SOURCE_PROBE_TIMEOUT_SECONDS): + pass + except (urllib.error.URLError, OSError, ValueError): + return None + + return time.monotonic() - start + + +def _rank_download_urls(urls: list[str]) -> list[str]: + """Rank candidate URLs by probe latency (fastest first).""" + if len(urls) <= 1: + return urls + + results: dict[str, float] = {} + with concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len(urls))) as exe: + future_to_url = {exe.submit(_probe_download_url, url): url for url in urls} + for future in concurrent.futures.as_completed(future_to_url): + url = future_to_url[future] + latency = future.result() + if latency is not None: + results[url] = latency + + ranked_ok = sorted(results, key=lambda url: results[url]) + ranked_fail = [url for url in urls if url not in results] + return ranked_ok + ranked_fail + + +def _download_file(url: str, destination: Path) -> None: + """Download URL content to destination atomically.""" + _validate_download_url(url) + destination.parent.mkdir(parents=True, exist_ok=True) + tmp_path = destination.with_suffix(destination.suffix + ".part") + + try: + with ( + urllib.request.urlopen(url, timeout=DOWNLOAD_TIMEOUT_SECONDS) as response, + tmp_path.open("wb") as out_file, + ): + shutil.copyfileobj(response, out_file) + except Exception: + tmp_path.unlink(missing_ok=True) + raise + + tmp_path.replace(destination) + + +def download_model( + model_name: str, + *, + cache_dir: Path | None = None, + logger: logging.Logger | None = None, +) -> Path: + """Download one model and return local path. + + The function will probe all configured sources, try the fastest reachable + source first, and then fallback to others when failure happens. + """ + log = logger or logging.getLogger(__name__) + + model_info = MODEL_REGISTRY.get(model_name) + if model_info is None: + available = ", ".join(sorted(MODEL_REGISTRY)) + raise ValueError(f"Unknown model: {model_name}. Available: {available}") + + target_dir = cache_dir or DEFAULT_CACHE_DIR + output_path = target_dir / str(model_info["filename"]) + expected_sha256 = str(model_info["sha256"]) + + if output_path.exists(): + actual = _sha256sum(output_path) + if actual == expected_sha256: + log.info("Model '%s' already exists at: %s", model_name, output_path) + return output_path + log.warning( + "Cached file for '%s' failed SHA256 check, re-downloading...", + model_name, + ) + output_path.unlink(missing_ok=True) + + urls = _model_download_urls(model_info) + if not urls: + raise RuntimeError(f"No download URL configured for model '{model_name}'") + + ranked_urls = _rank_download_urls(urls) + if len(ranked_urls) > 1: + log.info( + "Selecting fastest source among %d candidates...", + len(ranked_urls), + ) + + for idx, url in enumerate(ranked_urls, start=1): + log.info( + "Downloading '%s' (source %d/%d): %s", + model_name, + idx, + len(ranked_urls), + url, + ) + try: + _download_file(url, output_path) + except (urllib.error.URLError, OSError, ValueError) as exc: + log.warning("Download attempt failed from %s: %s", url, exc) + continue + + actual = _sha256sum(output_path) + if actual != expected_sha256: + output_path.unlink(missing_ok=True) + log.warning("SHA256 verification failed from source: %s", url) + log.warning("Expected: %s", expected_sha256) + log.warning("Actual: %s", actual) + continue + + log.info("Downloaded '%s' to: %s", model_name, output_path) + return output_path + + raise RuntimeError(f"Failed to download model '{model_name}' from all sources") + + +def resolve_model_path( + model_name: str, + *, + cache_dir: Path | None = None, + logger: logging.Logger | None = None, +) -> Path: + """Resolve model alias to verified local file, downloading if needed.""" + target_dir = cache_dir or DEFAULT_CACHE_DIR + model_info = MODEL_REGISTRY.get(model_name) + if model_info is None: + available = ", ".join(sorted(MODEL_REGISTRY)) + raise ValueError(f"Unknown model: {model_name}. Available: {available}") + + return download_model(model_name, cache_dir=target_dir, logger=logger) diff --git a/deepmd/pretrained/entrypoints.py b/deepmd/pretrained/entrypoints.py new file mode 100644 index 0000000000..559f85e839 --- /dev/null +++ b/deepmd/pretrained/entrypoints.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""CLI entrypoint for pretrained model operations.""" + +from __future__ import ( + annotations, +) + +import logging +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, +) + +from deepmd.pretrained.download import ( + download_model, +) + +if TYPE_CHECKING: + import argparse + + +def pretrained_entrypoint(args: argparse.Namespace) -> None: + """Handle `dp pretrained ...` subcommands.""" + if args.pretrained_command == "download": + cache_dir = Path(args.cache_dir) if args.cache_dir else None + path = download_model(args.MODEL, cache_dir=cache_dir) + logging.getLogger(__name__).info("Pretrained model path: %s", path) + print(path) # noqa: T201 + return + + raise ValueError(f"Unknown pretrained subcommand: {args.pretrained_command}") diff --git a/deepmd/pretrained/registry.py b/deepmd/pretrained/registry.py new file mode 100644 index 0000000000..b540cd3490 --- /dev/null +++ b/deepmd/pretrained/registry.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Registry of built-in pretrained model sources.""" + +from typing import ( + Any, +) + +MODEL_REGISTRY: dict[str, dict[str, Any]] = { + "DPA-3.2-5M": { + "urls": [ + "https://huggingface.co/deepmodelingcommunity/DPA-3.2-5M/resolve/main/DPA-3.2-5M.pt?download=true", + "https://hf-mirror.com/deepmodelingcommunity/DPA-3.2-5M/resolve/main/DPA-3.2-5M.pt?download=true", + ], + "filename": "DPA-3.2-5M.pt", + "sha256": "876354744aeaae17b2639a6a690514470273784f2b4836280850f50cbb799165", + }, + "DPA-3.1-3M": { + "urls": [ + "https://huggingface.co/deepmodelingcommunity/DPA-3.1-3M/resolve/main/DPA-3.1-3M.pt?download=true", + "https://hf-mirror.com/deepmodelingcommunity/DPA-3.1-3M/resolve/main/DPA-3.1-3M.pt?download=true", + ], + "filename": "DPA-3.1-3M.pt", + "sha256": "86dd3a804d78ca5d203ebf98747e8f16dff9713ba8950097ceb760b161e19907", + }, +} + + +def available_model_names() -> list[str]: + """Return available model names from built-in registry.""" + return sorted(MODEL_REGISTRY) diff --git a/doc/model/index.rst b/doc/model/index.rst index 4896ccdc4e..a173732bbc 100644 --- a/doc/model/index.rst +++ b/doc/model/index.rst @@ -29,3 +29,4 @@ Model change-bias precision show-model-info + pretrained diff --git a/doc/model/pretrained.md b/doc/model/pretrained.md new file mode 100644 index 0000000000..e4fd3481de --- /dev/null +++ b/doc/model/pretrained.md @@ -0,0 +1,49 @@ +# Use `dp pretrained` to download built-in models + +The `dp pretrained` command provides a simple way to download built-in pre-trained models and store them in a local cache. + +## Command syntax + +```bash +dp pretrained download [--cache-dir ] +``` + +- ``: the built-in model name. +- `--cache-dir `: optional cache directory. If omitted, DeePMD-kit uses the default cache path. + +## Available built-in models + +You can run `dp pretrained download -h` to see the currently supported model list in your installed version. + +Examples in this release include: + +- `DPA-3.2-5M` +- `DPA-3.1-3M` + +## Examples + +```bash +# Download to default cache directory +dp pretrained download DPA-3.2-5M + +# Download to a custom cache directory +dp pretrained download DPA-3.2-5M --cache-dir ./models +``` + +The command prints the local path of the downloaded model file on success. + +## Use downloaded models in DeepPot + +Using `DeepPot`, you do **not** have to run `dp pretrained download` first. + +Pass the built-in model name directly: + +```python +from deepmd.infer import DeepPot + +pot = DeepPot("DPA-3.2-5M") +``` + +If the model file is not already present in the local cache, DeePMD-kit will download and cache it automatically when resolving the model name. + +Built-in model names are user-facing selectors; backend details are handled internally. diff --git a/source/tests/common/test_pretrained_backend.py b/source/tests/common/test_pretrained_backend.py new file mode 100644 index 0000000000..e1502355b2 --- /dev/null +++ b/source/tests/common/test_pretrained_backend.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for pretrained backend registration and alias parsing.""" + +import importlib +import unittest + +from deepmd.backend.backend import ( + Backend, +) +from deepmd.backend.pretrained import ( + PretrainedBackend, +) +from deepmd.pretrained.deep_eval import ( + parse_pretrained_alias, +) + + +class TestPretrainedBackend(unittest.TestCase): + """Test pretrained backend integration points.""" + + @classmethod + def setUpClass(cls) -> None: + # ensure backend registration side effects are loaded + importlib.import_module("deepmd.backend") + + def test_detect_backend_by_model_name(self) -> None: + backend = Backend.detect_backend_by_model("DPA-3.2-5M") + self.assertIs(backend, PretrainedBackend) + + def test_detect_backend_by_pretrained_suffix_not_supported(self) -> None: + with self.assertRaises(ValueError): + Backend.detect_backend_by_model("DPA-3.2-5M.pretrained") + + def test_parse_pretrained_alias_plain_name(self) -> None: + self.assertEqual(parse_pretrained_alias("DPA-3.2-5M"), "DPA-3.2-5M") + self.assertEqual(parse_pretrained_alias("dpa-3.2-5m"), "DPA-3.2-5M") + + def test_parse_pretrained_alias_invalid(self) -> None: + with self.assertRaises(ValueError): + parse_pretrained_alias("DPA-3.2-5M.pt") + with self.assertRaises(ValueError): + parse_pretrained_alias("DPA-3.2-5M.pretrained") + + def test_deep_eval_property(self) -> None: + from deepmd.pretrained.deep_eval import ( + PretrainedDeepEvalBackend, + ) + + self.assertIs(PretrainedBackend().deep_eval, PretrainedDeepEvalBackend) diff --git a/source/tests/common/test_pretrained_download.py b/source/tests/common/test_pretrained_download.py new file mode 100644 index 0000000000..b943c1f247 --- /dev/null +++ b/source/tests/common/test_pretrained_download.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for pretrained download/resolve helpers.""" + +from __future__ import ( + annotations, +) + +import hashlib +import tempfile +import unittest +import urllib.error +from pathlib import ( + Path, +) +from unittest.mock import ( + patch, +) + +from deepmd.pretrained import download as dl + + +class TestPretrainedDownload(unittest.TestCase): + """Test download helper behavior.""" + + def test_model_download_urls_prefers_urls(self) -> None: + info = { + "urls": ["https://a", "https://a", "https://b"], + "url": "https://legacy", + } + self.assertEqual(dl._model_download_urls(info), ["https://a", "https://b"]) + + def test_rank_download_urls(self) -> None: + with patch.object( + dl, + "_probe_download_url", + side_effect=lambda url: { + "https://a": 0.3, + "https://b": 0.1, + "https://c": None, + }[url], + ): + ranked = dl._rank_download_urls(["https://a", "https://b", "https://c"]) + + self.assertEqual(ranked, ["https://b", "https://a", "https://c"]) + + def test_download_model_fallback_on_failure(self) -> None: + payload = b"payload" + expected = hashlib.sha256(payload).hexdigest() + model_name = "DPA-3.2-5M" + + with tempfile.TemporaryDirectory() as td: + cache_dir = Path(td) + + with patch.object( + dl, + "MODEL_REGISTRY", + { + model_name: { + "filename": "model.pt", + "sha256": expected, + "urls": ["https://a", "https://b"], + } + }, + ): + with patch.object( + dl, + "_rank_download_urls", + return_value=["https://a", "https://b"], + ): + + def fake_download(url: str, destination: Path) -> None: + if url == "https://a": + raise urllib.error.URLError("timeout") + destination.parent.mkdir(parents=True, exist_ok=True) + destination.write_bytes(payload) + + with patch.object(dl, "_download_file", side_effect=fake_download): + path = dl.download_model(model_name, cache_dir=cache_dir) + + self.assertTrue(path.exists()) + self.assertEqual(path.read_bytes(), payload) + + def test_resolve_model_path_cached(self) -> None: + payload = b"payload" + expected = hashlib.sha256(payload).hexdigest() + model_name = "DPA-3.2-5M" + + with tempfile.TemporaryDirectory() as td: + cache_dir = Path(td) + target = cache_dir / "model.pt" + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(payload) + + with patch.object( + dl, + "MODEL_REGISTRY", + { + model_name: { + "filename": "model.pt", + "sha256": expected, + "urls": ["https://a"], + } + }, + ): + with patch.object(dl, "_download_file") as mocked_download: + path = dl.resolve_model_path(model_name, cache_dir=cache_dir) + + self.assertEqual(path, target) + mocked_download.assert_not_called() diff --git a/source/tests/common/test_pretrained_parser.py b/source/tests/common/test_pretrained_parser.py new file mode 100644 index 0000000000..a31af78184 --- /dev/null +++ b/source/tests/common/test_pretrained_parser.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for pretrained argument parsing.""" + +import unittest + +from deepmd.main import ( + parse_args, +) +from deepmd.pretrained.registry import ( + available_model_names, +) + + +class TestPretrainedParser(unittest.TestCase): + """Test `dp pretrained` parser behavior.""" + + def test_pretrained_download_parser(self) -> None: + model = available_model_names()[0] + args = parse_args(["pretrained", "download", model]) + + self.assertEqual(args.command, "pretrained") + self.assertEqual(args.pretrained_command, "download") + self.assertEqual(args.MODEL, model) + self.assertIsNone(args.cache_dir) + + def test_pretrained_download_with_cache_dir(self) -> None: + model = available_model_names()[0] + args = parse_args( + [ + "pretrained", + "download", + model, + "--cache-dir", + "/tmp/deepmd-pretrained", + ] + ) + + self.assertEqual(args.cache_dir, "/tmp/deepmd-pretrained") + + def test_pretrained_download_rejects_unknown_model(self) -> None: + with self.assertRaises(SystemExit): + parse_args(["pretrained", "download", "NOT-EXIST"])