diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 9d338dd4a2..d690cb054c 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -47,6 +47,7 @@ jobs: HOROVOD_WITH_MPI: 1 # https://cmake.org/cmake/help/latest/variable/CMAKE_POLICY_VERSION_MINIMUM.html CMAKE_POLICY_VERSION_MINIMUM: 3.5 + UV_HTTP_TIMEOUT: 120 - run: dp --version - name: Get durations from cache uses: actions/cache@v6 diff --git a/README.md b/README.md index 5ca8080e77..a07db0cf9c 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/). - **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc. - **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing. - **highly modularized**, easy to adapt to different descriptors for deep learning-based potential energy models. +- **adapts pre-trained DPA models to downstream atomistic property prediction tasks with DPA-ADAPT**, a new Python API and CLI that supports frozen-descriptor scikit-learn heads, frozen property-head training, full end-to-end fine-tuning, and multi-task fine-tuning with an auxiliary force-field task. DPA-ADAPT trains on `deepmd/npy` systems and provides conversion pipelines for SMILES tables and structure or calculation files handled through dpdata. See the [DPA-ADAPT guide](doc/dpa_adapt/overview.md) and supported [input formats](doc/dpa_adapt/input_formats.md). ### License and credits @@ -103,6 +104,7 @@ The code is organized as follows: - `examples`: examples. - `deepmd`: DeePMD-kit python modules. +- `dpa_adapt`: DPA-ADAPT package for adapting pre-trained DPA models; see the [guide](doc/dpa_adapt/overview.md) and [input formats](doc/dpa_adapt/input_formats.md). - `source/lib`: source code of the core library. - `source/op`: Operator (OP) implementation. - `source/api_cc`: source code of DeePMD-kit C++ API. diff --git a/deepmd/__about__.py b/deepmd/__about__.py new file mode 100644 index 0000000000..828f79c7f7 --- /dev/null +++ b/deepmd/__about__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +__version__ = "0.0.0" diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 6c66d73a95..dedf8a0711 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -307,8 +307,8 @@ def save_txt_file( append : bool, optional if true file will be appended instead of overwriting, by default False """ - flags = "ab" if append else "w" - with fname.open(flags) as fp: + flags = "a" if append else "w" + with fname.open(flags, encoding="utf-8") as fp: np.savetxt(fp, data, header=header) diff --git a/doc/cli.rst b/doc/cli.rst index 15891369e3..ea3060698d 100644 --- a/doc/cli.rst +++ b/doc/cli.rst @@ -1,7 +1,7 @@ .. _cli: -Command line interface -====================== +DeePMD-kit command line interface +================================= .. argparse:: :module: deepmd.tf.entrypoints.main diff --git a/doc/conf.py b/doc/conf.py index 52c647a20d..c58073d5c0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -7,6 +7,10 @@ # -- Path setup -------------------------------------------------------------- +from __future__ import ( + annotations, +) + import datetime # If extensions (or modules to document with autodoc) are in another directory, @@ -224,3 +228,54 @@ bibtex_bibfiles = ["../CITATIONS.bib"] remove_from_toctrees = ["autoapi/**/*", "API_CC/*", "api_c/*", "api_core/*"] + + +# Auto-generated CLI reference pages (sphinx-argparse) nest a section per +# subcommand and per argument group. Under the global ``:numbered:`` toctree +# this explodes into unhelpful deep numbers (e.g. ``9.3.3.6.3.1.1.``). Cap the +# section numbering at the given depth (number of dotted components); headings +# deeper than that are left unnumbered. Only the listed pages are affected. +from typing import ( + TYPE_CHECKING, +) + +from docutils import ( + nodes, +) + +if TYPE_CHECKING: + from sphinx.application import ( + Sphinx, + ) + +cli_secnumber_max_depth = { + "dpa_adapt/cli": 5, +} + + +def _cap_cli_secnumbers(app: Sphinx, doctree: nodes.document, docname: str) -> None: + """Drop section numbers below ``cli_secnumber_max_depth`` for CLI pages.""" + max_depth = cli_secnumber_max_depth.get(docname) + if max_depth is None: + return + secnumbers = app.env.toc_secnumbers.get(docname) + if not secnumbers: + return + # The empty anchor "" holds the page chapter number (e.g. ``(9, 3)``). + # It must be dropped from the map, otherwise the writer falls back to it for + # the now-unnumbered deep sections; re-attach it to the page title instead. + page_number = secnumbers.get("") + app.env.toc_secnumbers[docname] = { + anchor: number + for anchor, number in secnumbers.items() + if anchor != "" and len(number) <= max_depth + } + if page_number: + for title in doctree.findall(nodes.title): + title["secnumber"] = page_number + break + + +def setup(app: Sphinx) -> dict[str, bool]: + app.connect("doctree-resolved", _cap_cli_secnumbers) + return {"parallel_read_safe": True, "parallel_write_safe": True} diff --git a/doc/dpa_adapt/cli.rst b/doc/dpa_adapt/cli.rst new file mode 100644 index 0000000000..7b612d6823 --- /dev/null +++ b/doc/dpa_adapt/cli.rst @@ -0,0 +1,10 @@ +DPA-ADAPT command line interface +================================ + +The ``dpaad`` command is a short alias for ``dpa-adapt`` and exposes the same +subcommands and options. + +.. argparse:: + :module: dpa_adapt.cli + :func: get_parser + :prog: dpa-adapt diff --git a/doc/dpa_adapt/index.rst b/doc/dpa_adapt/index.rst new file mode 100644 index 0000000000..8f094af891 --- /dev/null +++ b/doc/dpa_adapt/index.rst @@ -0,0 +1,15 @@ +DPA-ADAPT +========= + +DPA-ADAPT adapts pretrained DPA models to downstream atomistic property +prediction tasks through Python APIs and the standalone ``dpa-adapt`` CLI. +It supports frozen-descriptor scikit-learn heads, frozen property-head +training, full fine-tuning, and multi-task fine-tuning with an auxiliary +force-field task. + +.. toctree:: + :maxdepth: 2 + + overview + input_formats + cli diff --git a/doc/dpa_adapt/input_formats.md b/doc/dpa_adapt/input_formats.md new file mode 100644 index 0000000000..fe68c878f4 --- /dev/null +++ b/doc/dpa_adapt/input_formats.md @@ -0,0 +1,146 @@ +# Input Formats + +> **Project/package name:** `dpa-adapt` +> **Python import:** `dpa_adapt` +> **Main CLI:** `dpa-adapt` +> **Optional short alias:** `dpaad` +> **Display name:** DPA-ADAPT — Atomistic DPA Adaptation for Property Tasks + +`dpa-adapt data convert` and the Python `dpa_adapt.convert()` helper +auto-detect the input type and route it to the correct pipeline: +**SMILES table** → RDKit 3D conformer generation, +**structure files** → dpdata (auto-detect or explicit `--fmt`). + +## SMILES Tables (CSV) + +**Trigger:** file extension `.csv` **and** a SMILES column. +By default, the converter reads `SMILES`/`smiles`; use `--smiles-col` for +other column names such as `smi` or `mol`. Or pass `--fmt smiles` explicitly. + +| Parameter | Default | Description | +| ------------------ | ------------- | ------------------------------------------------------------------------------------------------------------------ | +| `--smiles-col` | `SMILES` | Column name for SMILES strings | +| `--property-col` | `Property` | Input table column to read target values from; also used as the output label name | +| `--train-ratio` | `0.9` | Fraction of samples used for training set | +| `--mol-dir` | — | Directory of pre-generated `.mol`, `.sdf`, `.xyz`, or `.pdb` structure files (skips RDKit 3D conformer generation) | +| `--mol-template` | `id{row}.mol` | Filename template under `--mol-dir`; use `{row}` for the CSV row index | +| `--split-seed` | `42` | Random seed for train/valid splitting | +| `--conformer-seed` | `42` | Random seed for RDKit 3D conformer generation | + +```bash +# Auto-detected via SMILES column +dpa-adapt data convert --input molecules.csv --output ./npy \ + --property-col homo +# Short alias +dpaad data convert --input molecules.csv --output ./npy \ + --property-col homo + +# Explicit fmt + custom column names +dpa-adapt data convert --input data.csv --output ./npy --fmt smiles \ + --smiles-col smi --property-col GAP --train-ratio 0.85 \ + --split-seed 42 --conformer-seed 43 +# Short alias +dpaad data convert --input data.csv --output ./npy --fmt smiles \ + --smiles-col smi --property-col GAP --train-ratio 0.85 \ + --split-seed 42 --conformer-seed 43 +``` + +## Structure Files via dpdata + +**Trigger:** inputs not routed to the SMILES pipeline. This means `--fmt` is +not `smiles`; when `--fmt` is omitted, CSV inputs are routed here only if they +do not contain a recognized SMILES column. +Calls dpdata for format auto-detection or explicit conversion. + +### Common Formats + +| `--fmt` value | Typical file(s) | Notes | +| -------------------------------------------------------- | ----------------------------- | -------------------------------------- | +| `xyz` | `*.xyz` | Plain XYZ | +| `vasp/poscar` / `vasp/contcar` | `POSCAR`, `CONTCAR` | VASP input/final structure | +| `vasp/outcar` | `OUTCAR` | VASP output (energies, forces, stress) | +| `vasp/xml` | `vasprun.xml` | VASP XML output | +| `vasp/string` | VASP structure string | VASP structure from a string | +| `abacus/stru` / `stru` | `STRU` | ABACUS input structure | +| `abacus/scf` / `abacus/pw/scf` / `abacus/lcao/scf` | SCF output | ABACUS SCF calculation | +| `abacus/md` / `abacus/pw/md` / `abacus/lcao/md` | MD output | ABACUS molecular dynamics | +| `abacus/relax` / `abacus/pw/relax` / `abacus/lcao/relax` | Relax output | ABACUS relaxation | +| `cp2k/aimd_output` | CP2K MD output | CP2K AIMD output file | +| `cp2k/output` | CP2K SCF output | CP2K single-point output | +| `deepmd/raw` | `set.*/` dirs | DeePMD-kit raw format | +| `deepmd/comp` / `deepmd/npy` | `set.*/` dirs | DeePMD-kit compressed/npy format | +| `deepmd/npy/mixed` | mixed `deepmd/npy` dir | DeePMD-kit mixed npy format | +| `deepmd/hdf5` | `*.hdf5` | DeePMD-kit HDF5 format | +| `lammps/dump` / `dump` | `dump.*` | LAMMPS dump trajectory | +| `lammps/lmp` / `lmp` | `*.lmp` | LAMMPS data file | +| `qe/cp/traj` | CP trajectory | Quantum ESPRESSO Car-Parrinello MD | +| `qe/pw/scf` | PWscf output | Quantum ESPRESSO PWscf | +| `siesta/output` | Siesta output | SIESTA SCF output | +| `siesta/aimd_output` | Siesta MD output | SIESTA AIMD output | +| `gaussian/log` | `*.log` | Gaussian log file | +| `gaussian/fchk` | `*.fchk` | Gaussian formatted checkpoint | +| `gaussian/md` | Gaussian MD output | Gaussian MD trajectory | +| `gaussian/gjf` | `*.gjf` | Gaussian input file | +| `amber/md` | Amber MD output | Amber MD trajectory | +| `gromacs/gro` / `gro` | `*.gro` | GROMACS coordinate file | +| `pwmat/output` / `pwmat/movement` / `pwmat/mlmd` | `REPORT`, `MOVEMENT`, `MLMD` | PWmat output / movement / MLMD | +| `pwmat/final.config` / `pwmat/atom.config` | `final.config`, `atom.config` | PWmat final/input structure | +| `fhi_aims/output` / `fhi_aims/md` | FHI-aims output/MD | FHI-aims calculation or MD trajectory | +| `fhi_aims/scf` | FHI-aims SCF output | FHI-aims SCF | +| `psi4/out` | Psi4 output | Psi4 calculation output | +| `psi4/inp` | Psi4 input | Psi4 input file | +| `orca/spout` | ORCA output | ORCA single-point output | +| `sqm/out` | SQM output | SQM output | +| `sqm/in` | SQM input | SQM input | +| `openmx/md` | OpenMX MD output | OpenMX MD trajectory | +| `n2p2` | n2p2 output | n2p2/NNPack output | +| `dftbplus` | DFTB+ output | DFTB+ detailed.xml | +| `mol` / `mol_file` | `*.mol` | MDL Molfile | +| `sdf` / `sdf_file` | `*.sdf` | MDL SDFile | +| `ase/structure` | Any ASE format | ASE structure (single frame) | +| `ase/traj` | Any ASE trajectory | ASE trajectory (multi-frame) | +| `pymatgen/structure` | pymatgen objects | pymatgen Structure | +| `pymatgen/molecule` | pymatgen objects | pymatgen Molecule | +| `pymatgen/computedstructureentry` | pymatgen objects | pymatgen ComputedStructureEntry | +| `lmdb` | LMDB dir | DeePMD-kit LMDB format | +| `list` | List-format dir | List of system directories | +| `3dmol` | 3Dmol format | 3Dmol.js format | + +You can omit `--fmt` and let dpdata infer the input format from the file name +or content. For example, files named `POSCAR`, `OUTCAR`, or `*.xyz` are often +recognized automatically. Use `--fmt` when the file name is ambiguous or +auto-detection fails. + +### Single file + +```bash +dpa-adapt data convert --input POSCAR --output ./npy +dpaad data convert --input POSCAR --output ./npy + +dpa-adapt data convert --input OUTCAR --output ./npy --fmt vasp/outcar +dpaad data convert --input OUTCAR --output ./npy --fmt vasp/outcar + +dpa-adapt data convert --input traj.xyz --output ./npy --fmt xyz +dpaad data convert --input traj.xyz --output ./npy --fmt xyz +``` + +### Glob patterns + +When `--input` contains wildcards (`*`, `?`, `[`), conversion uses mirrored +batch output: + +- **1 or more matches** → each matched file is converted into an output + directory that mirrors its path relative to the non-wildcard prefix. +- **0 matches** → `FileNotFoundError`. +- A `manifest.json` is written into the output root, recording converted and + skipped files. + +```bash +# Glob output mirrors the input tree under ./npy_root +dpa-adapt data convert --input "calcs/**/OUTCAR" --output ./npy_root --fmt vasp/outcar +dpaad data convert --input "calcs/**/OUTCAR" --output ./npy_root --fmt vasp/outcar +``` + +For example, `calcs/run1/OUTCAR` is written as `npy_root/run1/OUTCAR/`. +When `--strict` is set, the first conversion error fails immediately. Without +it, errors are skipped and logged in the manifest. diff --git a/doc/dpa_adapt/overview.md b/doc/dpa_adapt/overview.md new file mode 100644 index 0000000000..b742ebefb7 --- /dev/null +++ b/doc/dpa_adapt/overview.md @@ -0,0 +1,397 @@ +# DPA-ADAPT: Atomistic DPA Adaptation for Property Tasks + +**DPA-ADAPT** (`dpa-adapt`, Python import `dpa_adapt`) is a toolkit for adapting pretrained DPA models to downstream atomistic property prediction tasks. The main CLI is `dpa-adapt`; the optional short alias is `dpaad`. No DeePMD-kit JSON configs or `dp train` pipelines to write. + +## Installation + +```bash +pip install deepmd-kit[dpa-adapt] +``` + +Installs `scikit-learn`, `dpdata`, `ase`, `rdkit`, and `e3nn` alongside DeePMD-kit. For GPU PyTorch, install your preferred PyTorch build first. + +## Quickstart + +For a complete runnable example (QM9 HOMO–LUMO gap, ~5 min on CPU), see [`../../examples/dpa_adapt/`](../../examples/dpa_adapt/). + +## Fine-tuning strategies + +The strategy is the core choice. All four share the same pre-trained DPA backbone and differ in how much of it gets updated: + +| Strategy | Core Mechanism | Target Data Size | Primary Use Case | +| :--------------- | :---------------------------------------------- | :--------------- | :---------------------------------------------------------------------------- | +| `frozen_sklearn` | Frozen backbone + scikit-learn regressor | Small (\<1k) | Ultra-fast benchmarking & prototyping | +| `frozen_head` | Frozen backbone + DeePMD property fitting head | Medium (1k–10k) | Train only the property head while keeping the pretrained DPA backbone frozen | +| `finetune` | End-to-end full parameter fine-tuning | Large (>10k) | Maximum accuracy on large datasets | +| `mft` | Multi-task co-training (property + force field) | Small / low-data | Mitigating representation collapse | + +### frozen_sklearn — CPU-only, scikit-learn predictor + +Freezes the DPA backbone as a feature extractor and fits a scikit-learn +regressor on the pooled descriptors. No GPU, no `dp train` — fastest path +for small datasets. + +```python +model = DPAFineTuner( + pretrained="DPA-3.1-3M", + strategy="frozen_sklearn", + predictor="rf", # "rf" | "linear" | "mlp" + pooling="mean", # "mean" | "sum" | "mean+std" | "mean+std+max+min" + model_branch=None, # multi-task branch for descriptor extraction + fparam_dim=0, # > 0 reads set.*/fparam.npy and concatenates to descriptor + seed=42, +) +model.fit(train_data="/data/train/*", target_key="homo") +pred = model.predict(data="/data/test") +metrics = model.evaluate(data="/data/test") # .mae, .rmse, .r2 +``` + +| Parameter | Type | Default | Description | +| -------------- | --------------- | -------------- | --------------------------------------------------------------------- | +| `pretrained` | `str` | `"DPA-3.1-3M"` | Checkpoint path or built-in name | +| `predictor` | `str` | `"rf"` | `"rf"` (random forest), `"linear"` (Ridge), `"mlp"` (MLPRegressor) | +| `pooling` | `str` | `"mean"` | `"mean"`, `"sum"`, `"mean+std"`, `"mean+std+max+min"` | +| `model_branch` | `str` or `None` | `None` | Multi-task branch for descriptor extraction (e.g. `"Domains_Drug"`) | +| `fparam_dim` | `int` | `0` | Dimension of per-frame context features; > 0 reads `set.*/fparam.npy` | +| `seed` | `int` | `42` | Random seed for the sklearn head | + +### frozen_head / finetune — dp train with frozen or trainable backbone + +Both delegate to `dp --pt train` and accept the same parameters. The only +difference: `frozen_head` freezes the DPA backbone (train only the fitting +head), while `finetune` updates all parameters end-to-end. + +`frozen_head` suits medium datasets (1k–10k); `finetune` targets large datasets (>10k, GPU required). + +```python +model = DPAFineTuner( + pretrained="DPA-3.1-3M", + strategy="frozen_head", # "frozen_head" | "finetune" + # ---- task ---- + property_name="homo", + task_dim=1, + intensive=True, # True = intensive (mean-pooled), False = extensive + init_branch="SPICE2", # checkpoint branch for descriptor init + # ---- fitting net ---- + fitting_net_params=None, # dict overriding fitting_net fields, e.g. + # { # {"neuron": [128,128,128], "activation_function": "relu"} + # "neuron": [128, 128], # (default: neuron=[240,240,240], tanh, resnet_dt=True) + # "activation_function": "relu", + # }, + # ---- learning rate ---- + learning_rate=1e-3, # start_lr + stop_lr=1e-5, # end_lr + decay_steps=None, # None → 1000; or explicit int + warmup_steps=0, # linear LR warmup (0 = disabled) + # ---- training ---- + max_steps=100_000, + batch_size="auto:512", # deepmd-kit batch_size spec + loss_function="mse", # "mse" | "smooth_mae" + # ---- optional ---- + fparam_dim=0, # > 0 reads set.*/fparam.npy → numb_fparam + seed=42, + # ---- output ---- + output_dir="./dpa_output", + save_freq=10_000, + disp_freq=1_000, +) +model.fit(train_data="/data/train", valid_data="/data/valid") +pred = model.predict(data="/data/test") +metrics = model.evaluate(data="/data/test") # .mae, .rmse, .r2 +``` + +| Parameter | Type | Default | Description | +| -------------------- | ---------------- | ------------------ | ------------------------------------------------------------------------------------- | +| `pretrained` | `str` | `"DPA-3.1-3M"` | Checkpoint path or built-in name | +| `strategy` | `str` | `"frozen_sklearn"` | `"frozen_head"` (freeze backbone) or `"finetune"` (full update) | +| `property_name` | `str` | `"property"` | Label key under `set.*/`, e.g. `"homo"` reads `set.*/homo.npy` | +| `task_dim` | `int` | `1` | Output dimensionality of the property fitting net | +| `intensive` | `bool` | `True` | `True` = mean-pool over atoms (intensive); `False` = sum (extensive) | +| `init_branch` | `str` | `"SPICE2"` | Checkpoint branch used to initialise the descriptor | +| `fitting_net_params` | `dict` or `None` | `None` | Overrides for fitting-net fields (`neuron`, `activation_function`, `resnet_dt`, etc.) | +| `learning_rate` | `float` | `1e-3` | Start learning rate (`start_lr` in deepmd-kit `exp` scheduler) | +| `stop_lr` | `float` | `1e-5` | End learning rate | +| `decay_steps` | `int` or `None` | `None` | Steps between LR decays; `None` → 1000 | +| `warmup_steps` | `int` | `0` | Linear LR warmup steps; 0 = disabled | +| `max_steps` | `int` | `100_000` | Total training steps (`numb_steps`) | +| `batch_size` | `str` or `int` | `"auto:512"` | deepmd-kit batch_size spec (e.g. `"auto:256"` or `128`) | +| `loss_function` | `str` | `"mse"` | `"mse"` or `"smooth_mae"` | +| `fparam_dim` | `int` | `0` | Dimension of per-frame context features; > 0 reads `set.*/fparam.npy` | +| `seed` | `int` | `42` | Random seed (descriptor, fitting net, training) | +| `output_dir` | `str` | `"./dpa_output"` | Directory for `input.json`, checkpoints, and logs | +| `save_freq` | `int` | `10_000` | Checkpoint save interval in steps | +| `disp_freq` | `int` | `1_000` | Log display interval in steps | + +### mft — Multi-task fine-tuning (property + force field) + +Jointly trains a downstream property head with an auxiliary force/energy head +on a shared DPA descriptor, preventing representation collapse on small +datasets. Requires GPU. Inherits all `frozen_head`/`finetune` parameters +plus the MFT-specific ones below. + +```python +model = DPAFineTuner( + pretrained="/path/to/DPA-3.1-3M.pt", + strategy="mft", + # ---- task (same as frozen_head/finetune) ---- + property_name="homo", + task_dim=1, + intensive=True, + init_branch="SPICE2", + # ---- MFT-specific ---- + aux_branch="MP_traj_v024_alldata_mixu", # checkpoint branch for aux force head + aux_prob=0.5, # aux sampling weight (downstream = 1 - aux_prob) + downstream_task_type="property", # "property" (default) | "ener" (legacy) + type_map=None, # global (shared) type map; must be union of + # both datasets' elements (auto-detect) + aux_batch_size=None, # batch size for aux head (None = auto) + downstream_batch_size=None, # batch size for downstream head (None = auto) + # ---- fitting net (aux head only; downstream uses property defaults) ---- + fitting_net_params=None, # None = auto-read from checkpoint + # ---- learning rate ---- + learning_rate=1e-3, + stop_lr=1e-5, + decay_steps=None, # None → 1000 (property) or 5000 (ener) + warmup_steps=0, + # ---- training ---- + max_steps=50_000, + batch_size="auto:32", + # ---- optional ---- + fparam_dim=0, + seed=42, + # ---- output ---- + output_dir="./mft_output", + save_freq=10_000, + disp_freq=1_000, +) +model.fit(train_data="/data/train", aux_data="/data/spice2") +pred = model.predict(data="/data/test") +metrics = model.evaluate(data="/data/test") # .mae, .rmse, .r2 +``` + +**Shared parameters** — all `frozen_head`/`finetune` parameters above also apply to MFT. + +**MFT-specific parameters:** + +| Parameter | Type | Default | Description | +| ----------------------- | --------------------- | ----------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `aux_branch` | `str` | `"MP_traj_v024_alldata_mixu"` | Checkpoint branch to initialize the auxiliary force/energy head. Use `dp --pt show model-branch` to list options. | +| `aux_prob` | `float` | `0.5` | Sampling weight for the aux branch. Downstream weight = `1.0 - aux_prob`. | +| `downstream_task_type` | `str` | `"property"` | `"property"` (intensive scalar head, e.g. HOMO/LUMO) or `"ener"` (force-field head, legacy mode) | +| `type_map` | `list[str]` or `None` | `None` | Global (shared) type map for MFT. Both branches share a single descriptor, so this must be the **union** of all elements appearing in either dataset. Auto-detected from the pretrained checkpoint if `None`. | +| `aux_batch_size` | `str` or `None` | `None` | Batch size for aux head; auto-selected if `None` | +| `downstream_batch_size` | `int` or `None` | `None` | Batch size for downstream head; auto-selected if `None` | +| `fitting_net_params` | `dict` or `None` | `None` | Overrides for the **aux** fitting net; downstream uses property defaults. `None` = auto-read from checkpoint. | + +## Data preparation + +DPA-ADAPT trains on `deepmd/npy` data. Use `dpa-adapt data convert` (or the Python +`convert` helper) to route common inputs into the right conversion pipeline: + +- **SMILES CSV**: a `.csv` file with a `SMILES`/`smiles` column. RDKit generates 3D + conformers, or existing `.mol`/`.sdf`/`.xyz`/`.pdb` files can be supplied with + `mol_dir`. +- **Structure files / trajectories**: POSCAR, OUTCAR, `*.xyz`, `vasprun.xml`, ABACUS, + CP2K, Gaussian, LAMMPS, ASE, `deepmd/raw`, `deepmd/npy`, LMDB, and other dpdata + formats. Omit `fmt` when dpdata can infer it; set `fmt` explicitly for ambiguous + inputs. + +```python +from dpa_adapt import convert + +# Structure file / trajectory → dpdata → deepmd/npy +convert("POSCAR", "./npy") +convert("OUTCAR", "./npy", fmt="vasp/outcar") + +# Glob patterns: one match is converted as one system; multiple matches are batched. +convert("calcs/**/OUTCAR", "./npy_root", fmt="vasp/outcar") + +# CSV with a SMILES column → RDKit 3D conformers → deepmd/npy. +# property_col names the input target column and output label name. +convert( + "molecules.csv", + "./npy", + fmt="smiles", # optional when a SMILES/smiles column is present + smiles_col="SMILES", + property_col="HOMO", + train_ratio=0.9, +) + +# CSV + pre-generated molecular structures: skip RDKit conformer generation. +convert( + "molecules.csv", + "./npy", + fmt="smiles", + smiles_col="SMILES", + property_col="GAP", + mol_dir="./mol_files", + mol_template="id{row}.sdf", +) +``` + +CLI equivalents: + +```bash +# SMILES table +dpa-adapt data convert --input molecules.csv --output ./npy \ + --fmt smiles --smiles-col SMILES --property-col HOMO --train-ratio 0.9 + +# Structure file or glob of calculation outputs +dpa-adapt data convert --input POSCAR --output ./npy +dpa-adapt data convert --input "calcs/**/OUTCAR" --output ./npy_root --fmt vasp/outcar +``` + +Lower-level helpers: + +```python +from dpa_adapt import convert, attach_labels, check_data + +convert("OUTCAR", "./npy", fmt="vasp/outcar") +convert("calcs/**/OUTCAR", "./npy_root", fmt="vasp/outcar") + +# Single system +attach_labels("./npy/", head="bandgap", values=np.array([1.0, 2.0, 3.0])) + +# Multiple systems: values[i] → sorted(glob("npy/*/"))[i] +labels = np.load("labels.npy") # shape (n_systems,) +attach_labels("./npy/", head="bandgap", values=labels) + +check_data("/data/system") # → list[Issue] +``` + +For the full option list and supported dpdata formats, see +[`input_formats.md`](input_formats.md). + +### Context features (fparam) + +fparam lets you condition the model on system-level context such as temperature, humidity, pressure, or any per-frame scalar. All strategies use the same interface: place `fparam.npy` of shape `(n_frames, fparam_dim)` in each `set.*/` directory alongside `coord.npy` and declare the dimension at construction. + +```python +# works identically for frozen_sklearn, frozen_head, finetune, and mft +model = DPAFineTuner(strategy="frozen_sklearn", fparam_dim=2) +model.fit(train_data="data/train", target_key="property") +# fparam.npy is read automatically — no conditions= dict needed +``` + +| Strategy | How fparam is used | +| ---------------------------------- | ---------------------------------------------------------------------------------- | +| `frozen_sklearn` | columns are standardized via `ConditionManager` and concatenated to the descriptor | +| `frozen_head` / `finetune` / `mft` | passed into the fitting net as `numb_fparam` | + +## Inference and uncertainty + +After training, save a portable frozen bundle and load it with `DPAPredictor` — no training dependencies required: + +```python +model.freeze("model.pth") + +from dpa_adapt import DPAPredictor + +pred = DPAPredictor("model.pth") +result = pred.predict("/data/test") # DotDict: .predictions +metrics = pred.evaluate("/data/test") # DotDict: .mae, .rmse, .r2 +``` + +Uncertainty estimation is available for `frozen_sklearn` models: + +- **RF**: native out-of-bag variance, always available +- **MLP**: committee of N independently-seeded clones; set `n_committee` at load time +- **Ridge**: not supported + +```python +pred = DPAPredictor("model.pth", n_committee=5) +result = pred.predict("/data/test", return_uncertainty=True) +# result.predictions — shape (n,) +# result.uncertainty — shape (n,), std across committee members +``` + +Uncertainty estimates can drive active learning (query most uncertain candidates) or feed into Bayesian optimization over composition space. + +## Cross-validation + +Formula-grouped splitting prevents same-composition leakage between folds. +`group_by` accepts `"formula"` (uses each system's directory name as the group +key — requires directories named by formula, e.g. `H2O/`, `CH4/`) or a list +of labels the same length as `systems`: + +```python +from dpa_adapt import cross_validate, train_test_split, load_dataset + +systems = load_dataset("/data/root", label_key="energy") + +# Case 1: directory names are formulas (e.g. data/H2O/, data/CH4/) +train, valid, test = train_test_split(systems, group_by="formula", seed=42) + +# Case 2: directory names are not formulas (e.g. QM9's sys_0000, sys_0001, …) +formulas = ["H2O", "H2O", "CH4", "CH4", ...] # one label per system +train, valid, test = train_test_split(systems, group_by=formulas, seed=42) + +# Cross-validate (same group_by options apply) +result = cross_validate(model, systems, label_key="energy", cv=5, group_by=formulas) +# → {"aggregate": {"mae_mean": ..., "rmse_std": ...}, ...} +``` + +## Python API + +```python +from dpa_adapt import ( + DPAFineTuner, # fine-tune (strategies: frozen_sklearn, frozen_head, finetune, mft) + DPAPredictor, # inference from frozen bundles + extract_descriptors, # standalone descriptor extraction + cross_validate, # leak-proof cross-validation + train_test_split, # formula-grouped splitting + convert, # format-sniffing data conversion + smiles_to_npy, # CSV+SMILES → deepmd/npy + check_data, # data sanity checks + attach_labels, # inject label arrays + load_dataset, # label-filtered data loading +) +``` + +Standalone descriptor extraction: + +```python +X = extract_descriptors( + "/data/systems", + pretrained="/path/to/DPA-3.1-3M.pt", + pooling="mean+std", +) +# → np.ndarray (n_frames, feat_dim * 2) +``` + +## CLI + +| Command | Description | +| ------------------------------------------------------------- | ------------------------------------------------------------------- | +| `dpa-adapt fit` / `dpaad fit` | Fine-tune (`--strategy frozen_sklearn\|frozen_head\|finetune\|mft`) | +| `dpa-adapt predict` / `dpaad predict` | Predict with a frozen `.pth` bundle | +| `dpa-adapt evaluate` / `dpaad evaluate` | Evaluate against stored labels | +| `dpa-adapt extract-descriptors` / `dpaad extract-descriptors` | Extract pooled DPA descriptors to `.npy` | +| `dpa-adapt cv` / `dpaad cv` | Cross-validate | +| `dpa-adapt data convert` / `dpaad data convert` | Convert structure / CSV → `deepmd/npy` | +| `dpa-adapt data validate` / `dpaad data validate` | Sanity-check `deepmd/npy` directories | +| `dpa-adapt data attach-labels` / `dpaad data attach-labels` | Inject `.npy` label arrays | + +```bash +# Data conversion +# Structure file +dpa-adapt data convert --input POSCAR --output ./npy + +# SMILES CSV: --property-col names the input target column and output label name. +dpaad data convert --input data.csv --output ./npy --fmt smiles \ + --property-col homo + +# Fine-tune +dpa-adapt fit --train-data ./npy/train --pretrained DPA-3.1-3M \ + --strategy frozen_sklearn --predictor rf --target-key homo --output model.pth + +# MFT +dpaad fit --train-data /data/qm9 --aux-data /data/spice2 \ + --pretrained /path/to/DPA-3.1-3M.pt --strategy mft --target-key homo + +# Predict / evaluate +dpa-adapt predict --model model.pth --data ./npy/test --output pred.npy +dpa-adapt evaluate --model model.pth --data ./npy/test +``` + +`dpa-adapt --help` and `dpaad --help` do not load torch — all heavy imports are lazy. diff --git a/doc/index.rst b/doc/index.rst index d914782247..6e698be273 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -43,6 +43,7 @@ DeePMD-kit is a package written in Python/C++, designed to minimize the effort r freeze/index test/index inference/index + dpa_adapt/index cli third-party/index agent-skills diff --git a/dpa_adapt/__init__.py b/dpa_adapt/__init__.py new file mode 100644 index 0000000000..5af6edff26 --- /dev/null +++ b/dpa_adapt/__init__.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DPA tools — fine-tuning, descriptor extraction, cross-validation, and data +utilities for DPA-3 pretrained models. + +All public names are lazily imported: ``import dpa_adapt`` does not load +torch, dpdata, or any other heavy dependency until you actually access +a specific class or function. +""" + +__version__ = "0.1.0" + +_LAZY = { + "ConditionManager": (".conditions", "ConditionManager"), + "DPAConditionError": (".conditions", "DPAConditionError"), + "cross_validate": (".cv", "cross_validate"), + "train_test_split": (".cv", "train_test_split"), + "SmilesDataResult": (".data", "SmilesDataResult"), + "attach_labels": (".data", "attach_labels"), + "check_data": (".data", "check_data"), + "convert": (".data", "convert"), + "load_dataset": (".data", "load_dataset"), + "smiles_to_npy": (".data", "smiles_to_npy"), + "DPAFineTuner": (".finetuner", "DPAFineTuner"), + "extract_descriptors": (".finetuner", "extract_descriptors"), + "MFTFineTuner": (".mft", "MFTFineTuner"), + "DPAPredictor": (".predictor", "DPAPredictor"), + "DPATrainer": (".trainer", "DPATrainer"), +} + +__all__ = list(_LAZY) + + +def __getattr__(name: str) -> object: + if name in _LAZY: + import importlib + + mod_name, attr_name = _LAZY[name] + mod = importlib.import_module(mod_name, __package__) + attr = getattr(mod, attr_name) + # Cache in the module namespace so __getattr__ is only called once + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dpa_adapt/_backend.py b/dpa_adapt/_backend.py new file mode 100644 index 0000000000..99b2cf9725 --- /dev/null +++ b/dpa_adapt/_backend.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Single chokepoint for all ``deepmd`` internal API and ``torch`` calls. + +Every import from ``deepmd.pt.*``, ``deepmd.utils.model_branch_dict``, or +``torch`` that is needed by the rest of ``dpa_adapt`` must go through +this module. No other file in ``dpa_adapt`` may import those packages directly. + +All functions that load ``torch`` or ``deepmd.pt`` keep the import inside the +function body so that importing this module is cheap. +""" + +from __future__ import ( + annotations, +) + +import logging +from typing import ( + Any, +) + +# ``get_model_dict`` is backend-agnostic and lightweight — safe at module level. +from deepmd.utils.model_branch_dict import get_model_dict as _get_model_dict + +_LOG = logging.getLogger("dpa_adapt") + + +def resolve_dp_command() -> str: + """Return the ``dp`` executable associated with the current Python env.""" + import os as _os + import shutil as _shutil + import sys as _sys + import sysconfig as _sysconfig + from pathlib import Path as _Path + + exe_name = "dp.exe" if _os.name == "nt" else "dp" + scripts_dir = _sysconfig.get_path("scripts") + candidates = [ + _Path(_sys.executable).parent / exe_name, + ] + if scripts_dir: + candidates.append(_Path(scripts_dir) / exe_name) + for candidate in candidates: + if candidate.is_file(): + return _os.fspath(candidate) + + found = _shutil.which("dp") + if found: + return found + + return "dp" + + +# --------------------------------------------------------------------------- +# torch I/O +# --------------------------------------------------------------------------- + + +def _is_url_or_name(path: str) -> bool: + """Return True if *path* looks like a URL or a built-in model name rather + than a local file path. + """ + import os as _os + + return not _os.path.exists(path) + + +def resolve_pretrained_path(pretrained: str, cache_dir: str | None = None) -> str: + """Resolve *pretrained* to a local file path, downloading if necessary. + + If *pretrained* is a local checkpoint path, it is returned unchanged. This + includes non-existing path-like values so callers can raise their own + context-specific ``not found`` errors or tests can monkeypatch checkpoint + loading. Bare names (e.g. ``"DPA-3.1-3M"``) are resolved via + :func:`deepmd.pretrained.download.resolve_model_path`. + """ + import os as _os + from pathlib import Path as _Path + + pretrained = _os.fspath(pretrained) + + if _os.path.isfile(pretrained): + return pretrained + + p = _Path(pretrained) + is_path_like = ( + p.is_absolute() + or any(sep and sep in pretrained for sep in (_os.sep, _os.altsep)) + or p.suffix.lower() in {".pt", ".pth"} + ) + if is_path_like: + return pretrained + + from deepmd.pretrained.download import resolve_model_path as _download + + path = _download(pretrained, cache_dir=cache_dir) + _LOG.info("Resolved pretrained model: %s", path) + return _os.fspath(path) + + +def load_torch_file(path: str, map_location: str = "cpu") -> dict[str, Any]: + """Load a PyTorch checkpoint or frozen bundle. + + Always uses ``weights_only=False`` because deepmd checkpoints carry + ``_extra_state`` (non-tensor metadata) and dpa_adapt frozen bundles + carry ``sklearn`` pipeline objects. + """ + import torch + + try: + return torch.load(path, map_location=map_location, weights_only=False) + except RuntimeError as exc: + if "Invalid magic number" not in str(exc): + raise + import pickle + + with open(path, "rb") as f: + return pickle.load(f) + + +# --------------------------------------------------------------------------- +# model construction +# --------------------------------------------------------------------------- + + +def build_model_from_config(input_param: dict[str, Any]) -> Any: + """Build a (non-JIT) DPA model from an input-parameter dict. + + Returns a ``ModelWrapper`` whose inner model is accessible as + ``wrapper.model["Default"]``. + """ + from deepmd.pt.model.model import ( + get_model, + ) + from deepmd.pt.train.wrapper import ( + ModelWrapper, + ) + + model = get_model(input_param) + return ModelWrapper(model) + + +# --------------------------------------------------------------------------- +# multi-task branch helpers +# --------------------------------------------------------------------------- + + +def resolve_model_branch(model_dict: dict[str, Any]) -> tuple[dict[str, str], str]: + """Resolve multi-task model-branch aliases. + + Returns ``(alias_dict, model_dict)`` — the same tuple shape as the + upstream ``get_model_dict``. + """ + return _get_model_dict(model_dict) + + +# --------------------------------------------------------------------------- +# device +# --------------------------------------------------------------------------- + + +def get_torch_device() -> Any: + """Return ``torch.device("cuda")`` if a GPU is available, else CPU.""" + import torch + + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# --------------------------------------------------------------------------- +# descriptor extraction (the fragile chain) +# --------------------------------------------------------------------------- + + +class _DescriptorExtraction: + """Thin wrapper around a loaded model that runs a *single* forward pass + with ``eval_descriptor_hook`` enabled and returns per-atom descriptors. + + This is the lowest-level building block. Callers (like + ``DPAFineTuner._extract_features``) are responsible for pooling, + batching, and tensor creation. + """ + + def __init__(self, wrapper: Any) -> None: + inner = wrapper.model["Default"] + self._inner_model = inner + self._atomic_model = inner.atomic_model + self._descriptor_hook_model = self._resolve_descriptor_hook_model() + + def _resolve_descriptor_hook_model(self) -> Any | None: + # Prefer atomic_model — it owns eval_descriptor_list; dp_model + # delegates set_eval_descriptor_hook / eval_descriptor to it but + # lacks the list attribute, so _clear_accumulator was a no-op. + for model in (self._atomic_model, self._inner_model): + if hasattr(model, "set_eval_descriptor_hook") and hasattr( + model, "eval_descriptor" + ): + return model + return None + + def _enable_hook(self) -> None: + if self._descriptor_hook_model is None: + return + self._descriptor_hook_model.set_eval_descriptor_hook(True) + + def _disable_hook(self) -> None: + if self._descriptor_hook_model is None: + return + self._descriptor_hook_model.set_eval_descriptor_hook(False) + + def _clear_accumulator(self) -> None: + if self._descriptor_hook_model is None: + return + if hasattr(self._descriptor_hook_model, "eval_descriptor_list"): + self._descriptor_hook_model.eval_descriptor_list.clear() + + def _run_forward(self, coord: Any, atype: Any, box: Any) -> Any: + """Run ``forward_common`` and return per-atom descriptors (detached). + + Parameters + ---------- + coord : torch.Tensor + (n_frames, n_atoms*3), float64, requires_grad. + atype : torch.Tensor + (n_frames, n_atoms), int64. + box : torch.Tensor + (n_frames, 9), float64. + + Returns + ------- + torch.Tensor + (n_frames, n_atoms, feat_dim), detached. + """ + if not coord.requires_grad: + raise RuntimeError( + "forward_common requires coord to have requires_grad=True" + ) + if self._descriptor_hook_model is None: + if not hasattr(self._inner_model, "forward_embedding"): + raise AttributeError( + "Loaded model exposes neither descriptor hook methods nor " + "forward_embedding()." + ) + result = self._inner_model.forward_embedding(coord, atype, box) + return result["descriptor"].detach() + self._clear_accumulator() + self._inner_model.forward_common(coord, atype, box) + return self._descriptor_hook_model.eval_descriptor().detach() diff --git a/dpa_adapt/_validation.py b/dpa_adapt/_validation.py new file mode 100644 index 0000000000..966974f883 --- /dev/null +++ b/dpa_adapt/_validation.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# dpa_adapt/_validation.py +# +# Small shared argument validators for the fine-tuning entry points +# (DPATrainer, MFTFineTuner, DPAFineTuner) so the same checks are not +# copy-pasted across constructors. + +from __future__ import ( + annotations, +) + + +def validate_fparam_dim(fparam_dim: int) -> None: + """Raise ``ValueError`` unless *fparam_dim* is a non-negative int. + + ``0`` means "no fparam conditioning"; any positive value is the width of + the per-frame ``fparam.npy`` arrays. + """ + if not isinstance(fparam_dim, int) or fparam_dim < 0: + raise ValueError(f"fparam_dim must be a non-negative int; got {fparam_dim!r}.") diff --git a/dpa_adapt/cli.py b/dpa_adapt/cli.py new file mode 100644 index 0000000000..b5a7b1f104 --- /dev/null +++ b/dpa_adapt/cli.py @@ -0,0 +1,727 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""CLI entry point for the ``dpa-adapt`` and ``dpaad`` commands. + +Unlike the deepmd-kit ``dp`` command, ``dpa-adapt`` is a standalone CLI that +focuses solely on DPA model fine-tuning, descriptor extraction, +cross-validation, prediction, evaluation, and data preparation. + +``dpa-adapt --help`` and ``dpaad --help`` do not load torch — the parser is +pure argparse and the handlers (and the DPA stack) are imported lazily only +when a subcommand actually runs. +""" + +from __future__ import ( + annotations, +) + +import argparse +import json +import logging +import os +import sys +from typing import ( + TYPE_CHECKING, +) + +import numpy as np + +if TYPE_CHECKING: + from collections.abc import ( + Sequence, + ) + +_LOG = logging.getLogger("dpa_adapt") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_ll(log_level: str) -> int: + """Convert string to python logging level. + + Parameters + ---------- + log_level : str + allowed input values are: DEBUG, INFO, WARNING, ERROR, 3, 2, 1, 0 + + Returns + ------- + int + one of python logging module log levels - 10, 20, 30 or 40 + """ + if log_level.isdigit(): + int_level = (4 - int(log_level)) * 10 + else: + int_level = getattr(logging, log_level) + return int_level + + +def _set_log_handles(level: int, log_path: str | None = None) -> None: + """Set up logging to console and optionally a file.""" + logger = logging.getLogger("dpa_adapt") + logger.setLevel(level) + # Avoid duplicate handlers on repeated calls + if logger.handlers: + return + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + console = logging.StreamHandler(sys.stdout) + console.setLevel(level) + console.setFormatter(formatter) + logger.addHandler(console) + if log_path: + os.makedirs(os.path.dirname(log_path) or ".", exist_ok=True) + file_handler = logging.FileHandler(log_path) + file_handler.setLevel(level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + +def _maybe_split_list(val: str | Sequence[str] | None) -> list[str] | None: + """Normalize comma-separated strings or string sequences to a flat list.""" + if val is None: + return None + if isinstance(val, str): + values = [val] + else: + values = val + return [ + item + for value in values + for item in (part.strip() for part in value.split(",")) + if item + ] + + +def _parse_batch_size(val: str) -> str | int: + """Parse DeePMD batch-size specs, preserving strings like ``auto:512``.""" + text = val.strip() + if not text: + raise argparse.ArgumentTypeError("batch size must not be empty") + try: + return int(text) + except ValueError: + return text + + +class _RawTextArgDefaultsHelpFormatter( + argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter +): + """Formatter for multi-line help with default values.""" + + +# --------------------------------------------------------------------------- +# Subcommand handlers — each lazy-imports its dependencies +# --------------------------------------------------------------------------- + + +def _cmd_fit(args: argparse.Namespace) -> int: + from dpa_adapt import ( + DPAFineTuner, + ) + + train = _maybe_split_list(args.train_data) or [] + valid = _maybe_split_list(args.valid_data) if args.valid_data else None + type_map = _maybe_split_list(args.type_map) + + # Parse target_key: comma-separated → list[str] (multi-property), + # single value → str (single-property, backward compat). + target_keys = _maybe_split_list(args.target_key) + if target_keys is None: + target_key = "property" + prop_name = "property" + elif len(target_keys) == 1: + target_key = target_keys[0] + prop_name = target_key + else: + target_key = target_keys + prop_name = target_keys[0] + + model = DPAFineTuner( + pretrained=args.pretrained, + model_branch=args.model_branch, + predictor=args.predictor, + pooling=args.pooling, + seed=args.seed, + strategy=args.strategy, + property_name=prop_name, + task_dim=args.task_dim, + intensive=args.intensive, + learning_rate=args.learning_rate, + stop_lr=args.stop_lr, + max_steps=args.max_steps, + batch_size=args.batch_size, + output_dir=args.output_dir, + save_freq=args.save_freq, + disp_freq=args.disp_freq, + # MFT + aux_branch=args.aux_branch, + aux_prob=args.aux_prob, + type_map=_maybe_split_list(args.mft_type_map), + downstream_task_type=args.downstream_task_type, + aux_batch_size=args.aux_batch_size, + downstream_batch_size=args.downstream_batch_size, + fparam_dim=args.fparam_dim, + ) + aux_data = _maybe_split_list(args.aux_data) if args.aux_data else None + model.fit( + train_data=train, + valid_data=valid, + type_map=type_map, + target_key=target_key, + aux_data=aux_data, + ) + if args.strategy == "frozen_sklearn": + out = model.freeze(args.output) + _LOG.info("Frozen model → %s", out) + else: + _LOG.info("Checkpoint → %s", args.output_dir) + return 0 + + +def _cmd_cv(args: argparse.Namespace) -> int: + from dpa_adapt import ( + DPAFineTuner, + cross_validate, + load_dataset, + ) + + systems = load_dataset(args.data, label_key=args.label_key) + _LOG.info("%s systems", len(systems)) + + model = DPAFineTuner( + pretrained=args.pretrained, + model_branch=args.model_branch, + predictor=args.predictor, + pooling=args.pooling, + seed=args.seed, + ) + result = cross_validate( + model, + systems, + label_key=args.label_key, + cv=args.cv if args.cv == "holdout" else int(args.cv), + group_by=args.group_by or "formula", + granularity=args.granularity, + seed=args.seed, + ) + a = result["aggregate"] + _LOG.info( + "R² = %.4f ± %.4f", + a.get("r2_mean", float("nan")), + a.get("r2_std", float("nan")), + ) + _LOG.info( + "MAE = %.4f ± %.4f", + a.get("mae_mean", float("nan")), + a.get("mae_std", float("nan")), + ) + _LOG.info( + "RMSE= %.4f ± %.4f", + a.get("rmse_mean", float("nan")), + a.get("rmse_std", float("nan")), + ) + _LOG.info("n = %s independent groups", result["n_independent"]) + for w in result.get("warnings", []): + _LOG.warning("%s", w) + return 0 + + +def _cmd_extract_descriptors(args: argparse.Namespace) -> int: + from dpa_adapt.finetuner import ( + extract_descriptors, + ) + + X = extract_descriptors( + args.data, + pretrained=args.pretrained, + model_branch=args.model_branch, + pooling=args.pooling, + cache=not args.no_cache, + ) + np.save(args.output, X) + _LOG.info("Descriptors shape=%s → %s", X.shape, args.output) + return 0 + + +def _cmd_predict(args: argparse.Namespace) -> int: + from dpa_adapt import ( + DPAPredictor, + ) + + predictor = DPAPredictor(args.model) + result = predictor.predict(args.data) + np.save(args.output, result.predictions) + _LOG.info("Predictions shape=%s → %s", result.predictions.shape, args.output) + return 0 + + +def _cmd_evaluate(args: argparse.Namespace) -> int: + from dpa_adapt import ( + DPAPredictor, + ) + + predictor = DPAPredictor(args.model) + metrics = predictor.evaluate(args.data) + _LOG.info("MAE : %.6f", metrics.mae) + _LOG.info("RMSE : %.6f", metrics.rmse) + _LOG.info("R² : %.6f", metrics.r2) + _LOG.info("N : %s", metrics.predictions.shape[0]) + return 0 + + +def _cmd_data_convert(args: argparse.Namespace) -> int: + + type_map = _maybe_split_list(args.type_map) + + from dpa_adapt import ( + convert, + ) + + result = convert( + input_path=args.input, + output_dir=args.output, + fmt=args.fmt, + type_map=type_map, + property_name=args.property_name or args.property_col, + property_col=args.property_col, + train_ratio=args.train_ratio, + smiles_col=args.smiles_col, + mol_dir=args.mol_dir, + mol_template=args.mol_template, + split_seed=args.split_seed, + conformer_seed=args.conformer_seed, + overwrite=args.overwrite, + validate=args.validate, + strict=args.strict, + verbose=False, + ) + if result["method"] == "smiles": + _LOG.info("Train systems: %s", len(result["train_systems"])) + _LOG.info("Valid systems: %s", len(result["valid_systems"])) + _LOG.info("Type map : %s", result["type_map"]) + _LOG.info("Samples used : %s", result["samples_used"]) + _LOG.info("Failed rows : %s", len(result["failed_rows"])) + _LOG.info("Skipped zero : %s", result["skipped_zero"]) + _LOG.info("Skipped overlap: %s", result["skipped_overlap"]) + elif result["method"] == "batch_dpdata": + _LOG.info("Output dirs : %s", len(result["output_dirs"])) + _LOG.info("Manifest : %s", result["manifest"]) + else: + _LOG.info("Wrote deepmd/npy → %s", result["output_dir"]) + return 0 + + +def _cmd_data_validate(args: argparse.Namespace) -> int: + from dpa_adapt import ( + check_data, + ) + from dpa_adapt.data.loader import ( + load_data, + ) + + systems = load_data(args.data) + issues = check_data(systems, strict=False) + if not issues: + _LOG.info("OK: %s system(s) clean.", len(systems)) + return 0 + n_err = sum(1 for i in issues if i.severity == "error") + for i in issues: + log = _LOG.error if i.severity == "error" else _LOG.warning + log("%s/%s :: %s", i.system, i.set_dir, i.description) + _LOG.info( + "%s issue(s): %s error, %s warning", len(issues), n_err, len(issues) - n_err + ) + return 1 if (n_err > 0 or (args.strict and issues)) else 0 + + +def _cmd_data_attach_labels(args: argparse.Namespace) -> int: + from dpa_adapt import ( + attach_labels, + ) + + values = np.load(args.values) + if args.head_json: + head = json.loads(args.head) + else: + head = args.head + attach_labels(args.data, head=head, values=values) + _LOG.info("Labels attached to %s", args.data) + return 0 + + +# --------------------------------------------------------------------------- +# Dispatch table +# --------------------------------------------------------------------------- + +_DISPATCH = { + "extract-descriptors": _cmd_extract_descriptors, + "fit": _cmd_fit, + "cv": _cmd_cv, + "predict": _cmd_predict, + "evaluate": _cmd_evaluate, +} + +_DATA_DISPATCH = { + "convert": _cmd_data_convert, + "validate": _cmd_data_validate, + "attach-labels": _cmd_data_attach_labels, +} + + +# --------------------------------------------------------------------------- +# Argument parser +# --------------------------------------------------------------------------- + + +def get_parser() -> argparse.ArgumentParser: + """Build the standalone ``dpa-adapt`` / ``dpaad`` argument parser. + + Returns + ------- + argparse.ArgumentParser + The fully configured parser for the ``dpa-adapt`` / ``dpaad`` CLI. + """ + try: + from dpa_adapt import ( + __version__, + ) + except ImportError: + __version__ = "unknown" + + parser = argparse.ArgumentParser( + description="DPA tools — fine-tune pre-trained DPA models, extract descriptors, " + "cross-validate, predict, evaluate, and prepare data.", + formatter_class=_RawTextArgDefaultsHelpFormatter, + ) + + # Logging options (shared across all subcommands) + parser_log = argparse.ArgumentParser( + add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser_log.add_argument( + "-v", + "--log-level", + choices=["DEBUG", "3", "INFO", "2", "WARNING", "1", "ERROR", "0"], + default="INFO", + help="set verbosity level by string or number, 0=ERROR, 1=WARNING, " + "2=INFO and 3=DEBUG", + ) + parser_log.add_argument( + "-l", + "--log-path", + type=str, + default=None, + help="set log file to log messages to disk, if not specified, " + "the logs will only be output to console", + ) + + parser.add_argument( + "--version", action="version", version=f"dpa-adapt v{__version__}" + ) + + subparsers = parser.add_subparsers(title="subcommands", dest="command") + + # -- extract-descriptors ------------------------------------------------- + parser_extract = subparsers.add_parser( + "extract-descriptors", + help="Extract pooled DPA descriptors to .npy", + parents=[parser_log], + ) + parser_extract.add_argument( + "--data", required=True, nargs="+", help="System directories." + ) + parser_extract.add_argument( + "--pretrained", required=True, help="Path to DPA checkpoint (.pt)." + ) + parser_extract.add_argument("--model-branch", default=None) + parser_extract.add_argument( + "--pooling", + default="mean", + choices=["mean", "sum", "mean+std", "mean+std+max+min"], + ) + parser_extract.add_argument("--output", required=True, help="Output .npy path.") + parser_extract.add_argument( + "--no-cache", action="store_true", help="Bypass descriptor cache." + ) + + # -- fit ----------------------------------------------------------------- + parser_fit = subparsers.add_parser( + "fit", + help="Train a model (any strategy)", + parents=[parser_log], + ) + parser_fit.add_argument( + "--train-data", required=True, nargs="+", help="Training system directories." + ) + parser_fit.add_argument( + "--valid-data", default=None, nargs="+", help="Validation system directories." + ) + parser_fit.add_argument( + "--pretrained", default="DPA-3.1-3M", help="Path to DPA checkpoint (.pt)." + ) + parser_fit.add_argument("--model-branch", default=None) + parser_fit.add_argument( + "--strategy", + default="frozen_sklearn", + choices=["frozen_sklearn", "frozen_head", "finetune", "mft"], + ) + parser_fit.add_argument( + "--predictor", default="rf", choices=["rf", "linear", "ridge", "mlp"] + ) + parser_fit.add_argument( + "--pooling", + default="mean", + choices=["mean", "sum", "mean+std", "mean+std+max+min"], + ) + parser_fit.add_argument( + "--target-key", + default=None, + help="Label key under set.*/ (e.g. energy, homo, bandgap).", + ) + parser_fit.add_argument("--output", default="frozen_model.pth") + parser_fit.add_argument("--type-map", default=None) + parser_fit.add_argument("--task-dim", type=int, default=1) + parser_fit.add_argument( + "--intensive", action=argparse.BooleanOptionalAction, default=True + ) + parser_fit.add_argument("--max-steps", type=int, default=100_000) + parser_fit.add_argument("--learning-rate", type=float, default=1e-3) + parser_fit.add_argument("--stop-lr", type=float, default=1e-5) + parser_fit.add_argument("--batch-size", type=_parse_batch_size, default="auto:512") + parser_fit.add_argument("--seed", type=int, default=42) + parser_fit.add_argument("--output-dir", default="./dpa_output") + parser_fit.add_argument("--save-freq", type=int, default=10_000) + parser_fit.add_argument("--disp-freq", type=int, default=1_000) + # MFT-only flags + parser_fit.add_argument( + "--aux-data", + default=None, + nargs="+", + help="(mft) Auxiliary system directories.", + ) + parser_fit.add_argument( + "--aux-branch", + default="MP_traj_v024_alldata_mixu", + help="(mft) Aux branch name in checkpoint.", + ) + parser_fit.add_argument( + "--aux-prob", + type=float, + default=0.5, + help="(mft) Sampling weight for aux branch.", + ) + parser_fit.add_argument( + "--mft-type-map", + default=None, + help="(mft) Global (shared) type map for MFT. Must be the union of " + "elements in both aux and downstream datasets. " + "Comma-separated, e.g. 'H,C,N,O'. " + "Auto-detected from checkpoint if omitted.", + ) + parser_fit.add_argument( + "--downstream-task-type", + default="property", + choices=["ener", "property"], + help="(mft) Downstream head type.", + ) + parser_fit.add_argument( + "--aux-batch-size", + type=_parse_batch_size, + default=None, + help="(mft) Batch size for aux branch.", + ) + parser_fit.add_argument( + "--downstream-batch-size", + type=_parse_batch_size, + default=None, + help="(mft) Batch size for downstream.", + ) + parser_fit.add_argument( + "--fparam-dim", + type=int, + default=0, + help="(frozen_head/finetune/mft) Dimensionality of per-frame condition " + "inputs (fparam). Requires set.*/fparam.npy in training data. Default: 0.", + ) + + # -- cv ------------------------------------------------------------------ + parser_cv = subparsers.add_parser( + "cv", + help="Cross-validate frozen_sklearn baseline", + parents=[parser_log], + ) + parser_cv.add_argument( + "--data", required=True, nargs="+", help="System directories." + ) + parser_cv.add_argument("--label-key", default="energy") + parser_cv.add_argument( + "--pretrained", default="DPA-3.1-3M", help="Path to DPA checkpoint (.pt)." + ) + parser_cv.add_argument("--model-branch", default=None) + parser_cv.add_argument( + "--predictor", default="rf", choices=["rf", "linear", "ridge", "mlp"] + ) + parser_cv.add_argument( + "--pooling", + default="mean", + choices=["mean", "sum", "mean+std", "mean+std+max+min"], + ) + parser_cv.add_argument("--cv", default="5") + parser_cv.add_argument("--group-by", default="formula") + parser_cv.add_argument( + "--granularity", default="composition", choices=["frame", "composition"] + ) + parser_cv.add_argument("--seed", type=int, default=42) + + # -- predict ------------------------------------------------------------- + parser_predict = subparsers.add_parser( + "predict", + help="Predict with a frozen .pth bundle", + parents=[parser_log], + ) + parser_predict.add_argument("--model", required=True, help="Path to frozen .pth.") + parser_predict.add_argument( + "--data", required=True, nargs="+", help="System directories." + ) + parser_predict.add_argument("--output", required=True, help="Output .npy path.") + + # -- evaluate ------------------------------------------------------------ + parser_evaluate = subparsers.add_parser( + "evaluate", + help="Evaluate a frozen .pth against stored labels", + parents=[parser_log], + ) + parser_evaluate.add_argument("--model", required=True, help="Path to frozen .pth.") + parser_evaluate.add_argument( + "--data", required=True, nargs="+", help="System directories." + ) + + # -- data (nested group) ------------------------------------------------- + parser_data = subparsers.add_parser( + "data", + help="Data conversion and validation tools", + parents=[parser_log], + ) + data_subparsers = parser_data.add_subparsers( + dest="data_command", + required=True, + ) + + # data convert + parser_data_convert = data_subparsers.add_parser( + "convert", + help="Convert structure/CSV file → deepmd/npy (format auto-detected)", + parents=[parser_log], + ) + parser_data_convert.add_argument("--input", required=True) + parser_data_convert.add_argument("--output", required=True) + parser_data_convert.add_argument( + "--fmt", + default=None, + help="Format hint (auto-detected if omitted). " + "Use 'smiles' for CSV+SMILES, otherwise " + "dpdata format string (vasp/poscar, vasp/outcar, …).", + ) + parser_data_convert.add_argument("--type-map", default=None) + parser_data_convert.add_argument( + "--no-validate", dest="validate", action="store_false" + ) + parser_data_convert.add_argument("--strict", action="store_true") + parser_data_convert.add_argument("--property-name", default=None) + parser_data_convert.add_argument("--property-col", default="Property") + parser_data_convert.add_argument("--smiles-col", default="SMILES") + parser_data_convert.add_argument("--mol-dir", default=None) + parser_data_convert.add_argument( + "--mol-template", + default="id{row}.mol", + help="Filename template under --mol-dir; use {row} for the CSV row index.", + ) + parser_data_convert.add_argument("--train-ratio", type=float, default=0.9) + parser_data_convert.add_argument( + "--split-seed", + type=int, + default=None, + help="Random seed for train/valid split (SMILES input).", + ) + parser_data_convert.add_argument( + "--conformer-seed", + type=int, + default=None, + help="Random seed for RDKit conformer generation (SMILES input).", + ) + parser_data_convert.add_argument("--overwrite", action="store_true") + + # data validate + parser_data_validate = data_subparsers.add_parser( + "validate", + help="Sanity-check deepmd/npy directories", + parents=[parser_log], + ) + parser_data_validate.add_argument("--data", required=True, nargs="+") + parser_data_validate.add_argument("--strict", action="store_true") + + # data attach-labels + parser_data_attach = data_subparsers.add_parser( + "attach-labels", + help="Attach .npy labels to deepmd/npy directory", + parents=[parser_log], + ) + parser_data_attach.add_argument("--data", required=True) + parser_data_attach.add_argument("--head", required=True) + parser_data_attach.add_argument("--head-json", action="store_true") + parser_data_attach.add_argument("--values", required=True) + + return parser + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(args: Sequence[str] | None = None) -> None: + """Entry point for the ``dpa-adapt`` / ``dpaad`` CLI. + + Parameters + ---------- + args : list[str], optional + Command-line arguments. If ``None``, ``sys.argv[1:]`` is used. + """ + parser = get_parser() + parsed_args = parser.parse_args(args) + + if parsed_args.command is None: + parser.print_help() + return + + # Set up logging after subcommand parsing; subcommands provide these options. + log_level = _get_ll(getattr(parsed_args, "log_level", "INFO")) + _set_log_handles(log_level, getattr(parsed_args, "log_path", None)) + + try: + if parsed_args.command == "data": + handler = _DATA_DISPATCH.get(parsed_args.data_command) + if handler is None: + _LOG.error("Unknown data command: %s", parsed_args.data_command) + sys.exit(1) + sys.exit(handler(parsed_args)) + else: + handler = _DISPATCH.get(parsed_args.command) + if handler is None: + _LOG.error("Unknown dpa-adapt command: %s", parsed_args.command) + sys.exit(1) + sys.exit(handler(parsed_args)) + except Exception as exc: + # Lazy-import DPADataError so that --help doesn't trigger heavy imports. + from dpa_adapt.data.errors import ( + DPADataError, + ) + + if isinstance(exc, DPADataError): + _LOG.error("%s", exc) + sys.exit(1) + raise diff --git a/dpa_adapt/conditions.py b/dpa_adapt/conditions.py new file mode 100644 index 0000000000..a0765d196b --- /dev/null +++ b/dpa_adapt/conditions.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# dpa_adapt/conditions.py +"""Condition manager for scalar condition inputs (e.g. temperature, pressure).""" + +import pickle + +import numpy as np + + +class DPAConditionError(Exception): + """Raised when conditions are missing, mismatched, or used before fit.""" + + pass + + +class ConditionManager: + """Fit a StandardScaler per named condition key, then produce a + normalized (n, d_total) array for downstream concatenation. + """ + + def __init__(self) -> None: + self._scalers = None + self._keys = None + + def fit(self, conditions: dict[str, np.ndarray]) -> None: + from sklearn.preprocessing import ( + StandardScaler, + ) + + self._scalers = {} + self._keys = sorted(conditions.keys()) + for key in self._keys: + scaler = StandardScaler() + scaler.fit(np.asarray(conditions[key]).reshape(-1, 1)) + self._scalers[key] = scaler + + def transform(self, conditions: dict[str, np.ndarray]) -> np.ndarray: + if self._scalers is None: + raise DPAConditionError("ConditionManager.transform() called before fit().") + parts = [] + for key in self._keys: + if key not in conditions: + raise DPAConditionError( + f"Condition key {key!r} was present at fit time " + f"but is missing from transform()." + ) + x = self._scalers[key].transform(np.asarray(conditions[key]).reshape(-1, 1)) + parts.append(x) + return np.hstack(parts) + + def fit_transform(self, conditions: dict[str, np.ndarray]) -> np.ndarray: + self.fit(conditions) + return self.transform(conditions) + + def save(self, path: str) -> None: + with open(path, "wb") as f: + pickle.dump({"scalers": self._scalers, "keys": self._keys}, f) + + @classmethod + def load(cls, path: str) -> "ConditionManager": + with open(path, "rb") as f: + data = pickle.load(f) + obj = cls() + obj._scalers = data["scalers"] + obj._keys = data["keys"] + return obj diff --git a/dpa_adapt/config/__init__.py b/dpa_adapt/config/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/dpa_adapt/config/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/dpa_adapt/config/manager.py b/dpa_adapt/config/manager.py new file mode 100644 index 0000000000..2d267a57a1 --- /dev/null +++ b/dpa_adapt/config/manager.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +from typing import ( + Any, +) + +from dpa_adapt._backend import ( + resolve_dp_command, +) + +# Default property-head architecture for MFT DOWNSTREAM when +# downstream_task_type="property". Mirrors DPATrainer.DEFAULT_FITTING_NET +# (trainer.py L64-70) plus dim_case_embd=31, which the DPA-3.1-3M ckpt +# requires for the case-embedding layer in multi-task mode. (DPATrainer is +# single-task and doesn't need this field; in MFT the descriptor is shared +# across branches so the property head must declare it.) +_PROPERTY_FITTING_NET_BASE = { + "type": "property", + "neuron": [240, 240, 240], + "activation_function": "tanh", + "resnet_dt": True, + "precision": "float32", + "dim_case_embd": 31, +} + + +def _build_property_fitting_net(t: Any) -> dict: + """Construct a property fitting_net dict from a tuner's property params. + The property head is independent of the aux branch's ener fitting_net + that came out of the ckpt — reusing the ener config silently introduces + a force-field bias layer (Bug root cause). + """ + fn = dict(_PROPERTY_FITTING_NET_BASE) + fn.update( + { + "property_name": t.property_name, + "task_dim": t.task_dim, + "intensive": t.intensive, + "seed": t.seed, + } + ) + if getattr(t, "fparam_dim", 0) > 0: + fn["numb_fparam"] = t.fparam_dim + return fn + + +def _build_property_loss() -> dict: + """Property-task loss for DOWNSTREAM. Notes: + - No start_pref_f / start_pref_v: HOMO/LUMO data has no forces/virials. + - property_name MUST NOT appear here: deepmd 3.1.3 strict-mode dargs + rejects unknown keys inside loss_property (it belongs on fitting_net). + """ + return { + "type": "property", + "loss_func": "mse", + "metric": ["mae", "rmse"], + "beta": 1.0, + } + + +_ENER_LOSS = { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1, +} + + +class MFTConfigManager: + def __init__(self, tuner: Any) -> None: + self.t = tuner + + def build(self) -> dict: + t = self.t + aux_fitting_net = ( + t.fitting_net_params + if getattr(t, "fitting_net_params", None) + else {"type": "ener"} + ) + # DOWNSTREAM branch: ener (legacy, sensitivity-analysis callers) or + # property (paper-faithful BOOM eval). Default 'ener' for back-compat + # with FakeTuners and existing callers that don't set the attr. + downstream_task_type = getattr(t, "downstream_task_type", "ener") + is_property = downstream_task_type == "property" + # Branch key for the downstream head. Paper qm9_gap/mft uses "property"; + # legacy ener mode keeps "DOWNSTREAM" so mp_data sensitivity-analysis + # configs stay byte-for-byte unchanged (renaming would break the branch + # name in their already-trained ckpts). + downstream_key = "property" if is_property else "DOWNSTREAM" + if is_property: + downstream_fitting_net = _build_property_fitting_net(t) + downstream_loss = _build_property_loss() + else: + downstream_fitting_net = aux_fitting_net + downstream_loss = dict(_ENER_LOSS) + + # Paper qm9_gap/mft alignment is applied ONLY in property mode. The + # legacy ener path (mp_data sensitivity analysis) stays byte-for-byte + # unchanged. + descriptor = { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 16, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 1200, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 300, + "axis_neuron": 4, + "skip_stat": True, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "smooth_edge_update": True, + "use_dynamic_sel": True, + "sel_reduce_factor": 10.0, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "n_multi_edge_message": 1, + "optim_update": True, + "use_exp_switch": True, + }, + "activation_function": "silut:3.0" if is_property else "custom_silu:3.0", + "precision": "float32", + "use_tebd_bias": False, + "concat_output_tebd": False, + "exclude_types": [], + "env_protection": 0.0, + "trainable": True, + "use_econf_tebd": False, + } + if is_property: + descriptor["repflow"]["fix_stat_std"] = 0.3 + + # MFT branch heads. In property mode the paper pins finetune_head: + # the aux head loads from its named branch, the downstream property + # head is RANDOM-initialized (paper Eq 12). Legacy ener mode keeps the + # original layout (no finetune_head on aux; downstream = aux branch), + # including key order, so the emitted JSON is byte-for-byte unchanged. + if is_property: + aux_head = { + "type_map": "type_map", + "descriptor": "dpa3_descriptor", + "fitting_net": aux_fitting_net, + "finetune_head": t.aux_branch, + } + downstream_head = { + "finetune_head": "RANDOM", + "type_map": "type_map", + "descriptor": "dpa3_descriptor", + "fitting_net": downstream_fitting_net, + } + else: + aux_head = { + "type_map": "type_map", + "descriptor": "dpa3_descriptor", + "fitting_net": aux_fitting_net, + } + downstream_head = { + "finetune_head": t.aux_branch, + "type_map": "type_map", + "descriptor": "dpa3_descriptor", + "fitting_net": downstream_fitting_net, + } + + decay_steps = ( + t.decay_steps + if getattr(t, "decay_steps", None) is not None + else (1000 if is_property else 5000) + ) + # Per-branch batch sizes: explicit override wins, then paper defaults + # for property mode, then the single batch_size for legacy ener mode. + aux_batch = getattr(t, "aux_batch_size", None) or ( + "auto:128" if is_property else t.batch_size + ) + downstream_batch = getattr(t, "downstream_batch_size", None) or ( + "auto:512" if is_property else t.batch_size + ) + # Paper default 0.5/0.5; aux_prob (default 0.5) controls the split, the + # downstream share is the complement. Legacy keeps downstream at 1.0. + aux_prob = float(t.aux_prob) + if not 0.0 <= aux_prob <= 1.0: + raise ValueError(f"aux_prob must be in [0, 1]; got {t.aux_prob!r}.") + downstream_prob = (1.0 - aux_prob) if is_property else 1.0 + + aux_systems = t.aux_data if isinstance(t.aux_data, list) else [t.aux_data] + train_systems = ( + t.train_data if isinstance(t.train_data, list) else [t.train_data] + ) + valid_systems = None + if getattr(t, "valid_data", None) is not None: + valid_systems = ( + t.valid_data if isinstance(t.valid_data, list) else [t.valid_data] + ) + + training = { + "model_prob": {t.aux_branch: aux_prob, downstream_key: downstream_prob}, + "data_dict": { + t.aux_branch: { + "training_data": {"systems": aux_systems, "batch_size": aux_batch} + }, + downstream_key: { + "training_data": { + "systems": train_systems, + "batch_size": downstream_batch, + } + }, + }, + "numb_steps": t.max_steps, + "save_freq": t.save_freq, + # Pin the checkpoint prefix under output_dir (matching DPATrainer), + # so DeePMD writes model.ckpt-*.pt there regardless of the process + # cwd. Otherwise _freeze_ckpt()/evaluate()/predict() — which look + # under output_dir — cannot find the checkpoint after a successful + # fit() launched from another directory. + "save_ckpt": os.path.join(t.output_dir, "model.ckpt"), + "disp_freq": t.disp_freq, + "seed": t.seed, + } + if valid_systems is not None: + training["data_dict"][downstream_key]["validation_data"] = { + "systems": valid_systems, + "batch_size": downstream_batch, + } + + if is_property: + # Paper qm9_gap: gradient clipping at 5.0. + training["gradient_max_norm"] = 5.0 + + return { + "model": { + "shared_dict": { + "dpa3_descriptor": descriptor, + "type_map": t.type_map, + }, + "model_dict": {t.aux_branch: aux_head, downstream_key: downstream_head}, + }, + "learning_rate": { + "type": "exp", + "start_lr": t.learning_rate, + "stop_lr": t.stop_lr, + "decay_steps": decay_steps, + **( + {"warmup_steps": t.warmup_steps} + if getattr(t, "warmup_steps", 0) > 0 + else {} + ), + }, + "loss_dict": { + t.aux_branch: dict(_ENER_LOSS), + downstream_key: downstream_loss, + }, + "training": training, + } + + def save(self, config: dict, path: str) -> str: + with open(path, "w") as f: + json.dump(config, f, indent=2) + return path + + def build_cmd(self, input_json_path: str) -> list[str]: + t = self.t + # MFT mode: do not pass --model-branch (branches are keyed by model_dict). + # The full descriptor config is already in the JSON, so + # --use-pretrain-script is not needed. + return [ + resolve_dp_command(), + "--pt", + "train", + input_json_path, + "--skip-neighbor-stat", + "--finetune", + t.pretrained, + ] diff --git a/dpa_adapt/cv.py b/dpa_adapt/cv.py new file mode 100644 index 0000000000..b5b731e982 --- /dev/null +++ b/dpa_adapt/cv.py @@ -0,0 +1,613 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# cv.py +# +# sklearn-style split and cross-validation for dpdata systems. +# Leak-proof: all operations group by formula / user-provided groups so that +# the same formula never appears in both train and validation/test. + +from __future__ import ( + annotations, +) + +import json +import logging +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import numpy as np +from sklearn.pipeline import ( + make_pipeline, +) +from sklearn.preprocessing import ( + StandardScaler, +) + +from dpa_adapt.data.loader import ( + _get_source, + _resolve_label_key, +) + +_LOG = logging.getLogger("dpa_adapt.cv") + + +# --------------------------------------------------------------------------- +# internal: formula / group helpers +# --------------------------------------------------------------------------- + + +def _extract_formula(system: Any) -> str: + """Extract the formula name from a system. + + Uses the source path stored during loading (``_dpa_source`` attribute). + Falls back to a system hash when no source path is available. + """ + source = _get_source(system) + if source is not None: + return Path(source).resolve().parent.name + return f"sys_{id(system)}" + + +def _formula_to_group(systems: list) -> list[str]: + """Return one group label per system, derived from its path formula.""" + return [_extract_formula(s) for s in systems] + + +def _group_indices(groups: list[str]) -> dict[str, list[int]]: + """Map each unique group to the list of system indices belonging to it.""" + mapping: dict[str, list[int]] = {} + for i, g in enumerate(groups): + mapping.setdefault(g, []).append(i) + return mapping + + +# --------------------------------------------------------------------------- +# internal: manifest parsing +# --------------------------------------------------------------------------- + + +def _build_fold_groups( + manifest_path: str, +) -> tuple[list[set[str]], set[str]]: + """Parse a split_manifest.json into fold groups and test set. + + Returns + ------- + folds : list[set[str]] + One set of formula names per fold. + test : set[str] + Held-out test formulas (may be empty). + """ + m = json.loads(Path(manifest_path).read_text()) + folds: list[set[str]] = [] + test: set[str] = set() + + for tag in ("co", "ni"): + tag_data = m.get(tag, {}) + test.update(tag_data.get("test", [])) + parts = tag_data.get("parts", []) + for i, part in enumerate(parts): + if i >= len(folds): + folds.append(set()) + folds[i].update(part) + + folds = [f for f in folds if f] + return folds, test + + +# --------------------------------------------------------------------------- +# internal: sklearn head builder (delegates to shared factory) +# --------------------------------------------------------------------------- + + +def _build_sklearn_head(predictor_type: str, seed: int = 42) -> Any: + """Map a predictor type string to an sklearn estimator. + + Delegates to ``dpa_adapt.utils.sklearn_heads.build_sklearn_head``. + """ + from dpa_adapt.utils.sklearn_heads import ( + build_sklearn_head, + ) + + return build_sklearn_head(predictor_type, seed=seed) + + +# --------------------------------------------------------------------------- +# internal: per-system lazy assembly (avoids loading all descriptors at once) +# --------------------------------------------------------------------------- + + +def _load_system_labels(system: Any, label_key: str) -> np.ndarray: + """Load labels for a single system, shape (n_frames, ...).""" + resolved = _resolve_label_key(label_key) + return np.asarray(system.data[resolved]) + + +def _assemble_from_per_system_cache( + systems: list, + groups: list[str], + selected_groups: set[str], + label_key: str, + granularity: str, + pretrained: str, + model_branch: str | None, + pooling: str, + type_map: list[str] | tuple[str, ...] | None = None, +) -> tuple[np.ndarray, np.ndarray]: + """Build X, y for systems whose group is in *selected_groups*. + + Reads one system's descriptors at a time from the per-system cache. + Peak memory is proportional to the fold, not the full dataset. + + Parameters + ---------- + systems : list[dpdata.System] + All systems (same order as *groups*). + groups : list[str] + Group label per system. + selected_groups : set[str] + Which groups to include. + label_key : str + Label key in system data (e.g. ``"energies"``). + granularity : str + ``"frame"`` or ``"composition"``. + pretrained : str + Path to the pretrained model checkpoint. + model_branch : str or None + Model branch name for descriptor extraction. + pooling : str + Pooling strategy for descriptor aggregation. + type_map : list[str] or tuple[str, ...] or None + Optional type map for the system. + + Returns + ------- + X : np.ndarray + y : np.ndarray (1D) + """ + from dpa_adapt.data.desc_cache import ( + get_per_system_descriptor, + ) + + X_list, y_list = [], [] + + for system, grp in zip(systems, groups, strict=True): + if grp not in selected_groups: + continue + desc = get_per_system_descriptor( + system, + pretrained=pretrained, + model_branch=model_branch, + pooling=pooling, + type_map=type_map, + ) # (n_frames, feat_dim) + lab = _load_system_labels(system, label_key) # (n_frames, ...) + if granularity == "composition": + desc = desc.mean(axis=0, keepdims=True) + lab = lab.mean(axis=0, keepdims=True) + X_list.append(desc) + y_list.append(lab) + + if not X_list: + return np.empty((0, 0)), np.empty((0,)) + + X = np.concatenate(X_list, axis=0) + y = np.concatenate(y_list, axis=0).ravel() + return X, y + + +# --------------------------------------------------------------------------- +# train_test_split +# --------------------------------------------------------------------------- + + +def train_test_split( + systems: list, + manifest: str | None = None, + group_by: str | list[str] | None = None, + test_size: float = 0.1, + valid_size: float = 0.1, + seed: int = 42, +) -> tuple[list, list, list]: + """Split systems into train / valid / test, leak-proof by group. + + Exactly one of *manifest* or *group_by* must be provided. + + Parameters + ---------- + systems : list + dpdata systems (from ``load_data()`` or ``load_dataset()``). + manifest : str, optional + Path to a ``split_manifest.json``. When provided, the splits are read + from the manifest. + group_by : str or list[str], optional + ``"formula"`` — extract formula from each system's source path. + ``list[str]`` — explicit group label per system (same length as + *systems*). + test_size : float + Fraction of groups held out for test (ignored when *manifest* used). + valid_size : float + Fraction of remaining groups held out for validation. + seed : int + Random seed. + + Returns + ------- + train, valid, test : list + Three disjoint lists of systems. + """ + n = len(systems) + if n == 0: + return [], [], [] + + # --- manifest path --- + if manifest is not None: + folds, test_formulas = _build_fold_groups(manifest) + if not folds: + raise ValueError("Manifest contains no non-empty folds.") + + valid_formulas = folds[-1] + train_formulas: set[str] = set() + for f in folds[:-1]: + train_formulas.update(f) + + grp = _formula_to_group(systems) + train = [s for s, g in zip(systems, grp, strict=True) if g in train_formulas] + valid = [s for s, g in zip(systems, grp, strict=True) if g in valid_formulas] + test = [s for s, g in zip(systems, grp, strict=True) if g in test_formulas] + return train, valid, test + + # --- group_by --- + if group_by is None: + raise ValueError( + "Either manifest= or group_by= must be provided " + "to ensure leak-proof splitting." + ) + + if isinstance(group_by, str) and group_by == "formula": + groups = _formula_to_group(systems) + elif isinstance(group_by, (list, tuple)): + if len(group_by) != n: + raise ValueError( + f"group_by list length ({len(group_by)}) must match systems ({n})." + ) + groups = list(group_by) + else: + raise ValueError( + f"group_by must be 'formula' or a list of strings; got {group_by!r}." + ) + + unique_groups = sorted(set(groups)) + n_groups = len(unique_groups) + if n_groups <= 1: + raise ValueError(f"Only {n_groups} unique group(s) found; cannot split.") + + rng = np.random.default_rng(seed) + perm = rng.permutation(n_groups) + shuffled = [unique_groups[i] for i in perm] + + n_test = max(1, int(np.ceil(n_groups * test_size))) + n_valid = max(1, int(np.ceil((n_groups - n_test) * valid_size))) + + test_groups = set(shuffled[:n_test]) + valid_groups = set(shuffled[n_test : n_test + n_valid]) + train_groups = set(shuffled[n_test + n_valid :]) + + train = [s for s, g in zip(systems, groups, strict=True) if g in train_groups] + valid = [s for s, g in zip(systems, groups, strict=True) if g in valid_groups] + test = [s for s, g in zip(systems, groups, strict=True) if g in test_groups] + + return train, valid, test + + +# --------------------------------------------------------------------------- +# cross_validate +# --------------------------------------------------------------------------- + + +def cross_validate( + model: Any, + systems: list, + label_key: str = "energy", + cv: str | int = 5, + group_by: str | list[str] | None = "formula", + granularity: str = "frame", + allow_expensive_cv: bool = False, + min_groups_warn: int = 30, + seed: int = 42, + manifest: str | None = None, +) -> dict: + """Leak-proof cross-validation for dpdata systems. + + For ``frozen_sklearn`` (the default code path for now), descriptors are + extracted **once** and a cheap sklearn head is trained per fold — even + ``cv=5`` completes in seconds. + + Training paradigms (``frozen_head`` / ``finetune`` / ``mft``) + are expensive: each fold re-trains a full DeePMD model. To prevent + accidental hour-long runs, *allow_expensive_cv* must be explicitly set + to ``True`` for those strategies when *cv* is an integer >= 2. Otherwise + a ``ValueError`` is raised. Non-blocking warnings about estimated runtime + are printed regardless. + + Parameters + ---------- + model : DPAFineTuner + Estimator instance with a ``strategy`` attribute. + systems : list[str] + Validated system directory paths. + label_key : str + Label filename under ``set.*/`` (default ``"energy"``). + cv : str or int + ``"holdout"`` — single train/valid split. Training paradigms default + to this. + ``int >= 2`` — k-fold GroupKFold CV. ``frozen_sklearn`` defaults to 5. + group_by : str or list[str] or None + ``"formula"`` (default) — extract formula from system path. + ``list[str]`` — explicit groups. + ``None`` — no grouping (random split; not recommended for small data). + granularity : str + ``"frame"`` (default) — one prediction per frame. + ``"composition"`` — mean-pool descriptors and labels per formula, + yielding one prediction per independent sample. + allow_expensive_cv : bool + Must be ``True`` to run k-fold CV on a training paradigm. Ignored + for ``frozen_sklearn``. + min_groups_warn : int + Emit a warning when the number of independent groups is below this + threshold. Default 30 is an empirical guideline (small-sample CV + variance is large; see Hastie et al. ESL §7.10). Set to 0 to disable. + seed : int + Random seed for sklearn heads. + manifest : str, optional + Path to a ``split_manifest.json``. When provided, fold definitions + are read from the manifest (deterministic, reproducible). The *cv* + parameter is ignored — the number of folds equals the number of parts + in the manifest. Test formulas in the manifest are excluded from CV. + + Returns + ------- + dict + Keys: ``train_mae``, ``test_mae``, ``test_rmse``, ``test_r2``, + ``aggregate`` (mean/std dict), ``n_independent``, ``warnings`` + (list[str]), ``granularity``. + """ + # ---- resolve strategy ---- + strategy = getattr(model, "strategy", "frozen_sklearn") + is_cheap = strategy == "frozen_sklearn" + + if granularity not in ("frame", "composition"): + raise ValueError( + f"granularity must be 'frame' or 'composition'; got {granularity!r}." + ) + + # ---- resolve groups ---- + if group_by is None: + groups = [f"sys_{i}" for i in range(len(systems))] + elif isinstance(group_by, str) and group_by == "formula": + groups = _formula_to_group(systems) + elif isinstance(group_by, (list, tuple)): + if len(group_by) != len(systems): + raise ValueError( + f"group_by list length ({len(group_by)}) must match " + f"systems ({len(systems)})." + ) + groups = list(group_by) + else: + raise ValueError(f"Invalid group_by: {group_by!r}") + + gmap = _group_indices(groups) + unique_groups = sorted(gmap.keys()) + n_groups = len(unique_groups) + + # ---- resolve cv ---- + if cv == "holdout": + n_splits = 1 + elif isinstance(cv, int) and cv >= 2: + n_splits = cv + else: + raise ValueError(f"cv must be 'holdout' or an int >= 2; got {cv!r}.") + + # ---- expensive-cv guard (NON-interactive!) ---- + if not is_cheap and n_splits >= 2 and not allow_expensive_cv: + raise ValueError( + f"{strategy} {n_splits}-fold CV requires re-training the model " + f"{n_splits} times, which may take hours on a single GPU. " + f"Pass allow_expensive_cv=True to proceed, or use " + f"cv='holdout' for a single train/valid split." + ) + if not is_cheap and n_splits >= 2: + _LOG.warning( + "%s %d-fold CV will train %d models. " + "Estimated %s. This is a non-blocking warning — training proceeds.", + strategy, + n_splits, + n_splits, + _estimate_runtime(strategy, n_splits), + ) + + # ---- build fold assignments ---- + fold_assignments: list[tuple[set[str], set[str]]] = [] + + if manifest is not None: + # Deterministic folds from split_manifest.json. + # Each part is a validation fold; test formulas are excluded. + manifest_folds, test_formulas = _build_fold_groups(manifest) + if not manifest_folds: + raise ValueError("Manifest contains no non-empty folds.") + + # Exclude test formulas from CV + if test_formulas: + _LOG.info( + "Excluding %d test formula(s) from cross_validate: %s", + len(test_formulas), + sorted(test_formulas)[:10], + ) + + for fi, fold_formulas in enumerate(manifest_folds): + val_groups = set(fold_formulas) + train_groups: set[str] = set() + for fj, other in enumerate(manifest_folds): + if fj != fi: + train_groups.update(other) + # Remove test formulas from both sides + val_groups -= test_formulas + train_groups -= test_formulas + if val_groups and train_groups: + fold_assignments.append((train_groups, val_groups)) + else: + # Deterministic GroupKFold: sort groups, split by index (no shuffle). + # Reproducible given the same set of systems and groups. + groups_sorted = list(unique_groups) # already sorted from dict keys + + if n_splits == 1: + n_val = max(1, n_groups // 5) + val_groups = set(groups_sorted[:n_val]) + train_groups = set(groups_sorted[n_val:]) + fold_assignments.append((train_groups, val_groups)) + else: + fold_size = n_groups // n_splits + for fi in range(n_splits): + start = fi * fold_size + end = start + fold_size if fi < n_splits - 1 else n_groups + val_groups = set(groups_sorted[start:end]) + train_groups = set(groups_sorted[:start]) | set(groups_sorted[end:]) + fold_assignments.append((train_groups, val_groups)) + + # ---- ensure per-system descriptor cache (once, lazy) ---- + # This reuses existing desc_mean.npy when present, extracts only missing + # systems one-by-one. Peak memory is one system's descriptors at a time. + if is_cheap: + from dpa_adapt.finetuner import ( + ensure_per_system_cache, + ) + + ensure_per_system_cache( + systems, + pretrained=model.pretrained, + model_branch=model.model_branch, + pooling=model.pooling, + type_map=getattr(model, "type_map", None), + ) + + # ---- per-fold loop (reads per-system cache on demand) ---- + train_mae_list, test_mae_list = [], [] + test_rmse_list, test_r2_list = [], [] + + for train_groups, val_groups in fold_assignments: + if is_cheap: + Xtr, ytr = _assemble_from_per_system_cache( + systems, + groups, + train_groups, + label_key, + granularity, + pretrained=model.pretrained, + model_branch=model.model_branch, + pooling=model.pooling, + type_map=getattr(model, "type_map", None), + ) + Xva, yva = _assemble_from_per_system_cache( + systems, + groups, + val_groups, + label_key, + granularity, + pretrained=model.pretrained, + model_branch=model.model_branch, + pooling=model.pooling, + type_map=getattr(model, "type_map", None), + ) + if Xtr.shape[0] == 0 or Xva.shape[0] == 0: + continue + + predictor_type = getattr(model, "_predictor_type", None) + if predictor_type is None: + predictor_type = getattr(model, "predictor", "linear") + # Map the public API name to the internal _predictor_type + if predictor_type == "ridge": + predictor_type = "linear" + head = make_pipeline( + StandardScaler(), + _build_sklearn_head(predictor_type, seed=seed), + ) + head.fit(Xtr, ytr) + + pred_tr = head.predict(Xtr) + pred_va = head.predict(Xva) + + train_mae_list.append(float(np.mean(np.abs(pred_tr - ytr)))) + test_mae_list.append(float(np.mean(np.abs(pred_va - yva)))) + test_rmse_list.append(float(np.sqrt(np.mean((pred_va - yva) ** 2)))) + if len(yva) >= 3: + ss_res = np.sum((pred_va - yva) ** 2) + ss_tot = np.sum((yva - yva.mean()) ** 2) + r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan") + else: + r2 = float("nan") + test_r2_list.append(r2) + + # Release fold arrays before the next fold + del Xtr, ytr, Xva, yva, pred_tr, pred_va + else: + # Training paradigms — delegate to per-fold fit/evaluate. + # Phase 2 will wire this to DPATrainer / MFTFineTuner. + raise NotImplementedError( + "cross_validate for training paradigms " + "(frozen_head / finetune / mft) is not yet " + "implemented. Use frozen_sklearn for now." + ) + + # ---- warnings ---- + warnings: list[str] = [] + if min_groups_warn > 0 and n_groups < min_groups_warn: + warnings.append( + f"Only {n_groups} independent groups; CV metrics have high " + f"variance. Report per-fold values, not just mean ± std. " + f"(min_groups_warn={min_groups_warn}, set to 0 to suppress)" + ) + if granularity == "frame" and n_groups < 100: + warnings.append( + "granularity='frame': labels repeat within each group. " + "n_independent is the true sample size." + ) + + # ---- aggregate ---- + agg = {} + for name, lst in [ + ("mae", test_mae_list), + ("rmse", test_rmse_list), + ("r2", test_r2_list), + ]: + vals = [v for v in lst if not np.isnan(v)] + if vals: + agg[f"{name}_mean"] = float(np.mean(vals)) + agg[f"{name}_std"] = float(np.std(vals)) + + return { + "train_mae": train_mae_list, + "test_mae": test_mae_list, + "test_rmse": test_rmse_list, + "test_r2": test_r2_list, + "aggregate": agg, + "n_independent": n_groups, + "warnings": warnings, + "granularity": granularity, + } + + +# --------------------------------------------------------------------------- +# internal: runtime estimate +# --------------------------------------------------------------------------- + + +def _estimate_runtime(strategy: str, n_splits: int) -> str: + per_run = { + "frozen_head": "~5-15 min/run", + "finetune": "~10-30 min/run", + "mft": "~20-60 min/run", + }.get(strategy, "unknown") + return f"{n_splits} x {per_run}" diff --git a/dpa_adapt/data/__init__.py b/dpa_adapt/data/__init__.py new file mode 100644 index 0000000000..1a9a601b5e --- /dev/null +++ b/dpa_adapt/data/__init__.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Data loading, conversion, validation, and SMILES/type-map utilities. + +All public names are lazily imported so that ``import dpa_adapt.data`` +(and therefore ``dpa-adapt --help`` / ``dpaad --help``) does not pull in +dpdata, torch, or rdkit. +""" + +_LAZY = { + "load_data": (".loader", "load_data"), + "load_dataset": (".dataset", "load_dataset"), + "read_checkpoint_type_map": (".type_map", "read_checkpoint_type_map"), + "read_data_type_map_union": (".type_map", "read_data_type_map_union"), + "validate_type_map_subset": (".type_map", "validate_type_map_subset"), + "convert": (".convert", "convert"), + "attach_labels": (".convert", "attach_labels"), + "check_data": (".validate", "check_data"), + "Issue": (".validate", "Issue"), + "DPADataError": (".errors", "DPADataError"), + "SmilesDataResult": (".smiles", "SmilesDataResult"), + "read_mol_coords": (".smiles", "read_mol_coords"), + "smiles_to_3d_coords": (".smiles", "smiles_to_3d_coords"), + "smiles_to_npy": (".smiles", "smiles_to_npy"), + "predict_records_from_data": (".smiles", "predict_records_from_data"), + "records_from_direct_data": (".smiles", "records_from_direct_data"), +} + +__all__ = list(_LAZY) + + +def __getattr__(name: str) -> object: + if name in _LAZY: + import importlib + + mod_name, attr_name = _LAZY[name] + mod = importlib.import_module(mod_name, __package__) + attr = getattr(mod, attr_name) + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dpa_adapt/data/convert.py b/dpa_adapt/data/convert.py new file mode 100644 index 0000000000..6c9ca4628f --- /dev/null +++ b/dpa_adapt/data/convert.py @@ -0,0 +1,655 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Format-agnostic data conversion. + +Public entry point: ``convert()`` — sniffs the input and routes to the +appropriate pipeline: SMILES tables, single structure files, +or globbed batches of structure files. +""" + +from __future__ import ( + annotations, +) + +import csv +import glob as _glob +import json +import logging +from pathlib import ( + Path, +) + +import numpy as np + +from dpa_adapt.data.validate import ( + check_data, +) + +_LOG = logging.getLogger("dpa_adapt") + +# Recognised SMILES / molecule column names (case-insensitive). +_SMILES_COLUMNS = frozenset({"smiles", "smi", "mol"}) + + +def _sniff_csv(path: str) -> set[str] | None: + """Return the set of column names from a CSV file, or ``None`` if + the file does not look like a table. + """ + try: + with open(path, newline="", encoding="utf-8") as fh: + reader = csv.DictReader(fh) + if reader.fieldnames is None: + return None + + columns = [] + for header in reader.fieldnames: + if header is None: + return None + header = header.strip() + if not header: + return None + # Reject binary/malformed files that csv.DictReader otherwise + # treats as a one-column header, e.g. b"\x00\x01\x02". + if any(ord(ch) < 32 for ch in header): + return None + columns.append(header.lower()) + return set(columns) + except Exception: + return None + + +def _sniff_xlsx(path: str) -> set[str]: + """Return the set of column names from the first sheet of an Excel file, + or ``None`` if pandas / openpyxl is not available. + """ + try: + import pandas as pd + except ImportError: + return None + try: + df = pd.read_excel(path, nrows=0, engine="openpyxl") + return {str(h).lower() for h in df.columns} + except Exception: + return None + + +def _is_smiles_input(path: str) -> bool: + """Return True if *path* looks like a CSV / Excel file whose columns + contain at least one recognised SMILES / molecule identifier. + """ + suffix = Path(path).suffix.lower() + columns: set[str] | None = None + if suffix == ".csv": + columns = _sniff_csv(path) + elif suffix in (".xlsx", ".xls"): + columns = _sniff_xlsx(path) + if columns is None: + return False + return bool(columns & _SMILES_COLUMNS) + + +# --------------------------------------------------------------------------- +# convert — the single public entry point +# --------------------------------------------------------------------------- + + +def convert( + input_path: str, + output_dir: str, + *, + fmt: str | None = None, + type_map: list[str] | None = None, + property_name: str = "Property", + property_col: str = "Property", + train_ratio: float = 0.9, + smiles_col: str = "SMILES", + mol_dir: str | None = None, + mol_template: str = "id{row}.mol", + split_seed: int | None = None, + conformer_seed: int | None = None, + overwrite: bool = False, + validate: bool = True, + strict: bool = False, + verbose: bool = True, +) -> dict: + """Convert any supported input to ``deepmd/npy``, auto-detecting the format. + + *If the input is a CSV / Excel file with SMILES columns* the call + delegates to :func:`~dpa_adapt.data.smiles.smiles_to_npy`, which + generates 3D conformers (via RDKit), splits into train/valid, and writes + the standard ``deepmd/npy`` layout. + + *If the input is a glob pattern* the call converts each matched structure + file into a mirrored output tree and writes ``manifest.json``. + + *Otherwise* the call delegates to ``dpdata`` with auto-detection (or the + explicit *fmt* if provided), converting a single structure file into + ``deepmd/npy``. + + Returns a dict with ``"method"`` and additional metadata from the chosen + backend. + """ + # --- explicit SMILES hint, or auto-sniff --- + is_smiles_fmt = isinstance(fmt, str) and fmt.lower() == "smiles" + if is_smiles_fmt or (fmt is None and _is_smiles_input(input_path)): + from dpa_adapt.data.smiles import ( + smiles_to_npy, + ) + + result = smiles_to_npy( + data={"dataset": input_path, "mol_dir": mol_dir}, + output_dir=output_dir, + property_name=property_name, + property_col=property_col, + train_ratio=train_ratio, + smiles_col=smiles_col, + mol_template=mol_template, + split_seed=split_seed, + conformer_seed=conformer_seed, + overwrite=overwrite, + ) + converted = { + "method": "smiles", + "train_systems": result.train_systems, + "valid_systems": result.valid_systems, + "type_map": result.type_map, + "samples_used": result.samples_used, + "failed_rows": result.failed_rows, + "skipped_zero": result.skipped_zero, + "skipped_overlap": result.skipped_overlap, + } + if verbose: + _LOG.info("RDKit converted samples: %s", converted["samples_used"]) + _LOG.info("RDKit failed rows : %s", len(converted["failed_rows"])) + return converted + + # --- structure glob → batch dpdata --- + input_str = str(input_path) + if any(ch in input_str for ch in "*?["): + outputs = _batch_convert( + glob_pattern=input_str, + output_dir=output_dir, + fmt=fmt or "auto", + type_map=type_map, + validate=validate, + strict=strict, + ) + return { + "method": "batch_dpdata", + "output_dirs": outputs, + "manifest": str(Path(output_dir).resolve() / "manifest.json"), + } + + # --- single structure file → dpdata --- + out = _convert_dpdata( + input_path=input_path, + output_dir=output_dir, + fmt=fmt, + type_map=type_map, + validate=validate, + strict=strict, + ) + return {"method": "dpdata", "output_dir": out} + + +# --------------------------------------------------------------------------- +# _convert_dpdata() — thin dpdata wrapper +# --------------------------------------------------------------------------- + + +def _convert_dpdata( + input_path: str, + output_dir: str, + fmt: str | None = None, + type_map: list[str] | None = None, + validate: bool = True, + strict: bool = False, +) -> str: + """Convert one structure file to ``deepmd/npy`` via ``dpdata``.""" + _convert_one( + input_path=input_path, + output_dir=str(Path(output_dir).resolve()), + fmt=fmt, + type_map=type_map, + validate=validate, + strict=strict, + ) + + return str(Path(output_dir).resolve()) + + +# --------------------------------------------------------------------------- +# _convert_one() — single-file dpdata conversion (internal helper) +# --------------------------------------------------------------------------- + + +def _convert_one( + input_path: str, + output_dir: str, + fmt: str | None = None, + type_map: list[str] | None = None, + validate: bool = True, + strict: bool = False, +) -> str: + """Convert a single structure file to ``deepmd/npy`` format. + + Internal helper called by :func:`_convert_dpdata` — do not use directly. + """ + try: + import dpdata + except ImportError as e: + raise ImportError( + "dpdata is required for format conversion. " + "Install it with: pip install dpdata" + ) from e + + output_dir = str(Path(output_dir).resolve()) + Path(output_dir).mkdir(parents=True, exist_ok=True) + + to_kwargs: dict = {} + if type_map: + to_kwargs["type_map"] = type_map + + # Try labeled first; dpdata auto-detects when fmt is None. + load_kwargs = {"fmt": fmt} if fmt and fmt != "auto" else {} + try: + sys = dpdata.LabeledSystem(str(input_path), **load_kwargs) + except Exception: + sys = dpdata.System(str(input_path), **load_kwargs) + + sys.to("deepmd/npy", output_dir, **to_kwargs) + + if validate: + # Re-load the newly-written directory to validate via dpdata API. + try: + loaded = dpdata.LabeledSystem(output_dir, fmt="deepmd/npy") + except Exception: + loaded = dpdata.System(output_dir, fmt="deepmd/npy") + for issue in check_data(loaded, strict=strict): + _LOG.warning("[Validation] %s", issue.description) + + return output_dir + + +# --------------------------------------------------------------------------- +# _batch_convert() — glob many inputs into a mirrored deepmd/npy tree +# --------------------------------------------------------------------------- + + +def _glob_base(pattern: str) -> Path: + """The fixed (non-wildcard) directory prefix of a glob pattern. + + Used to compute each match's path relative to the part of the pattern the + user actually typed, so the output tree mirrors the input tree. For + ``./calcs/**/OUTCAR`` the base is ``./calcs``. + """ + base_parts: list[str] = [] + for part in Path(pattern).parts: + if any(ch in part for ch in "*?["): + break + base_parts.append(part) + base = Path(*base_parts) if base_parts else Path(".") + # A pattern with no wildcard at all resolves to a file; mirror from its + # parent so the single match still lands in its own subdirectory. + if base.is_file(): + base = base.parent + return base + + +def _batch_convert( + glob_pattern: str, + output_dir: str, + fmt: str, + type_map: list[str] | None = None, + validate: bool = True, + strict: bool = False, + recursive: bool = True, +) -> list[str]: + """ + Convert every file matching a glob pattern to deepmd/npy in one call. + + The input directory tree is mirrored under ``output_dir``: a match at + ``/sub/run/OUTCAR`` (where ```` is the non-wildcard prefix of + ``glob_pattern``) is written to ``/sub/run/OUTCAR/``. Using + the file stem as the leaf directory keeps the layout collision-free even + when one input directory holds several convertible files. + + A ``manifest.json`` recording inputs, outputs, and skipped files is + written into ``output_dir``. + + Parameters + ---------- + glob_pattern : str + Glob pattern for the input files, e.g. ``"./calcs/**/OUTCAR"``. + output_dir : str + Root directory for the mirrored deepmd/npy output tree. + fmt : str + dpdata format string, applied to every match. + type_map : list[str], optional + Ordered element symbol list, passed through to ``convert()``. + validate : bool + Passed through to the dpdata converter. + strict : bool + If True, the first failure (a conversion error or, when ``validate`` + is on, a validation issue) raises instead of being skipped. If False + (default), failures are logged and skipped, and conversion continues. + recursive : bool + If True (default), ``**`` in the pattern matches across directories. + + Returns + ------- + list[str] + Resolved paths of the successfully created deepmd/npy directories, + in sorted input order. Feeds directly into ``load_data()``. + """ + output_root = Path(output_dir).resolve() + output_root.mkdir(parents=True, exist_ok=True) + + base = _glob_base(glob_pattern) + matches = sorted(_glob.glob(glob_pattern, recursive=recursive)) + if not matches: + raise FileNotFoundError(f"No files matched pattern: {glob_pattern}") + + converted: list[dict] = [] + skipped: list[dict] = [] + + for input_path in matches: + in_path = Path(input_path) + if not in_path.is_file(): + skipped.append( + { + "input": str(in_path), + "error": "matched path is not a file", + } + ) + continue + try: + rel = in_path.relative_to(base) + except ValueError: + rel = Path(in_path.name) + # Mirror the input tree; the file stem is the leaf system directory. + out_sub = output_root / rel.parent / in_path.stem + try: + out = _convert_dpdata( + input_path=str(in_path), + output_dir=str(out_sub), + fmt=fmt, + type_map=type_map, + validate=validate, + strict=strict, + ) + converted.append({"input": str(in_path), "output": out}) + except Exception as e: + if strict: + raise + # Drop the output subdir if conversion created it but wrote + # nothing — an empty dir would just make load_data() and the + # split_* helpers choke later, and keeps the return value in + # sync with what's actually on disk. A half-written dir (dpdata + # crashed mid-write) is kept for debugging. + if out_sub.exists() and not any(out_sub.iterdir()): + try: + out_sub.rmdir() + except OSError: + pass # races / permissions — don't block the batch + _LOG.warning("[convert] skipping %s: %s", in_path, e) + skipped.append({"input": str(in_path), "error": str(e)}) + + manifest = { + "glob_pattern": glob_pattern, + "fmt": fmt, + "type_map": type_map, + "converted": converted, + "skipped": skipped, + } + manifest_path = output_root / "manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2)) + + _LOG.info( + "[convert] %d converted, %d skipped — manifest: %s", + len(converted), + len(skipped), + manifest_path, + ) + + return [c["output"] for c in converted] + + +# --------------------------------------------------------------------------- +# attach_labels() — property label injection using fit()'s head language +# --------------------------------------------------------------------------- + +# Dict head types we know how to map to a DeePMD-kit data key. +# Anything outside this set is likely a typo; users should pass a plain string +# (e.g. head="force") for ad-hoc keys not listed here. +_KNOWN_DICT_HEAD_TYPES = frozenset({"property", "dos", "dipole", "polar"}) + + +def _key_from_head(head: str | dict) -> str: + """Derive the deepmd/npy filename key from a head specification. + + DeePMD-kit stores label ``key`` as ``set.*/key.npy``. This function maps + the same ``head`` vocabulary used by ``DPAFineTuner.fit()`` to that key. + + - ``str`` → key is the string itself (``"energy"`` → ``energy.npy``) + - ``dict`` with ``"property_name"`` + → key is ``head["property_name"]`` + (used with ``"type": "property"`` heads; confirmed by DeePMD-kit + ``PropertyFittingNet`` docstring: "If the data file is named + ``humo.npy``, this parameter should be ``'humo'``.") + - ``{"type": "dos", ...}`` → ``dos.npy`` + - ``{"type": "dipole", ...}`` → ``dipole.npy`` + - ``{"type": "polar", ...}`` → ``polar.npy`` + + Unknown dict ``type`` values raise ``ValueError`` with the supported list, + rather than silently writing a file DeePMD-kit will never find. + """ + if isinstance(head, str): + return head + + if isinstance(head, dict): + # property_name present → that IS the data key (overrides type check) + if "property_name" in head: + return head["property_name"] + + htype = head.get("type") + if htype is None: + raise ValueError( + "head dict must contain 'property_name' or 'type'. " + f"Got keys: {sorted(head.keys())}" + ) + + if htype not in _KNOWN_DICT_HEAD_TYPES: + raise ValueError( + f"Unknown dict head type {htype!r}. " + f"Supported types: {sorted(_KNOWN_DICT_HEAD_TYPES)}. " + f"For ad-hoc keys, pass a plain string instead: head={htype!r}" + ) + + if htype == "property": + # "property" is a meta-type: the real key comes from property_name. + # We already handled property_name above, so if we're here it's missing. + raise ValueError( + "head type 'property' requires a 'property_name' key " + "(DeePMD-kit will read '{property_name}.npy'). " + "Example: head={'type': 'property', 'property_name': 'bandgap', 'task_dim': 1}" + ) + + # dos / dipole / polar: key == type name + return htype + + raise TypeError(f"head must be str or dict, got {type(head).__name__!r}") + + +def _attach_single( + sys_path: str | Path, + head: str | dict, + values: np.ndarray, +) -> None: + """Write label values to the set.*/ directory of a single deepmd/npy system. + + Parameters + ---------- + sys_path : str | Path + Path to a single deepmd/npy system directory containing set.*/ subdirs. + head : str | dict + Property head specification — resolved to a .npy filename via + :func:`_key_from_head`. + values : np.ndarray + Per-frame label array. First axis must match the frame count in + ``set.*/coord.npy``. + + Raises + ------ + ValueError + If *sys_path* is not a directory, no set.*/ dirs are found, + coord.npy is missing, or the frame count mismatches. + NotImplementedError + If more than one set.*/ directory exists (multi-set not yet supported). + """ + sys_path = Path(sys_path) + if not sys_path.is_dir(): + raise ValueError(f"System path is not a directory: {sys_path}") + + key = _key_from_head(head) + values = np.asarray(values, dtype=np.float64) + + set_dirs = sorted(sys_path.glob("set.*")) + if not set_dirs: + raise ValueError( + f"No set.* directories found in {sys_path} — " + "is this a valid deepmd/npy system directory?" + ) + if len(set_dirs) > 1: + raise NotImplementedError( + f"Multiple set.* directories found in {sys_path}. " + "attach_labels currently supports single-set systems only. " + f"Found: {[d.name for d in set_dirs]}" + ) + + set_dir = set_dirs[0] + coord_path = set_dir / "coord.npy" + if not coord_path.is_file(): + raise ValueError(f"coord.npy not found in {set_dir}. Expected at: {coord_path}") + + coords = np.load(coord_path) + n_frames = coords.shape[0] + + if values.shape[0] != n_frames: + raise ValueError( + f"values has {values.shape[0]} frames but system " + f"contains {n_frames} frames (from {coord_path})." + ) + + np.save(str(set_dir / f"{key}.npy"), values) + + +def attach_labels( + data: str | Path, + head: str | dict, + values: np.ndarray, +) -> None: + """Inject label values into one or more deepmd/npy systems. + + Auto-detects single vs multi-system input: + + - **Single system**: *data* contains ``set.*/`` directories directly. + *values* must match the frame count (``values.shape[0] == n_frames``). + - **Multi system**: *data* contains subdirectories (``sys_0000/``, + ``sys_0001/``, …); systems are matched to *values* in ``sorted()`` + order. *values* must have ``values.shape[0] == n_systems`` and + each element is written to the corresponding system's ``set.*/`` dir. + + Labels are written as ``set.*/{key}.npy`` on disk, where *key* is + resolved from *head* via :func:`_key_from_head`. + + Parameters + ---------- + data : str | Path + Path to a single deepmd/npy system (contains ``set.*/`` subdirs) or + a parent directory containing system subdirectories. + + head : str | dict + Property head specification — same vocabulary as + ``DPAFineTuner(head=...)``: + + - ``"energy"`` → writes ``set.*/energy.npy`` + - ``"bandgap"`` (any plain string) → writes ``set.*/bandgap.npy`` + - ``{"type": "property", "property_name": "bandgap", "task_dim": 1}`` + → writes ``set.*/bandgap.npy`` + - ``{"type": "dos", "numb_dos": 250}`` → writes ``set.*/dos.npy`` + + values : np.ndarray + For single-system: shape ``(n_frames,)`` or ``(n_frames, dim)``. + For multi-system: shape ``(n_systems,)`` or ``(n_systems, dim)``; + each element is assigned to the corresponding system directory + (in ``sorted()`` order). + + Raises + ------ + ValueError + If *data* is not a directory, has an unrecognised structure, + or the frame / system count mismatches. + NotImplementedError + If a system has more than one ``set.*/`` directory. + + Notes + ----- + **Idempotency**: calling ``attach_labels`` twice with the *same* head on + the same system overwrites the existing file. Calling with *different* + heads writes separate ``.npy`` files. + + Examples + -------- + Single system: + + >>> attach_labels("sys_0000/", head="bandgap", values=np.array([1.0])) + + Multi system — ``values[i]`` → ``sorted(glob("npy/*/"))[i]``: + + >>> labels = np.load("labels.npy") # shape (n_systems,) + >>> attach_labels("./npy/", head="bandgap", values=labels) + + CLI (works for both single and multi-system): + + .. code-block:: bash + + dpaad data attach-labels --data ./npy/ --head bandgap --values labels.npy + """ + data = Path(data) + if not data.is_dir(): + raise ValueError(f"Data path is not a directory: {data}") + + # Detect single-system: set.*/ subdirs directly under data + has_set_dirs = any(p.is_dir() and p.name.startswith("set.") for p in data.iterdir()) + + if has_set_dirs: + _attach_single(data, head, values) + return + + # Multi-system: glob non-hidden subdirectories as system dirs + sys_dirs = sorted( + p for p in data.iterdir() if p.is_dir() and not p.name.startswith(".") + ) + if not sys_dirs: + raise ValueError( + f"No set.* directories or system subdirectories found " + f"in {data}.\n" + "Expected either:\n" + " (a) a single system with set.*/ subdirs, or\n" + " (b) a parent directory containing system subdirectories\n" + " (each with their own set.*/)." + ) + + values_arr = np.asarray(values) + if values_arr.shape[0] != len(sys_dirs): + raise ValueError( + f"values has {values_arr.shape[0]} entries along the first " + f"axis but found {len(sys_dirs)} system directories in {data}. " + "In multi-system mode, values.shape[0] must equal the number " + "of system subdirectories (sorted alphabetically)." + ) + + for sys_dir, sub_vals in zip(sys_dirs, values_arr, strict=True): + if np.isscalar(sub_vals): + sub_vals = np.asarray([sub_vals]) + _attach_single(sys_dir, head, sub_vals) diff --git a/dpa_adapt/data/dataset.py b/dpa_adapt/data/dataset.py new file mode 100644 index 0000000000..234fc4fa6f --- /dev/null +++ b/dpa_adapt/data/dataset.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# data/dataset.py +# +# Label-aware data loading for supervised training / fine-tuning. +# Thin layer on top of load_data() that additionally verifies every +# system carries the requested label key (e.g. "energy", "homo"). + +from __future__ import ( + annotations, +) + +import logging +from pathlib import ( + Path, +) + +import dpdata + +from dpa_adapt.data.errors import ( + DPADataError, +) +from dpa_adapt.data.loader import ( + _resolve_label_key, + load_data, +) + +_LOG = logging.getLogger("dpa_adapt.data.dataset") + +_DataInput = ( + str + | Path + | dpdata.System + | dpdata.LabeledSystem + | list[str | Path | dpdata.System | dpdata.LabeledSystem] +) + + +def load_dataset( + data: _DataInput, + label_key: str = "energy", +) -> list[dpdata.LabeledSystem]: + """ + Load systems and keep only those that carry *label_key*. + + Internally calls ``load_data()`` to normalise input, then inspects each + system's ``data`` dict for the requested label. Systems that lack the + label are skipped with a warning rather than raising, so a partial + dataset (e.g. a directory tree where only some systems have energies) + does not block downstream work. + + Parameters + ---------- + data : str | Path | dpdata.System | dpdata.LabeledSystem | list + Any input accepted by ``load_data()`` — single path, glob string, + dpdata object, or heterogeneous list of the above. + label_key : str + Label key to check in each system's ``data`` dict (e.g. + ``"energy"``, ``"force"``, ``"homo"``). Default ``"energy"``. + + Returns + ------- + list[dpdata.LabeledSystem] + Systems that passed label validation. May be empty only if + *every* candidate was skipped, in which case a ``DPADataError`` + is raised (fail-fast for training workflows). + """ + from dpa_adapt.data.loader import ( + _find_label_npys, + _get_source, + ) + + systems = load_data(data) + + resolved_key = _resolve_label_key(label_key) + + validated: list[dpdata.LabeledSystem] = [] + skipped: list[str] = [] + + for i, system in enumerate(systems): + # dpdata stores everything (coords, energies, forces, ...) in the + # ``data`` dict; label_key (after alias resolution) presence is the litmus test. + if resolved_key in system.data: + validated.append(system) + continue + + # Fallback: custom labels such as "homo.npy"/"bandgap.npy" under set.*/ + # are not generally loaded into dpdata.System.data, so check for the + # label file directly (shared discovery with _load_labels()). + source = _get_source(system) + if source is not None and _find_label_npys(source, resolved_key): + validated.append(system) + continue + + identifier = getattr(system, "_dpa_source", f"system[{i}]") + skipped.append(f"{identifier} (missing {resolved_key!r})") + + if skipped: + _LOG.warning( + "load_dataset: %d system(s) skipped (missing label key %r):\n %s", + len(skipped), + resolved_key, + "\n ".join(skipped), + ) + + if not validated: + raise DPADataError( + f"load_dataset: no valid systems found with label_key={label_key!r} " + f"(resolved to {resolved_key!r}). " + f"Skipped {len(skipped)} candidate(s). " + "Check that the path and label_key are correct." + ) + + return validated diff --git a/dpa_adapt/data/desc_cache.py b/dpa_adapt/data/desc_cache.py new file mode 100644 index 0000000000..92bd240539 --- /dev/null +++ b/dpa_adapt/data/desc_cache.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# data/desc_cache.py +# +# Transparent on-disk cache for extracted DPA descriptors. +# Two-tier: (1) per-system cache keyed by lightweight content hash, +# (2) bulk cache under ``~/.cache/dpa_adapt/desc_cache/`` keyed by +# (aggregate data fingerprint, checkpoint identity, branch, pooling). +# +# Systems are ``dpdata.System`` objects; cache keys are computed from +# data fingerprints and resolved checkpoint metadata. +# +# Note: ``load_or_extract()`` and ``ensure_per_system_cache()`` live in +# ``dpa_adapt.finetuner`` to avoid an import cycle (those functions need +# ``DPAFineTuner``, while ``finetuner`` imports cache helpers from here). + +from __future__ import ( + annotations, +) + +import hashlib +import os +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, +) + +import numpy as np + +from dpa_adapt._backend import ( + resolve_pretrained_path, +) + +if TYPE_CHECKING: + import dpdata + + +# --------------------------------------------------------------------------- +# cache directory +# --------------------------------------------------------------------------- + + +def _cache_dir() -> Path: + base = os.environ.get("XDG_CACHE_HOME", os.path.join(str(Path.home()), ".cache")) + return Path(base) / "dpa_adapt" / "desc_cache" + + +# --------------------------------------------------------------------------- +# system fingerprint (O(n) over the full descriptor-relevant arrays) +# --------------------------------------------------------------------------- + + +def _hash_array(h: hashlib._Hash, arr: np.ndarray) -> None: + """Fold an array's shape, dtype, and full byte content into *h*. + + The contiguous buffer is fed to :meth:`hashlib._Hash.update` directly via + the buffer protocol, so no large intermediate ``bytes`` copy is made. + """ + arr = np.ascontiguousarray(arr) + h.update(str(arr.shape).encode()) + h.update(str(arr.dtype).encode()) + h.update(arr) + + +def _system_fingerprint(system: dpdata.System) -> str: + """Return a hex fingerprint for a dpdata System. + + Hashes the *full* contents of the descriptor-relevant arrays — ``coords``, + ``cells`` and ``atom_types`` — together with ``atom_names``. Sampling + only the first/last few entries (as an earlier version did) let any change + in the middle of a long trajectory keep the same key, so the cache could + return descriptors extracted from a different structure. Hashing every + element costs O(total array size), but that is negligible next to the + descriptor extraction the cache guards, and it makes the key collision-safe + for changed systems. + """ + d = system.data + + h = hashlib.sha1() + # atom-type identity + _hash_array(h, np.asarray(d["atom_types"])) + # atom_names (if present) + names = d.get("atom_names", []) + h.update("|".join(str(n) for n in names).encode()) + # full geometry + _hash_array(h, np.asarray(d["coords"])) + if "cells" in d: + _hash_array(h, np.asarray(d["cells"])) + return h.hexdigest()[:16] + + +def _data_fingerprint(systems: list) -> str: + """Aggregate fingerprint for a list of systems in request order.""" + fps = [_system_fingerprint(s) for s in systems] + h = hashlib.sha1() + for fp in fps: + h.update(fp.encode()) + return h.hexdigest() + + +def _checkpoint_fingerprint(pretrained: str) -> str: + resolved = Path(resolve_pretrained_path(pretrained)).resolve() + stat = resolved.stat() + payload = f"{resolved}|{stat.st_mtime_ns}|{stat.st_size}" + return hashlib.sha1(payload.encode()).hexdigest()[:16] + + +def _type_map_payload(type_map: list[str] | tuple[str, ...] | None) -> str: + if not type_map: + return "" + return "\x1f".join(str(item) for item in type_map) + + +def _cache_key( + systems: list, + pretrained: str, + model_branch: str | None, + pooling: str, + *, + type_map: list[str] | tuple[str, ...] | None = None, +) -> str: + fp = _data_fingerprint(systems) + ckpt_fp = _checkpoint_fingerprint(pretrained) + tm = _type_map_payload(type_map) + payload = f"{fp}|{ckpt_fp}|{model_branch or ''}|{pooling}|{tm}" + return hashlib.sha1(payload.encode()).hexdigest()[:16] + + +# --------------------------------------------------------------------------- +# per-system cache path helpers +# --------------------------------------------------------------------------- + + +def _per_system_cache_path( + system: dpdata.System, + pretrained: str, + model_branch: str | None = None, + pooling: str = "mean", + type_map: list[str] | tuple[str, ...] | None = None, +) -> Path: + """Return the cache path for one system under a descriptor identity.""" + system_fp = _system_fingerprint(system) + ckpt_fp = _checkpoint_fingerprint(pretrained) + tm = _type_map_payload(type_map) + payload = f"{system_fp}|{ckpt_fp}|{model_branch or ''}|{pooling}|{tm}" + fp = hashlib.sha1(payload.encode()).hexdigest()[:16] + return _cache_dir() / "per_system" / f"{fp}.npy" + + +def get_per_system_descriptor( + system: dpdata.System, + pretrained: str, + model_branch: str | None = None, + pooling: str = "mean", + type_map: list[str] | tuple[str, ...] | None = None, +) -> np.ndarray: + """Read cached descriptors for one system and descriptor identity. + + Raises ``FileNotFoundError`` if the cache file does not exist. + """ + cache_path = _per_system_cache_path( + system, + pretrained, + model_branch, + pooling, + type_map, + ) + if not cache_path.is_file(): + raise FileNotFoundError( + f"Per-system descriptor cache not found: {cache_path}\n" + f"Run ensure_per_system_cache() first." + ) + return np.load(cache_path) diff --git a/dpa_adapt/data/errors.py b/dpa_adapt/data/errors.py new file mode 100644 index 0000000000..aeabad8229 --- /dev/null +++ b/dpa_adapt/data/errors.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# data/errors.py + + +class DPADataError(Exception): + """Raised when data loading or format detection fails.""" + + pass diff --git a/dpa_adapt/data/loader.py b/dpa_adapt/data/loader.py new file mode 100644 index 0000000000..56f39cd765 --- /dev/null +++ b/dpa_adapt/data/loader.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# data/loader.py +# +# Polymorphic entry point: normalises str / Path / glob / dpdata objects +# into a flat list[dpdata.System]. Disk I/O and format detection are +# delegated to dpdata. + +from __future__ import ( + annotations, +) + +import glob as _glob +from pathlib import ( + Path, +) + +import dpdata + +from dpa_adapt.data.errors import ( + DPADataError, +) + +_SOURCE_ATTR = "_dpa_source" + +# Backward-compat key aliases: old code used "energy"/"force" but dpdata +# stores them as "energies"/"forces". Single source of truth — all other +# modules import from here. +_LABEL_KEY_ALIASES = { + "energy": "energies", + "force": "forces", +} + + +def _resolve_label_key(key: str) -> str: + """Map legacy label keys to dpdata's canonical names.""" + return _LABEL_KEY_ALIASES.get(key, key) + + +# Type alias covering every form the public API accepts. +_SystemLike = str | Path | dpdata.System | dpdata.LabeledSystem +_DataInput = _SystemLike | list[_SystemLike] + + +def _get_source(system: dpdata.System) -> str | None: + """Return the source path stored on a system, or None.""" + return getattr(system, _SOURCE_ATTR, None) + + +def _find_label_npys(source: str | Path, key: str) -> list[Path]: + """Return existing ``set.*/{key}.npy`` paths under *source*, sorted by set. + + Custom labels (e.g. ``homo.npy``, ``bandgap.npy``) are not loaded into + ``dpdata.System.data``. This is the shared label-discovery used by both + ``dataset.load_dataset`` (existence check) and ``finetuner._load_labels`` + (loading), so the two stay in sync. + """ + source_path = Path(source) + return [ + npy + for set_dir in sorted(source_path.glob("set.*")) + if (npy := set_dir / f"{key}.npy").exists() + ] + + +def load_data( + data: _DataInput, + fmt: str | None = None, +) -> list[dpdata.System]: + """ + Normalise arbitrary data input into a flat list of ``dpdata.System``. + + This is the single polymorphic entry point for all data in dpa_adapt. + Every internal consumer receives its data through this function so that + disk-access logic lives in exactly one place. + + Parameters + ---------- + data : str | Path | dpdata.System | dpdata.LabeledSystem | list + - **str / Path** — a deepmd/npy system directory (or any path that + dpdata can open). If the string contains glob wildcards (``*``, + ``?``) it is expanded and every match is loaded. + - **dpdata.System / dpdata.LabeledSystem** — passed through as-is + (no deep copy). + - **list** — each element is processed recursively and the results + are flattened into a single list. + fmt : str, optional + dpdata format string. Defaults to ``"deepmd/npy"`` for paths; + ignored when *data* is already a dpdata object. + + Returns + ------- + list[dpdata.System] + One ``dpdata.System`` (or ``LabeledSystem``) per resolved input. + """ + # 1. List → recurse and flatten + if isinstance(data, list): + result: list[dpdata.System] = [] + for item in data: + result.extend(load_data(item, fmt=fmt)) + return result + + # 2. Glob string → expand, then recurse + if isinstance(data, str) and _glob.has_magic(data): + matches = sorted(Path(p) for p in _glob.glob(data)) + if not matches: + raise DPADataError( + f"Glob pattern {data!r} matched no files or directories." + ) + + # Fail-fast: deepmd/npy (the default) only works on directories. + load_fmt = fmt if fmt is not None else "deepmd/npy" + if load_fmt == "deepmd/npy": + non_dirs = [str(m) for m in matches if not m.is_dir()] + if non_dirs: + raise DPADataError( + f"Glob pattern {data!r} matched non-directory paths " + f"incompatible with fmt={load_fmt!r}: {non_dirs}. " + "Pass fmt= explicitly or load these separately." + ) + + result: list[dpdata.System] = [] + for match in matches: + result.extend(load_data(match, fmt=fmt)) + return result + + # 3. dpdata object → pass through (no copy) + if isinstance(data, (dpdata.System, dpdata.LabeledSystem)): + return [data] + + # 4. str / Path → delegate to dpdata + path = str(data) + if not Path(path).exists(): + raise DPADataError(f"Path does not exist: {path!r}") + + load_fmt = fmt if fmt is not None else "deepmd/npy" + + # Try labeled first so that training labels are preserved when present. + try: + system: dpdata.System = dpdata.LabeledSystem(path, fmt=load_fmt) + except Exception: + try: + system = dpdata.System(path, fmt=load_fmt) + except Exception as exc: + raise DPADataError( + f"Failed to load {path!r} via dpdata (fmt={load_fmt!r}): {exc}" + ) from exc + + # Stamp source path so downstream consumers (e.g. cv formula extraction) + # can recover the original filesystem location. + setattr(system, _SOURCE_ATTR, str(Path(path).resolve())) + + return [system] diff --git a/dpa_adapt/data/smiles.py b/dpa_adapt/data/smiles.py new file mode 100644 index 0000000000..955827f9c0 --- /dev/null +++ b/dpa_adapt/data/smiles.py @@ -0,0 +1,815 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""SMILES → 3D coordinates → deepmd/npy conversion. + +Provides the molecular data ingestion pipeline originally from +``dpa_adapt``: + +- Parse CSV files with SMILES (or pre-generated MOL files) and property labels +- Generate 3D conformers via RDKit (ETKDGv3 + MMFF/UFF optimisation) +- Validate structures (zero-coordinate rejection, overlapping-atom detection) +- Write ``deepmd/npy`` directories consumable by ``DPAFineTuner`` and friends +""" + +from __future__ import ( + annotations, +) + +import csv +import random +import re +import shutil +import warnings +from dataclasses import ( + dataclass, +) +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import numpy as np + +# Period table, used to build a consistent per-checkpoint type_map. +ELEMENTS = np.array( + [ + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og", + ] +) +ELEMENT_INDEX: dict[str, int] = {name: i for i, name in enumerate(ELEMENTS)} + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _find_column(columns: list[str], choices: list[str]) -> str: + lower_map = {col.lower(): col for col in columns} + for choice in choices: + if choice.lower() in lower_map: + return lower_map[choice.lower()] + raise KeyError(f"None of columns {choices} found in {columns}") + + +def _parse_property_value(raw_value: object) -> float: + if isinstance(raw_value, (int, float)): + return float(raw_value) + text = str(raw_value).strip() + try: + return float(text) + except ValueError: + match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", text) + if match: + return float(match.group(0)) + raise + + +# --------------------------------------------------------------------------- +# Pre-generated structure readers +# --------------------------------------------------------------------------- + + +def read_mol_coords(path: str | Path) -> tuple[list[str], np.ndarray]: + """Parse a V2000/V3000 MOL file, returning element symbols and (natoms,3) coords.""" + mol_path = Path(path) + lines = mol_path.read_text(encoding="utf-8", errors="ignore").splitlines() + if len(lines) < 4: + raise ValueError(f"Bad MOL file (too short): {mol_path}") + + counts = lines[3] + try: + natoms = int(counts[0:3]) + except ValueError: + parts = counts.split() + if not parts: + raise ValueError(f"Bad MOL counts line: {mol_path}") from None + natoms = int(parts[0]) + + atom_lines = lines[4 : 4 + natoms] + if len(atom_lines) != natoms: + raise ValueError(f"Bad MOL atom block length: {mol_path}") + + symbols: list[str] = [] + coords: list[list[float]] = [] + for atom_line in atom_lines: + if len(atom_line) >= 34: + x = float(atom_line[0:10]) + y = float(atom_line[10:20]) + z = float(atom_line[20:30]) + symbol = atom_line[31:34].strip() + else: + parts = atom_line.split() + if len(parts) < 4: + raise ValueError(f"Bad MOL atom line: {mol_path}") + x, y, z = float(parts[0]), float(parts[1]), float(parts[2]) + symbol = parts[3] + if symbol not in ELEMENT_INDEX: + raise ValueError(f"Unknown element {symbol!r} in {mol_path}") + symbols.append(symbol) + coords.append([x, y, z]) + + return symbols, np.asarray(coords, dtype=np.float32) + + +def _read_xyz_coords(path: str | Path) -> tuple[list[str], np.ndarray]: + xyz_path = Path(path) + lines = xyz_path.read_text(encoding="utf-8", errors="ignore").splitlines() + if len(lines) < 2: + raise ValueError(f"Bad XYZ file (too short): {xyz_path}") + try: + natoms = int(lines[0].strip()) + except ValueError: + raise ValueError(f"Bad XYZ atom count line: {xyz_path}") from None + atom_lines = lines[2 : 2 + natoms] + if len(atom_lines) != natoms: + raise ValueError(f"Bad XYZ atom block length: {xyz_path}") + + symbols: list[str] = [] + coords: list[list[float]] = [] + for atom_line in atom_lines: + parts = atom_line.split() + if len(parts) < 4: + raise ValueError(f"Bad XYZ atom line: {xyz_path}") + symbol = parts[0] + if symbol not in ELEMENT_INDEX: + raise ValueError(f"Unknown element {symbol!r} in {xyz_path}") + symbols.append(symbol) + coords.append([float(parts[1]), float(parts[2]), float(parts[3])]) + return symbols, np.asarray(coords, dtype=np.float32) + + +def _read_rdkit_coords(path: str | Path) -> tuple[list[str], np.ndarray]: + structure_path = Path(path) + try: + from rdkit import ( + Chem, + ) + except ImportError as exc: + raise ImportError( + "RDKit is required to read .sdf and .pdb files from mol_dir." + ) from exc + + suffix = structure_path.suffix.lower() + if suffix == ".sdf": + supplier = Chem.SDMolSupplier(str(structure_path), removeHs=False) + mol = next((m for m in supplier if m is not None), None) + elif suffix == ".pdb": + mol = Chem.MolFromPDBFile(str(structure_path), removeHs=False) + else: + raise ValueError(f"Unsupported structure file extension: {structure_path}") + if mol is None: + raise ValueError(f"Could not read structure file: {structure_path}") + if mol.GetNumConformers() == 0: + raise ValueError(f"Structure file has no 3D conformer: {structure_path}") + + conf = mol.GetConformer() + symbols: list[str] = [] + coords: list[list[float]] = [] + for atom in mol.GetAtoms(): + symbol = atom.GetSymbol() + if symbol not in ELEMENT_INDEX: + raise ValueError(f"Unknown element {symbol!r} in {structure_path}") + pos = conf.GetAtomPosition(atom.GetIdx()) + symbols.append(symbol) + coords.append([pos.x, pos.y, pos.z]) + return symbols, np.asarray(coords, dtype=np.float32) + + +def read_structure_coords(path: str | Path) -> tuple[list[str], np.ndarray]: + structure_path = Path(path) + suffix = structure_path.suffix.lower() + if suffix == ".mol": + return read_mol_coords(structure_path) + if suffix == ".xyz": + return _read_xyz_coords(structure_path) + if suffix in {".sdf", ".pdb"}: + return _read_rdkit_coords(structure_path) + raise ValueError( + f"Unsupported pre-generated structure file extension {suffix!r}; " + "expected .mol, .sdf, .xyz, or .pdb" + ) + + +# --------------------------------------------------------------------------- +# SMILES → 3D (RDKit, lazy import) +# --------------------------------------------------------------------------- + + +def smiles_to_3d_coords( + smiles: str, + *, + random_seed: int = 42, +) -> tuple[list[str], np.ndarray]: + """Generate a 3D conformer from a SMILES string via RDKit ETKDGv3.""" + try: + from rdkit import ( + Chem, + ) + from rdkit.Chem import ( + AllChem, + ) + except ImportError as exc: + raise ImportError( + "RDKit is required to generate 3D coordinates from SMILES. " + "Install rdkit, or provide mol_dir with pre-generated MOL files." + ) from exc + + mol = Chem.MolFromSmiles(str(smiles)) + if mol is None: + raise ValueError(f"Invalid SMILES: {smiles!r}") + mol = Chem.AddHs(mol) + params = AllChem.ETKDGv3() + params.randomSeed = int(random_seed) + if hasattr(params, "maxAttempts"): + params.maxAttempts = 1000 + status = AllChem.EmbedMolecule(mol, params) + if status != 0: + params.useRandomCoords = True + status = AllChem.EmbedMolecule(mol, params) + if status != 0: + status = AllChem.EmbedMolecule( + mol, + randomSeed=int(random_seed), + useRandomCoords=True, + maxAttempts=2000, + ignoreSmoothingFailures=True, + enforceChirality=False, + ) + if status != 0: + raise ValueError(f"RDKit failed to embed 3D coordinates for SMILES: {smiles!r}") + try: + if AllChem.MMFFHasAllMoleculeParams(mol): + AllChem.MMFFOptimizeMolecule(mol, maxIters=500) + else: + AllChem.UFFOptimizeMolecule(mol, maxIters=500) + except Exception: + # MMFF optimization failed; fall back to UFF. + try: + AllChem.UFFOptimizeMolecule(mol, maxIters=500) + except Exception: + # Even UFF failed — proceed with unoptimized conformer. + pass + + conf = mol.GetConformer() + symbols: list[str] = [] + coords: list[list[float]] = [] + for atom in mol.GetAtoms(): + pos = conf.GetAtomPosition(atom.GetIdx()) + symbol = atom.GetSymbol() + if symbol not in ELEMENT_INDEX: + raise ValueError( + f"Unknown element {symbol!r} generated from SMILES {smiles!r}" + ) + symbols.append(symbol) + coords.append([pos.x, pos.y, pos.z]) + return symbols, np.asarray(coords, dtype=np.float32) + + +# --------------------------------------------------------------------------- +# validation +# --------------------------------------------------------------------------- + + +def _has_overlapping_atoms(coords: np.ndarray, tol: float) -> bool: + if coords.shape[0] < 2: + return False + diff = coords[:, np.newaxis, :] - coords[np.newaxis, :, :] + dist2 = np.sum(diff * diff, axis=-1) + np.fill_diagonal(dist2, np.inf) + return float(np.min(dist2)) < tol * tol + + +def _build_type_map_from_elements(used_elements: set[str]) -> list[str]: + return [el for el in ELEMENTS.tolist() if el in used_elements] + + +# --------------------------------------------------------------------------- +# CSV record extractors +# --------------------------------------------------------------------------- + +_Record = tuple[list[str], np.ndarray, float, int] # symbols, coords, value, row_idx + + +def _records_from_csv_mol( + dataset: str | Path, + mol_dir: str | Path, + property_col: str, + mol_template: str = "id{row}.mol", + overlap_tol: float = 1e-6, +) -> tuple[list[_Record], list[tuple[int, str, str]], int, int, list[dict[str, Any]]]: + with Path(dataset).open("r", encoding="utf-8") as fp: + rows = list(csv.DictReader(fp)) + if not rows: + raise ValueError(f"No rows found in dataset: {dataset}") + prop_col = _find_column( + list(rows[0].keys()), [property_col, "Property", "property"] + ) + + records: list[_Record] = [] + failed_rows: list[tuple[int, str, str]] = [] + skipped_zero = 0 + skipped_overlap = 0 + kept_rows: list[dict[str, Any]] = [] + for row_idx, row in enumerate(rows): + structure_path = (Path(mol_dir) / mol_template.format(row=row_idx)).resolve() + try: + symbols, coords = read_structure_coords(structure_path) + if np.allclose(coords, 0.0): + skipped_zero += 1 + continue + if _has_overlapping_atoms(coords, overlap_tol): + skipped_overlap += 1 + continue + records.append( + (symbols, coords, _parse_property_value(row[prop_col]), row_idx) + ) + kept_rows.append(dict(row)) + except Exception as exc: + failed_rows.append((row_idx, str(structure_path), str(exc))) + return records, failed_rows, skipped_zero, skipped_overlap, kept_rows + + +def _records_from_csv_smiles( + dataset: str | Path, + property_col: str, + smiles_col: str = "SMILES", + overlap_tol: float = 1e-6, + conformer_seed: int = 42, +) -> tuple[list[_Record], list[tuple[int, str, str]], int, int, list[dict[str, Any]]]: + with Path(dataset).open("r", encoding="utf-8") as fp: + rows = list(csv.DictReader(fp)) + if not rows: + raise ValueError(f"No rows found in dataset: {dataset}") + prop_col = _find_column( + list(rows[0].keys()), [property_col, "Property", "property"] + ) + smiles_column = _find_column(list(rows[0].keys()), [smiles_col, "SMILES", "smiles"]) + + records: list[_Record] = [] + failed_rows: list[tuple[int, str, str]] = [] + skipped_zero = 0 + skipped_overlap = 0 + kept_rows: list[dict[str, Any]] = [] + for row_idx, row in enumerate(rows): + smiles = row[smiles_column] + try: + symbols, coords = smiles_to_3d_coords( + smiles, random_seed=conformer_seed + row_idx + ) + if np.allclose(coords, 0.0): + skipped_zero += 1 + continue + if _has_overlapping_atoms(coords, overlap_tol): + skipped_overlap += 1 + continue + records.append( + (symbols, coords, _parse_property_value(row[prop_col]), row_idx) + ) + kept_rows.append(dict(row)) + except Exception as exc: + failed_rows.append((row_idx, smiles, str(exc))) + return records, failed_rows, skipped_zero, skipped_overlap, kept_rows + + +# --------------------------------------------------------------------------- +# public: full pipeline +# --------------------------------------------------------------------------- + + +@dataclass +class SmilesDataResult: + output_dir: Path + train_systems: list[str] + valid_systems: list[str] + type_map: list[str] + failed_rows: list[tuple[int, str, str]] + samples_used: int + skipped_zero: int + skipped_overlap: int + + +def smiles_to_npy( + data: dict[str, Any] | str | Path, + *, + output_dir: str | Path, + property_name: str = "Property", + property_col: str = "Property", + train_ratio: float = 0.9, + mol_dir: str | Path | None = None, + mol_template: str = "id{row}.mol", + smiles_col: str = "SMILES", + overlap_tol: float = 1e-6, + split_seed: int | None = None, + conformer_seed: int | None = None, + overwrite: bool = False, +) -> SmilesDataResult: + """Convert a CSV of molecules (SMILES or pre-generated structures) into ``deepmd/npy``. + + Parameters + ---------- + data + Path to a CSV file, or a dict with ``"dataset"`` key. + output_dir + Root directory for ``train/`` and ``valid/`` subdirectories. + property_name + Name of the property label (stored as ``set.*/{property_name}.npy``). + property_col + CSV column containing the target value. + train_ratio + Fraction of samples used for training (remainder = validation). + mol_dir + Directory containing pre-generated structure files. When omitted, + SMILES are converted to 3D via RDKit. + mol_template + Template for structure filenames, e.g. ``"id{row}.mol"``. Supported + extensions are ``.mol``, ``.sdf``, ``.xyz``, and ``.pdb``. + smiles_col + CSV column containing SMILES strings. + overlap_tol + Minimum inter-atomic distance (Å) below which a structure is rejected. + split_seed : int, optional + Random seed for train/valid splitting. Defaults to 42. + conformer_seed : int, optional + Random seed for RDKit 3D conformer generation. Defaults to 42. + overwrite + If True, remove *output_dir* before writing. + + Returns + ------- + SmilesDataResult + """ + import dpdata + from dpdata.data_type import ( + Axis, + DataType, + ) + + if split_seed is None: + split_seed = 42 + if conformer_seed is None: + conformer_seed = 42 + + # Register the custom property + stru_id dtypes with dpdata. + datatypes = [ + DataType(property_name, np.ndarray, shape=(Axis.NFRAMES, 1), required=False), + DataType("stru_id", np.ndarray, shape=(Axis.NFRAMES, 1), required=False), + ] + for dtype in datatypes: + dpdata.System.register_data_type(dtype) + dpdata.LabeledSystem.register_data_type(dtype) + + # --- ingest --- + if isinstance(data, (str, Path)) or (isinstance(data, dict) and "dataset" in data): + dataset = Path(data if isinstance(data, (str, Path)) else data["dataset"]) + mol_dir_value = ( + mol_dir + if mol_dir is not None + else data.get("mol_dir") + if isinstance(data, dict) + else None + ) + smiles_col_value = ( + data.get("smiles_col", smiles_col) if isinstance(data, dict) else smiles_col + ) + if mol_dir_value is None: + records, failed_rows, skipped_zero, skipped_overlap, _raw = ( + _records_from_csv_smiles( + dataset=dataset, + property_col=property_col, + smiles_col=smiles_col_value, + overlap_tol=overlap_tol, + conformer_seed=conformer_seed, + ) + ) + else: + records, failed_rows, skipped_zero, skipped_overlap, _raw = ( + _records_from_csv_mol( + dataset=dataset, + mol_dir=mol_dir_value, + property_col=property_col, + mol_template=mol_template, + overlap_tol=overlap_tol, + ) + ) + else: + atoms = data.get("atoms") + coordinates = data.get("coordinates") + targets = data.get("target", data.get("targets")) + if atoms is None or coordinates is None or targets is None: + raise ValueError("Direct data requires atoms, coordinates, and target") + records = [ + (list(s), np.asarray(c, dtype=np.float32), float(t), i) + for i, (s, c, t) in enumerate( + zip(atoms, coordinates, targets, strict=False) + ) + ] + failed_rows, skipped_zero, skipped_overlap = [], 0, 0 + + for row_idx, source, error in failed_rows: + warnings.warn( + f"Skipping row {row_idx}: {source!r} — {error}", + RuntimeWarning, + ) + + # --- deduplicate elements → type_map --- + used_elements = {symbol for symbols, _, _, _ in records for symbol in symbols} + type_map = _build_type_map_from_elements(used_elements) + if not type_map: + raise RuntimeError("No usable elements found after filtering.") + type_index = {el: i for i, el in enumerate(type_map)} + + # --- build dpdata systems --- + systems: list[dpdata.LabeledSystem] = [] + for symbols, coords, property_value, row_idx in records: + natoms = len(symbols) + if coords.shape != (natoms, 3): + raise ValueError(f"coords shape mismatch for row {row_idx}: {coords.shape}") + atom_types = np.array([type_index[s] for s in symbols], dtype=np.int32) + frame_data = { + "orig": np.array([0, 0, 0], dtype=np.int32), + "atom_names": type_map, + "atom_numbs": [ + np.count_nonzero(atom_types == i) for i in range(len(type_map)) + ], + "atom_types": atom_types, + "cells": np.array( + [[[100.0, 0.0, 0.0], [0.0, 100.0, 0.0], [0.0, 0.0, 100.0]]] + ), + "nopbc": True, + "coords": coords[np.newaxis, :, :].astype(np.float32), + "energies": np.zeros((1,), dtype=np.float32), + "forces": np.zeros((1, natoms, 3), dtype=np.float32), + property_name: np.array([[property_value]], dtype=np.float32), + "stru_id": np.array([[row_idx]], dtype=np.int64), + } + systems.append(dpdata.LabeledSystem(data=frame_data, type_map=type_map)) + + n_total = len(systems) + if n_total < 2: + raise RuntimeError(f"Not enough usable samples: {n_total}") + + # --- train / valid split --- + output_path = Path(output_dir).resolve() + if overwrite and output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + rng = random.Random(split_seed) + indices = list(range(n_total)) + rng.shuffle(indices) + train_count = max(1, min(int(n_total * train_ratio), n_total - 1)) + + ms_train = dpdata.MultiSystems() + ms_valid = dpdata.MultiSystems() + for idx in indices[:train_count]: + ms_train.append(systems[idx]) + for idx in indices[train_count:]: + ms_valid.append(systems[idx]) + + train_dir = output_path / "train" + valid_dir = output_path / "valid" + ms_train.to_deepmd_npy_mixed(str(train_dir)) + ms_valid.to_deepmd_npy_mixed(str(valid_dir)) + + train_systems = sorted(str(p) for p in train_dir.iterdir() if p.is_dir()) + valid_systems = sorted(str(p) for p in valid_dir.iterdir() if p.is_dir()) + + return SmilesDataResult( + output_dir=output_path, + train_systems=train_systems, + valid_systems=valid_systems, + type_map=type_map, + failed_rows=failed_rows, + samples_used=n_total, + skipped_zero=skipped_zero, + skipped_overlap=skipped_overlap, + ) + + +def records_from_direct_data( + data: dict[str, Any], +) -> tuple[list[_Record], list[dict[str, Any]]]: + atoms = data.get("atoms") + coordinates = data.get("coordinates") + targets = data.get("target", data.get("targets")) + if atoms is None or coordinates is None or targets is None: + raise ValueError("Direct training data requires atoms, coordinates, and target") + if not (len(atoms) == len(coordinates) == len(targets)): + raise ValueError("atoms, coordinates, and target must have the same length") + records = [] + rows = [] + for idx, (symbols, coords, target) in enumerate( + zip(atoms, coordinates, targets, strict=False) + ): + records.append( + (list(symbols), np.asarray(coords, dtype=np.float32), float(target), idx) + ) + rows.append({"sample_id": idx, "target": float(target)}) + return records, rows + + +def predict_records_from_data( + data: dict[str, Any] | str | Path, + *, + property_col: str | None = "Property", + mol_dir: str | Path | None = None, + mol_template: str = "id{row}.mol", + smiles_col: str = "SMILES", +) -> tuple[list[list[str]], list[np.ndarray], list[dict[str, Any]]]: + if isinstance(data, (str, Path)) or (isinstance(data, dict) and "dataset" in data): + dataset = Path(data if isinstance(data, (str, Path)) else data["dataset"]) + mol_dir_value = ( + mol_dir + if mol_dir is not None + else data.get("mol_dir") + if isinstance(data, dict) + else None + ) + smiles_col_value = ( + data.get("smiles_col", smiles_col) if isinstance(data, dict) else smiles_col + ) + with dataset.open("r", encoding="utf-8") as fp: + rows = list(csv.DictReader(fp)) + if rows and property_col is not None: + _find_column(list(rows[0].keys()), [property_col, "Property", "property"]) + smiles_column = None + if mol_dir_value is None and rows: + smiles_column = _find_column( + list(rows[0].keys()), [smiles_col_value, "SMILES", "smiles"] + ) + atoms: list[list[str]] = [] + coords: list[np.ndarray] = [] + kept_rows: list[dict[str, Any]] = [] + for row_idx, row in enumerate(rows): + if mol_dir_value is None: + try: + symbols, coord = smiles_to_3d_coords( + row[smiles_column], random_seed=42 + row_idx + ) + except Exception as exc: + warnings.warn( + f"Skipping row {row_idx} during prediction because RDKit failed " + f"to generate coordinates: {exc}", + RuntimeWarning, + ) + continue + else: + symbols, coord = read_mol_coords( + Path(mol_dir_value) / mol_template.format(row=row_idx) + ) + atoms.append(symbols) + coords.append(coord) + kept_rows.append(dict(row)) + return atoms, coords, kept_rows + + atoms_raw = data.get("atoms") + coords_raw = data.get("coordinates") + if atoms_raw is None or coords_raw is None: + raise ValueError("Prediction data requires atoms and coordinates") + atoms = [list(symbols) for symbols in atoms_raw] + coords = [np.asarray(coord, dtype=np.float32) for coord in coords_raw] + if len(atoms) != len(coords): + raise ValueError("atoms and coordinates must have the same length") + rows = [{"sample_id": idx} for idx in range(len(atoms))] + return atoms, coords, rows + + +# --------------------------------------------------------------------------- +# tiny utility +# --------------------------------------------------------------------------- + + +def _deep_merge(base: dict, updates: dict) -> dict: + """Recursively merge *updates* into a shallow copy of *base*.""" + import copy + + result = copy.deepcopy(base) + _deep_update(result, updates) + return result + + +def _deep_update(target: dict, updates: dict) -> None: + for key, value in updates.items(): + if isinstance(value, dict) and isinstance(target.get(key), dict): + _deep_update(target[key], value) + else: + target[key] = value diff --git a/dpa_adapt/data/type_map.py b/dpa_adapt/data/type_map.py new file mode 100644 index 0000000000..477b173ca3 --- /dev/null +++ b/dpa_adapt/data/type_map.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# data/type_map.py +# +# Automatic type_map resolution: read from checkpoint, union from data, +# validate subsets. Users should never need to touch ``_extra_state``. + +from __future__ import ( + annotations, +) + + +def read_checkpoint_type_map( + pretrained: str, + branch: str | None = None, +) -> list[str]: + """Read the global type_map from a DPA checkpoint. + + For multi-task checkpoints the type_map lives in + ``shared_dict..type_map`` or falls back to the branch's + own ``type_map``. For single-task checkpoints it is at the model root. + + Parameters + ---------- + pretrained : str + Path to the ``.pt`` checkpoint. + branch : str, optional + Branch name for multi-task checkpoints. If not given the first + available branch is used. + + Returns + ------- + list[str] + Element symbols. + """ + from dpa_adapt._backend import ( + load_torch_file, + resolve_pretrained_path, + ) + + pretrained = resolve_pretrained_path(pretrained) + sd = load_torch_file(pretrained) + if "model" in sd: + sd = sd["model"] + + params = sd["_extra_state"]["model_params"] + + # Multi-task: type_map is in shared_dict or per-branch + model_dict = params.get("model_dict", {}) + if model_dict: + shared = params.get("shared_dict", {}) + # shared_dict values are descriptor/fitting_net dicts; some may + # contain a type_map list directly, some use a "type_map" key that + # points to a name in shared_dict. + for v in shared.values(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], str): + return v + # Fall back to the branch's own type_map + if branch and branch in model_dict: + tm = model_dict[branch].get("type_map") + else: + first = next(iter(model_dict.values())) + tm = first.get("type_map") + if isinstance(tm, str): + tm = shared.get(tm) + if isinstance(tm, list): + return tm + + # Single-task: type_map at model root + tm = params.get("type_map") + if isinstance(tm, list): + return tm + + raise ValueError( + f"Could not locate type_map in checkpoint {pretrained}. " + "Pass type_map=[...] explicitly." + ) + + +def _is_placeholder_type_map(names: list[str] | tuple[str, ...]) -> bool: + """Return ``True`` if *names* is dpdata's all-``Type_N`` placeholder map. + + dpdata invents ``Type_0``, ``Type_1``, ... when the source data had no + ``type_map.raw``. Such a map carries no real element identity, so callers + treat it as "no atom_names" and fall back to raw atom indices. Shared by + ``read_data_type_map_union`` here and ``_read_data_type_map`` in + ``finetuner`` so both apply the same rule. + """ + return bool(names) and all(str(n).startswith("Type_") for n in names) + + +def read_data_type_map_union(systems: list) -> list[str]: + """Read ``atom_names`` from every system and return the union. + + Each system may declare a subset of elements (different dopants per + formula). The union covers all elements present across the dataset. + + Parameters + ---------- + systems : list[dpdata.System] + Systems to scan for element names. + + Returns + ------- + list[str] + Sorted union of all element symbols appearing in any system. + """ + elems: set[str] = set() + for sys in systems: + names = sys.data.get("atom_names", []) + # Skip dpdata's all-"Type_N" placeholder maps so callers fall back to + # raw atom indices instead of rejecting valid data as unsupported + # elements (consistent with _read_data_type_map in finetuner.py). + if _is_placeholder_type_map(names): + continue + for name in names: + if name: + elems.add(str(name)) + if not elems: + raise ValueError( + "No atom_names found in any system. " + "Ensure data has been loaded with dpdata correctly." + ) + return sorted(elems) + + +def validate_type_map_subset( + data_elements: list[str], + checkpoint_elements: list[str], + *, + label: str = "data", +) -> None: + """Raise ``ValueError`` if *data_elements* is not a subset of *checkpoint_elements*. + + Parameters + ---------- + data_elements : list[str] + Element symbols appearing in the data (typically from + ``read_data_type_map_union``). + checkpoint_elements : list[str] + Element symbols covered by the checkpoint (from + ``read_checkpoint_type_map``). + label : str + Human-readable label for the error message (e.g. ``"OER data"``). + + Raises + ------ + ValueError + If any data element is not in the checkpoint type_map. + """ + ckpt_set = set(checkpoint_elements) + unsupported = [e for e in data_elements if e not in ckpt_set] + if unsupported: + ckpt_repr = ( + f"{checkpoint_elements[:3]}...{checkpoint_elements[-1:]} " + f"({len(checkpoint_elements)} elements)" + if len(checkpoint_elements) > 8 + else str(checkpoint_elements) + ) + raise ValueError( + f"Element(s) in {label} are not covered by the checkpoint.\n" + f" {label} type_map: {data_elements}\n" + f" Unsupported elements: {unsupported}\n" + f" Checkpoint covers: {ckpt_repr}\n" + "Use a checkpoint whose type_map includes these elements, " + "or filter the data to remove unsupported elements." + ) diff --git a/dpa_adapt/data/validate.py b/dpa_adapt/data/validate.py new file mode 100644 index 0000000000..071955cb1f --- /dev/null +++ b/dpa_adapt/data/validate.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# data/validate.py +# +# Content-level sanity checks for dpdata systems. +# +# Scope: flag things that are almost certainly bugs (NaN/Inf, degenerate +# cells, misaligned frame counts) plus two coarse magnitude bounds. This is +# NOT anomaly detection — it does not look for statistical outliers. + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, + Literal, + NamedTuple, +) + +import numpy as np + +from dpa_adapt.data.errors import ( + DPADataError, +) + +if TYPE_CHECKING: + import dpdata + +# Magnitude sanity thresholds — values past these are almost never real. +_ENERGY_MAX_EV_PER_ATOM = 1000.0 +_FORCE_MAX_EV_PER_ANGSTROM = 100.0 + +# A box matrix with |det| below this is treated as degenerate. +_BOX_DET_TOLERANCE = 1e-10 + + +class Issue(NamedTuple): + """A single data-quality finding from check_data().""" + + severity: Literal["warning", "error"] + system: str # system identifier (source path or hash) + set_dir: str # always "" for dpdata systems (no set.* granularity) + file: str # data key the issue concerns, e.g. "energies" + description: str # human-readable explanation + + +def _check_system( + system: dpdata.System, + identifier: str, + box_det_tol: float, +) -> list[Issue]: + """Run all content checks on a single dpdata system.""" + issues: list[Issue] = [] + name = identifier + + def _issue(severity: str, file: str, description: str) -> Issue: + return Issue(severity, name, "", file, description) + + d = system.data + coords = np.asarray(d.get("coords")) + cells_raw = d.get("cells") + energies = d.get("energies") + forces = d.get("forces") + + # --- normalise cells to (n_frames, 3, 3) --- + # dpdata versions differ: some return (n_frames, 9), others (n_frames, 3, 3). + # Reshape explicitly so downstream checks see a uniform layout. + cells = None + if cells_raw is not None: + cells = np.asarray(cells_raw) + if cells.ndim == 2 and cells.shape[1] == 9: + try: + cells = cells.reshape(-1, 3, 3) + except ValueError as exc: + raise DPADataError( + f"Cannot reshape cells of shape {cells_raw.shape} to " + f"(-1, 3, 3): {exc}" + ) from exc + elif cells.ndim == 3 and cells.shape[1:] == (3, 3): + pass # already canonical + else: + raise DPADataError( + f"Unexpected cells shape {cells_raw.shape!r}. " + "Expected (n_frames, 9) or (n_frames, 3, 3)." + ) + + # --- NaN / Inf --- + for key, arr in [("energies", energies), ("forces", forces), ("cells", cells)]: + if arr is None: + continue + arr = np.asarray(arr) + if not np.all(np.isfinite(arr)): + n_bad = int(np.count_nonzero(~np.isfinite(arr))) + issues.append( + _issue( + "error", + key, + f"{key}: contains {n_bad} non-finite value(s) (NaN or Inf).", + ) + ) + + # --- degenerate box (|det| below tolerance) --- + if cells is not None and np.all(np.isfinite(cells)): + dets = np.abs(np.linalg.det(cells)) + for fi in np.where(dets < box_det_tol)[0]: + issues.append( + _issue( + "error", + "cells", + f"cells: frame {int(fi)} has |det| = {dets[fi]:.2e} " + f"(< tol {box_det_tol:.0e}), likely degenerate box.", + ) + ) + + # --- energy magnitude (per atom) --- + if energies is not None and coords is not None and coords.ndim >= 2: + energies = np.asarray(energies) + if np.all(np.isfinite(energies)): + n_atoms = coords.shape[1] # dpdata coords: (n_frames, n_atoms, 3) + if n_atoms > 0: + per_atom = np.abs(energies) / n_atoms + for fi in np.where(per_atom > _ENERGY_MAX_EV_PER_ATOM)[0]: + issues.append( + _issue( + "warning", + "energies", + f"energies: frame {int(fi)} has |E/atom| = " + f"{per_atom[fi]:.1f} eV/atom " + f"(> {_ENERGY_MAX_EV_PER_ATOM:.0f}); suspicious magnitude.", + ) + ) + + # --- force magnitude (per component) --- + if forces is not None: + forces = np.asarray(forces) + if np.all(np.isfinite(forces)): + abs_f = np.abs(forces) + per_frame_max = abs_f.max(axis=tuple(range(1, abs_f.ndim))) + for fi in np.where(per_frame_max > _FORCE_MAX_EV_PER_ANGSTROM)[0]: + issues.append( + _issue( + "warning", + "forces", + f"forces: frame {int(fi)} has a force component of " + f"{per_frame_max[fi]:.1f} eV/Ang " + f"(> {_FORCE_MAX_EV_PER_ANGSTROM:.0f}); suspicious magnitude.", + ) + ) + + # --- frame-count alignment --- + ref = coords.shape[0] if coords.ndim >= 2 else 0 + for key in ("cells", "energies", "forces"): + arr = d.get(key) + if arr is not None: + arr = np.asarray(arr) + if arr.ndim >= 1 and arr.shape[0] != ref and ref > 0: + issues.append( + _issue( + "error", + key, + f"{key} has {arr.shape[0]} frame(s) but coords has " + f"{ref}; frame counts must align.", + ) + ) + + return issues + + +def check_data( + data: dpdata.System | list[dpdata.System], + strict: bool = False, + box_det_tol: float = _BOX_DET_TOLERANCE, +) -> list[Issue]: + """ + Content-level sanity check of one or more dpdata systems. + + Checks for NaN/Inf, degenerate (zero-volume) cells, misaligned frame + counts, and coarse magnitude bounds. + + Parameters + ---------- + data : dpdata.System | list[dpdata.System] + Systems to check. + strict : bool + If True, raise ``DPADataError`` on the first issue. + box_det_tol : float + A cell matrix with ``|det|`` below this is reported as degenerate. + + Returns + ------- + list[Issue] + """ + import dpdata + + if isinstance(data, (dpdata.System, dpdata.LabeledSystem)): + systems = [data] + elif isinstance(data, (list, tuple)): + systems = list(data) + else: + raise TypeError( + f"check_data expects dpdata.System or list, got {type(data).__name__}" + ) + + issues: list[Issue] = [] + + for i, system in enumerate(systems): + source = getattr(system, "_dpa_source", None) + identifier = source if source else f"system[{i}]" + for issue in _check_system(system, identifier, box_det_tol): + if strict: + raise DPADataError(f"check_data (strict): {issue.description}") + issues.append(issue) + + return issues diff --git a/dpa_adapt/finetuner.py b/dpa_adapt/finetuner.py new file mode 100644 index 0000000000..9c409a8eda --- /dev/null +++ b/dpa_adapt/finetuner.py @@ -0,0 +1,1730 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# dpa_adapt/finetuner.py +# +# frozen_sklearn architecture: frozen DPA descriptor → sklearn predictor +# DPA checkpoint is used purely as a feature extractor (no dp train). + +import logging +import os +import re +import shutil +import subprocess +from pathlib import ( + Path, +) +from typing import ( + Any, + ClassVar, +) + +import dpdata +import numpy as np + +from dpa_adapt._backend import ( + _DescriptorExtraction, + build_model_from_config, + get_torch_device, + load_torch_file, + resolve_dp_command, + resolve_model_branch, + resolve_pretrained_path, +) +from dpa_adapt._validation import ( + validate_fparam_dim, +) +from dpa_adapt.conditions import ( + ConditionManager, + DPAConditionError, +) +from dpa_adapt.data.errors import ( + DPADataError, +) +from dpa_adapt.data.loader import ( + _find_label_npys, + _get_source, + _resolve_label_key, + load_data, +) +from dpa_adapt.data.type_map import ( + _is_placeholder_type_map, +) +from dpa_adapt.utils.dotdict import ( + DotDict, +) + +_LOG = logging.getLogger("dpa_adapt") + +# --------------------------------------------------------------------------- +# Module-level helpers +# --------------------------------------------------------------------------- + + +def _load_labels( + systems: list[dpdata.System], + target_key: str | list[str], +) -> np.ndarray: + """Load and concatenate labels from dpdata systems. + + *target_key* may be a single string (existing behaviour) or a list of + strings (new: multi-property). When a list is given each key is loaded + independently and the results are stacked column-wise into a 2-D array + of shape ``(n_frames, len(target_key))``. + + Each key is resolved through ``_LABEL_KEY_ALIASES`` so that + ``"energy"`` → ``"energies"`` for backward compatibility. + + When a resolved key is not present in ``system.data`` (dpdata only + loads standard DeePMD keys), this function falls back to reading + ``set.*/{key}.npy`` directly from the system source directory. + """ + keys = [target_key] if isinstance(target_key, str) else list(target_key) + columns = [] + + for key in keys: + resolved = _resolve_label_key(key) + all_labels = [] + for system in systems: + if resolved in system.data: + all_labels.append(np.asarray(system.data[resolved])) + continue + + # Fallback: load set.*/{key}.npy directly from the system directory. + source = _get_source(system) + if source is not None: + npy_paths = _find_label_npys(source, resolved) + if npy_paths: + all_labels.append( + np.concatenate([np.load(p) for p in npy_paths], axis=0) + ) + continue + + # Neither dpdata nor direct .npy found — build a clear error. + available = sorted(system.data.keys()) + if source is not None: + set_dirs = sorted(Path(source).glob("set.*")) + available_npy = sorted( + {p.name for sd in set_dirs for p in sd.glob("*.npy")} + ) + else: + available_npy = [] + msg = ( + f"Label key {resolved!r} not found. " + f"Checked system.data keys: {available}." + ) + if available_npy: + msg += f" Checked set.*/npy files: {available_npy}." + else: + msg += " No system source path for direct .npy fallback." + msg += f" (target_key={key!r})." + raise DPADataError(msg) + + columns.append(np.concatenate(all_labels, axis=0)) + + if len(columns) == 1: + return columns[0] + return np.column_stack(columns) + + +def _set_nframes(set_dir: Path) -> int | None: + """Frame count of a deepmd/npy ``set.*`` directory. + + Read from the ``coord.npy`` header (memory-mapped, so no array data is + loaded). Returns ``None`` when the set has no ``coord.npy`` to count + against. + """ + coord = set_dir / "coord.npy" + if not coord.is_file(): + return None + return int(np.load(str(coord), mmap_mode="r").shape[0]) + + +def _read_fparam_from_systems( + systems: list[dpdata.System], + expected_dim: int | None = None, +) -> dict[str, np.ndarray] | None: + """Auto-read fparam.npy from each system's ``set.*/`` directories. + + Returns a dict mapping ``"fparam_0"``, ``"fparam_1"``, ... to 1-D + arrays of length ``n_frames_total``, suitable for passing as + ``conditions=`` to :meth:`ConditionManager.fit_transform`. + + Returns ``None`` when no system has a ``set.*/fparam.npy`` file and + *expected_dim* is not set. When *expected_dim* is set, every system must + provide fparams with exactly that width. + """ + all_fparams = [] + for idx, system in enumerate(systems): + source = _get_source(system) + if source is None: + if expected_dim is not None: + raise DPAConditionError( + "fparam_dim was requested, but system " + f"{idx} has no source directory for set.*/fparam.npy." + ) + continue + source_path = Path(source) + set_dirs = sorted(source_path.glob("set.*")) + fps = [sd / "fparam.npy" for sd in set_dirs] + missing = [fp for fp in fps if not fp.is_file()] + if expected_dim is not None and missing: + raise DPAConditionError( + f"fparam_dim={expected_dim} but fparam.npy is missing under " + f"{source_path}: {[str(fp) for fp in missing]}" + ) + present = [ + (sd, fp) for sd, fp in zip(set_dirs, fps, strict=False) if fp.is_file() + ] + if not present: + continue + arrs = [] + for set_dir, fp in present: + arr = np.load(str(fp)) + if arr.ndim != 2: + raise DPAConditionError( + f"fparam.npy at {fp} has shape {arr.shape}; expected " + "(n_frames, fparam_dim)." + ) + if expected_dim is not None and arr.shape[1] != expected_dim: + raise DPAConditionError( + f"fparam.npy at {fp} has shape {arr.shape}; expected " + f"(n_frames, {expected_dim})." + ) + nframes_set = _set_nframes(set_dir) + if nframes_set is not None and arr.shape[0] != nframes_set: + raise DPAConditionError( + f"fparam.npy at {fp} has {arr.shape[0]} rows, but set " + f"{set_dir.name} has {nframes_set} frames; expected one " + "fparam row per frame." + ) + arrs.append(arr) + all_fparams.append(np.concatenate(arrs, axis=0)) + if not all_fparams: + if expected_dim is not None: + raise DPAConditionError( + f"fparam_dim={expected_dim} but no set.*/fparam.npy files " + "were found in the data." + ) + return None + combined = np.concatenate(all_fparams, axis=0) # (n_frames, fparam_dim) + if expected_dim is not None and combined.shape[1] != expected_dim: + raise DPAConditionError( + f"Combined fparam width is {combined.shape[1]}, expected {expected_dim}." + ) + return {f"fparam_{i}": combined[:, i] for i in range(combined.shape[1])} + + +def _read_data_type_map(system: dpdata.System) -> list[str]: + """Read element symbols from a dpdata System's ``atom_names``. + + Returns an empty list when the names are dpdata's auto-generated + ``Type_0`` / ``Type_1`` placeholders (which appear when the source + data had no ``type_map.raw``). + """ + names = list(system.data.get("atom_names", [])) + if not names or _is_placeholder_type_map(names): + return [] + return names + + +def _load_npy_system( + system: dpdata.System, +) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: + """Extract (coords, boxes, atom_types) from a dpdata System. + + Adapts dpdata's native shapes to the format expected by + ``_extract_features``: + + - coords : (n_frames, n_atoms*3) (flattened) + - boxes : (n_frames, 9) or None for non-periodic + - atom_types : (n_atoms,) int + + Returns + ------- + coords : np.ndarray, shape (n_frames, n_atoms*3) + boxes : np.ndarray, shape (n_frames, 9), or None + atom_types : np.ndarray, shape (n_atoms,) + """ + d = system.data + coords = np.asarray(d["coords"]) # (n_frames, n_atoms, 3) + n_atoms = coords.shape[1] + coords = coords.reshape(coords.shape[0], n_atoms * 3) + + cells = np.asarray(d["cells"]) # (n_frames, 3, 3) + boxes = cells.reshape(cells.shape[0], 9) + + atom_types = np.asarray(d["atom_types"]) # (n_atoms,) + + if d.get("nopbc", False) or np.allclose(boxes, 0): + boxes = None + + return coords, boxes, atom_types + + +# --------------------------------------------------------------------------- +# Descriptor-cache extraction helpers (here, not desc_cache.py, to avoid +# an import cycle: these need DPAFineTuner while finetuner imports cache +# helpers from desc_cache). +# --------------------------------------------------------------------------- + + +def load_or_extract( + systems: list, + pretrained: str, + model_branch: str | None = None, + pooling: str = "mean", + cache: bool = True, + type_map: list[str] | tuple[str, ...] | None = None, +) -> np.ndarray: + """Return descriptors for *systems*, using the cache when possible. + + Parameters + ---------- + systems : list[dpdata.System] + Systems to extract descriptors from. + pretrained : str + Path to the DPA checkpoint. + model_branch : str, optional + Branch name. + pooling : str + Pooling strategy. + cache : bool + If False the cache is bypassed entirely. + type_map : list[str] or tuple[str, ...], optional + Element symbols used to build the descriptor model and cache key. + + Returns + ------- + np.ndarray, shape ``(n_frames_total, feat_dim)`` + """ + from dpa_adapt.data.desc_cache import ( + _cache_dir, + _cache_key, + ) + + if cache: + key = _cache_key( + systems, + pretrained, + model_branch, + pooling, + type_map=type_map, + ) + cache_path = _cache_dir() / f"{key}.npy" + if cache_path.is_file(): + _LOG.info("Descriptor cache hit: %s", cache_path.name) + return np.load(cache_path) + _LOG.info("Descriptor cache miss; extracting...") + else: + _LOG.info("Descriptor cache bypassed (cache=False).") + + extractor = DPAFineTuner( + pretrained=pretrained, + model_branch=model_branch, + predictor="linear", + pooling=pooling, + type_map=list(type_map) if type_map else None, + ) + descriptors = extractor._extract_features(systems) + + if cache: + cache_path.parent.mkdir(parents=True, exist_ok=True) + np.save(cache_path, descriptors) + _LOG.info("Cached descriptors to %s", cache_path) + + return descriptors + + +def ensure_per_system_cache( + systems: list, + pretrained: str, + model_branch: str | None = None, + pooling: str = "mean", + type_map: list[str] | tuple[str, ...] | None = None, +) -> None: + """Ensure every system has its descriptors cached to disk. + + Existing cache files are reused as-is. Missing ones are extracted one + system at a time for low peak memory. + """ + from dpa_adapt.data.desc_cache import ( + _per_system_cache_path, + ) + + missing: list = [] + for system in systems: + if not _per_system_cache_path( + system, + pretrained, + model_branch, + pooling, + type_map, + ).is_file(): + missing.append(system) + + if not missing: + _LOG.info( + "All %d systems have per-system cache; nothing to extract.", len(systems) + ) + return + + import torch + + _LOG.info( + "%d/%d systems missing per-system cache; extracting one by one...", + len(missing), + len(systems), + ) + + extractor = DPAFineTuner( + pretrained=pretrained, + model_branch=model_branch, + predictor="linear", + pooling=pooling, + type_map=list(type_map) if type_map else None, + ) + + for i, system in enumerate(missing): + cache_path = _per_system_cache_path( + system, + pretrained, + model_branch, + pooling, + type_map, + ) + cache_path.parent.mkdir(parents=True, exist_ok=True) + desc = extractor._extract_features([system]) + np.save(cache_path, desc) + if extractor._device is not None and extractor._device.type == "cuda": + torch.cuda.empty_cache() + if i > 0 and i % 50 == 0: + _LOG.info(" per-system cache: %d/%d done", i, len(missing)) + + _LOG.info("Per-system cache ready (%d systems).", len(systems)) + + +# --------------------------------------------------------------------------- +# Public descriptor extraction +# --------------------------------------------------------------------------- + + +def extract_descriptors( + data: str | list[str], + pretrained: str, + model_branch: str | None = None, + pooling: str = "mean", + cache: bool = True, +) -> np.ndarray: + """ + Extract pooled DPA descriptors for one or more deepmd/npy systems. + + This is the same feature extraction pipeline ``DPAFineTuner.fit()`` uses + internally, exposed as a standalone function so downstream tools (e.g. + multi-task fine-tuning, auxiliary-data selection) can share it without + constructing a finetuner. + + Parameters + ---------- + data : str | list[str] + Path(s) to deepmd/npy system directories. + pretrained : str + Path to the pretrained DPA checkpoint (.pt). + model_branch : str, optional + Branch name for multi-task checkpoints (e.g. ``"Omat24"``). + pooling : str + Pooling strategy. One of ``"mean"``, ``"sum"``, ``"mean+std"``, + ``"mean+std+max+min"``. + cache : bool + If True (default), cache the extracted descriptors on disk so + repeated calls with the same data + checkpoint + pooling are + instant. The cache is invalidated when any ``coord.npy`` or the + checkpoint changes (mtime-based fingerprint). + + Returns + ------- + np.ndarray + Pooled descriptor features, shape ``(n_frames_total, feat_dim)``. + ``feat_dim`` depends on the pooling strategy. + """ + systems = load_data(data) + return load_or_extract( + systems=systems, + pretrained=pretrained, + model_branch=model_branch, + pooling=pooling, + cache=cache, + ) + + +# --------------------------------------------------------------------------- +# Internal: frozen-sklearn pipeline (extracted from DPAFineTuner) +# +# Refactored: all descriptor-loading, feature-extraction, and sklearn-fitting +# logic moved into this helper so DPAFineTuner is a thin dispatcher. +# --------------------------------------------------------------------------- + + +class _FrozenSklearnPipeline: + """Internal helper: frozen DPA descriptor → sklearn predictor pipeline. + + Encapsulates descriptor model loading, feature extraction (with + caching), type-map validation / remapping, and sklearn fitting / + prediction / evaluation / freeze. DPAFineTuner holds one of these + when ``strategy='frozen_sklearn'`` and delegates public API calls to it. + + Refactored: extracted from ``DPAFineTuner`` to separate the sklearn + code path from the training-paradigm and MFT dispatch logic. + """ + + _VALID_POOLING: ClassVar[set[str]] = {"mean", "sum", "mean+std", "mean+std+max+min"} + + def __init__( + self, + pretrained: str, + model_branch: str | None, + predictor_type: str, + pooling: str, + seed: int, + ) -> None: + self.pretrained = pretrained + self.model_branch = model_branch + self._predictor_type = predictor_type + self.pooling = pooling + self.seed = seed + + # Populated during fit / extraction + self._model = None + self._device = None + self._checkpoint_type_map = [] + self.predictor = None + self._task_dim = 1 + self._target_key = None + self._condition_manager = None + self._fitted = False + self.type_map = [] + + # ------------------------------------------------------------------ + # Descriptor model loading + # ------------------------------------------------------------------ + + def load_descriptor_model(self) -> Any: + """Load the pretrained DPA checkpoint and return a (non-JIT) ModelWrapper. + + If *pretrained* is a built-in model name (e.g. ``"DPA-3.1-3M"``) + rather than a local path, it is automatically downloaded. + """ + resolved = resolve_pretrained_path(self.pretrained) + state_dict = load_torch_file(resolved) + if "model" in state_dict: + state_dict = state_dict["model"] + + input_param = state_dict["_extra_state"]["model_params"] + + if "model_dict" in input_param: + # Multi-task checkpoint: select the right branch + model_alias_dict, _ = resolve_model_branch(input_param["model_dict"]) + head = self.model_branch or "Omat24" + + # Case-insensitive fallback + if head not in model_alias_dict: + head_lower = head.lower() + for mk in model_alias_dict: + if mk.lower() == head_lower: + head = mk + break + assert head in model_alias_dict, ( + f"Branch '{head}' not found. Available: {list(model_alias_dict)}" + ) + head = model_alias_dict[head] + + # Build single-task input_param from the selected branch + input_param = input_param["model_dict"][head] + + # Remap state dict keys: model.{head}.xxx → model.Default.xxx + new_sd = {"_extra_state": state_dict["_extra_state"]} + for key, val in state_dict.items(): + prefix = f"model.{head}." + if key.startswith(prefix): + new_sd[key.replace(prefix, "model.Default.", 1)] = val + state_dict = new_sd + + self._checkpoint_type_map = list(input_param.get("type_map", [])) + + # Build model WITHOUT JIT so that eval_descriptor_hook works + wrapper = build_model_from_config(input_param) + wrapper.load_state_dict(state_dict) + wrapper.eval() + + device = get_torch_device() + wrapper = wrapper.to(device) + self._device = device + return wrapper + + # ------------------------------------------------------------------ + # Type-map helpers + # ------------------------------------------------------------------ + + def validate_type_map( + self, user_type_map: list[str], systems: list[dpdata.System] + ) -> None: + """Raise DPADataError if any data element is not in the checkpoint type_map. + + The data type_map can be any subset of the checkpoint's type_map — order + and contiguity are irrelevant. Local indices are remapped to checkpoint + global indices in ``extract_features``. + """ + ckpt = self._checkpoint_type_map + if not ckpt: + return # checkpoint has no type_map metadata → skip + + ckpt_set = set(ckpt) + + def _check(candidate: list[str], source: str) -> None: + unsupported = [e for e in candidate if e not in ckpt_set] + if unsupported: + ckpt_repr = ( + f"{[*ckpt[:3], '...', *ckpt[-1:]]} ({len(ckpt)} elements)" + if len(ckpt) > 8 + else str(ckpt) + ) + raise DPADataError( + f"Element(s) in {source} not supported by this checkpoint.\n" + f" Data type_map : {candidate}\n" + f" Checkpoint covers : {ckpt_repr}\n" + f" Unsupported : {unsupported}\n" + "Please re-convert your data with a supported element set." + ) + + if user_type_map: + _check(user_type_map, "user-provided type_map") + + for system in systems: + data_tm = _read_data_type_map(system) + if data_tm: + identifier = system.orig if hasattr(system, "orig") else "system" + _check(data_tm, f"atom_names of {identifier}") + + def remap_atom_types( + self, atom_types: np.ndarray, system: dpdata.System + ) -> np.ndarray: + """Map local atom-type indices to checkpoint-global indices. + + ``atom_types`` are 0-based indices into the system's type_map. + The model expects indices into the checkpoint's ``type_map``. + """ + ckpt = self._checkpoint_type_map + + data_tm = _read_data_type_map(system) or list(self.type_map) + + identifier = system.orig if hasattr(system, "orig") else "system" + + if not data_tm: + if ckpt and atom_types.size and int(atom_types.max()) >= len(ckpt): + raise DPADataError( + f"No atom_names in system and no type_map provided, " + f"but atom type index {int(atom_types.max())} " + f"is out of range for the checkpoint type_map " + f"(size {len(ckpt)}). " + "Pass type_map=[...] to fit()." + ) + return atom_types + + if not ckpt: + return atom_types + + try: + local_to_global = np.array( + [ckpt.index(elem) for elem in data_tm], + dtype=np.int64, + ) + except ValueError as e: + unsupported = [e for e in data_tm if e not in set(ckpt)] + raise DPADataError( + f"Element(s) in data type_map for {identifier!r} not " + f"supported by this checkpoint.\n" + f" Data type_map : {data_tm}\n" + f" Unsupported : {unsupported}" + ) from e + + if atom_types.size and int(atom_types.max()) >= len(local_to_global): + raise DPADataError( + f"atom type index {int(atom_types.max())} in {identifier!r} " + f"exceeds the data type_map size ({len(local_to_global)}). " + "Check that type_map and atom_types are consistent." + ) + + return local_to_global[atom_types] + + # ------------------------------------------------------------------ + # Feature extraction (extract_features_cached is on DPAFineTuner + # so that patches on DPAFineTuner._extract_features are honoured) + # ------------------------------------------------------------------ + + def extract_features(self, systems: list[dpdata.System]) -> np.ndarray: + """Extract per-structure descriptor features by pooling over atoms. + + The pooling strategy is controlled by ``self.pooling``: + - ``"mean"`` → shape (n_frames, feat_dim) + - ``"sum"`` → shape (n_frames, feat_dim) + - ``"mean+std"`` → shape (n_frames, feat_dim*2) + - ``"mean+std+max+min"`` → shape (n_frames, feat_dim*4) + + Parameters + ---------- + systems : list[dpdata.System] + dpdata systems to extract descriptors from. + + Returns + ------- + np.ndarray, shape (n_frames_total, feature_dim) + """ + import torch + + if self._model is None: + self._model = self.load_descriptor_model() + + extractor = _DescriptorExtraction(self._model) + extractor._enable_hook() + + all_features = [] + + for system in systems: + coords, boxes, atom_types = _load_npy_system(system) + n_frames = coords.shape[0] + n_atoms = len(atom_types) + + # Remap local atom-type indices to checkpoint-global indices. + atom_types_global = self.remap_atom_types(atom_types, system) + + # Non-periodic structures must NOT use all-zero box: + # the descriptor produces NaN in that case. + # Use a large 100 Å cubic box instead. + if boxes is None: + boxes = np.tile(np.eye(3) * 100.0, (n_frames, 1)).reshape(n_frames, 9) + + # coord requires grad: forward_common calls autograd.grad + # internally to compute forces, which fails under no_grad. + coord_t = torch.tensor( + coords.reshape(n_frames, n_atoms * 3), + dtype=torch.float64, + device=self._device, + ).requires_grad_(True) + atype_t = torch.tensor( + np.tile(atom_types_global, (n_frames, 1)), + dtype=torch.long, + device=self._device, + ) + box_t = torch.tensor(boxes, dtype=torch.float64, device=self._device) + + # Shape: (n_frames, n_atoms, feat_dim) + descrpt = extractor._run_forward(coord_t, atype_t, box_t) + if self.pooling == "mean": + feat = descrpt.mean(dim=1) + elif self.pooling == "sum": + feat = descrpt.sum(dim=1) + elif self.pooling == "mean+std": + mean = descrpt.mean(dim=1) + std = torch.nan_to_num(descrpt.std(dim=1), nan=0.0) + feat = torch.cat([mean, std], dim=-1) + elif self.pooling == "mean+std+max+min": + mean = descrpt.mean(dim=1) + std = torch.nan_to_num(descrpt.std(dim=1), nan=0.0) + feat = torch.cat( + [ + mean, + std, + descrpt.max(dim=1).values, + descrpt.min(dim=1).values, + ], + dim=-1, + ) + feat = torch.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0) + all_features.append(feat.detach().cpu().numpy()) + + extractor._disable_hook() + return np.concatenate(all_features, axis=0) + + +# --------------------------------------------------------------------------- +# Main class +# --------------------------------------------------------------------------- + + +class DPAFineTuner: + """Adapt a pretrained DPA model to a downstream property via transfer learning. + + Four strategies, selected by *strategy*. + + ==================== ===================================================== + ``frozen_sklearn`` (default, CPU) Freeze the DPA backbone, extract + descriptors once, pool, and fit a scikit-learn + regressor (Ridge, KRR, or MLP). No GPU needed; + fastest for small datasets. + ``frozen_head`` Freeze the backbone, train only a neural property + fitting net via ``dp --pt train``. + ``finetune`` Fine-tune the full network (descriptor + fitting + net) end-to-end via ``dp --pt train``. + ``mft`` Multi-task fine-tuning: a downstream property head + is trained jointly with an auxiliary force/energy + head to regularise the representation. Requires + *aux_data* at ``fit()`` time. + ==================== ===================================================== + + Parameters + ---------- + pretrained : str + Path to the pretrained DPA checkpoint (``.pt``), or a built-in name + such as ``"DPA-3.1-3M"`` that could be auto-downloaded. + model_branch : str or None + Multi-task branch for descriptor extraction (e.g. ``"Domains_Drug"``). + predictor : str + (frozen_sklearn only) scikit-learn head: ``"rf"``, ``"linear"`` / + ``"ridge"``, or ``"mlp"``. + pooling : str + (frozen_sklearn only) Descriptor pooling: ``"mean"`` (default), + ``"sum"``, ``"mean+std"``, or ``"mean+std+max+min"``. + seed : int + Random seed for the head or for full training. + strategy : str + ``"frozen_sklearn"`` (default), ``"frozen_head"``, ``"finetune"``, + or ``"mft"``. + + property_name : str + Label key written under ``set.*/`` (e.g. ``"bandgap"``). Used by + all non-``frozen_sklearn`` strategies, and by ``frozen_sklearn`` + when *target_key* is not passed explicitly to ``fit()``. + task_dim : int + Output dimensionality of the property fitting net. + intensive : bool + If True (default), the property is intensive and frame-averaged; + if False it is extensive (summed). + init_branch : str + Checkpoint branch used to initialise the descriptor (LP / FT only). + learning_rate, stop_lr : float + Start and end points of the exponential learning-rate schedule + (training paradigms). + decay_steps : int or None + Steps between LR decays for the ``exp`` scheduler (DeePMD-kit + native). ``None`` (default) auto-selects: 1000 for + ``frozen_head``/``finetune``; 1000 for MFT property mode, + 5000 for MFT ener mode. + warmup_steps : int + Linear LR warmup steps (deepmd-kit native). 0 = disabled. + max_steps : int + Total training steps (LP / FT / MFT). + batch_size : str or int + DeePMD-kit batch-size spec (e.g. ``"auto:512"`` or 128). + loss_function : str + ``"mse"`` or ``"smooth_mae"`` (training paradigms). + fitting_net_params : dict or None + Extra kwargs merged into the fitting-net config (e.g. + ``{"neuron": [128, 128]}``). Applies to ``frozen_head``, + ``finetune``, and ``mft`` strategies. + fparam_dim : int + Dimension of per-frame context features (e.g. temperature, + humidity). When > 0, ``set.*/fparam.npy`` of shape + ``(n_frames, fparam_dim)`` is read automatically for all + strategies. For ``frozen_sklearn``, fparam columns are + standardized and concatenated to the descriptor via + ``ConditionManager``. Default 0 (disabled). + output_dir : str + Directory for ``input.json``, checkpoints, and logs. + save_freq, disp_freq : int + Checkpoint save and log-display intervals (steps). + + aux_branch : str + (MFT only) Pre-trained branch for the auxiliary force/energy head. + aux_prob : float + (MFT only) Probability of sampling an auxiliary batch at each step. + type_map : list[str] or None + (MFT only) The global (shared) type map. Both branches share a single + descriptor, so this must be the union of elements in both datasets. + Auto-detected from the checkpoint if not provided. + downstream_task_type : str + (MFT only) Task type of the downstream head (``"property"`` etc.). + aux_batch_size : str or int or None + (MFT only) Batch-size spec for the auxiliary head. + downstream_batch_size : str or int or None + (MFT only) Batch-size spec for the downstream head. + """ + + _VALID_POOLING: ClassVar[set[str]] = {"mean", "sum", "mean+std", "mean+std+max+min"} + _VALID_STRATEGIES: ClassVar[set[str]] = { + "frozen_sklearn", + "frozen_head", + "finetune", + "mft", + } + + def __init__( + self, + pretrained: str = "DPA-3.1-3M", + model_branch: str | None = None, + predictor: str = "rf", + pooling: str = "mean", + seed: int = 42, + # ---- training paradigms ---- + strategy: str = "frozen_sklearn", + property_name: str = "property", + task_dim: int = 1, + intensive: bool = True, + init_branch: str = "SPICE2", + learning_rate: float = 1e-3, + stop_lr: float = 1e-5, + decay_steps: int + | None = None, # None → auto: 1000 for training, MFT auto-detect + warmup_steps: int = 0, + max_steps: int = 100_000, + batch_size: str | int = "auto:512", + loss_function: str = "mse", + fitting_net_params: dict | None = None, + fparam_dim: int = 0, + output_dir: str = "./dpa_output", + save_freq: int = 10_000, + disp_freq: int = 1_000, + # ---- mft-only ---- + aux_branch: str = "MP_traj_v024_alldata_mixu", + aux_prob: float = 0.5, + type_map: list[str] | None = None, + downstream_task_type: str = "property", + aux_batch_size: str | int | None = None, + downstream_batch_size: str | int | None = None, + ) -> None: + if pooling not in self._VALID_POOLING: + raise ValueError( + f"pooling must be one of {sorted(self._VALID_POOLING)}, got {pooling!r}" + ) + if strategy not in self._VALID_STRATEGIES: + raise ValueError( + f"strategy must be one of {sorted(self._VALID_STRATEGIES)}; " + f"got {strategy!r}" + ) + validate_fparam_dim(fparam_dim) + + self.strategy = strategy + + self.pretrained = pretrained + self.model_branch = model_branch + self._predictor_type = predictor + self.pooling = pooling + self.seed = seed + + # Training-paradigm params (unused by frozen_sklearn). + self.property_name = property_name + self.task_dim = task_dim + self.intensive = intensive + self.init_branch = init_branch + self.learning_rate = learning_rate + self.stop_lr = stop_lr + self.decay_steps = decay_steps + self.warmup_steps = warmup_steps + self.max_steps = max_steps + self.batch_size = batch_size + self.loss_function = loss_function + self.fitting_net_params = fitting_net_params + self.fparam_dim = fparam_dim + self.output_dir = output_dir + self.save_freq = save_freq + self.disp_freq = disp_freq + + # MFT-only parameters. + self.aux_branch = aux_branch + self.aux_prob = aux_prob + self.type_map = type_map + self.downstream_task_type = downstream_task_type + self.aux_batch_size = aux_batch_size + self.downstream_batch_size = downstream_batch_size + + if strategy == "mft": + if not 0.0 <= float(aux_prob) <= 1.0: + raise ValueError( + f"aux_prob must be in [0, 1] when strategy='mft'; got {aux_prob!r}." + ) + if not isinstance(property_name, str) or not property_name.isidentifier(): + raise ValueError( + "property_name is required when strategy='mft' and must be a " + f"valid Python identifier; got {property_name!r}." + ) + + # ---- frozen_sklearn pipeline (created lazily by fit()) ---- + self._sklearn: _FrozenSklearnPipeline | None = None + self._mft = None + + # ---- backward-compat state mirrors (delegated to pipeline) ---- + if self.type_map is None: + self.type_map = [] + self._target_key = None + self._task_dim = 1 + self.predictor = None # sklearn object after fit() + self._fitted = False + self._model = None # lazy-loaded descriptor model (cached) + self._device = None # set when model is first loaded + self._checkpoint_type_map = [] # set by _load_descriptor_model + self._condition_manager = None + + # ------------------------------------------------------------------ + # Frozen-sklearn pipeline helpers (thin delegators) + # + # Each method forwards to the corresponding method on + # ``_FrozenSklearnPipeline``. State set directly on DPAFineTuner + # (e.g. ``_checkpoint_type_map`` by tests) is propagated into the + # pipeline on each call so that direct setters continue to work. + # ------------------------------------------------------------------ + + def _ensure_sklearn(self) -> _FrozenSklearnPipeline: + """Create the pipeline on first use if it doesn't exist yet.""" + if self._sklearn is None: + self._sklearn = _FrozenSklearnPipeline( + pretrained=self.pretrained, + model_branch=self.model_branch, + predictor_type=self._predictor_type, + pooling=self.pooling, + seed=self.seed, + ) + # Sync state that external code may have set on DPAFineTuner directly, + # without clobbering values loaded lazily by the pipeline. + if self._model is not None: + self._sklearn._model = self._model + elif self._sklearn._model is not None: + self._model = self._sklearn._model + if self._device is not None: + self._sklearn._device = self._device + elif self._sklearn._device is not None: + self._device = self._sklearn._device + if self._checkpoint_type_map: + self._sklearn._checkpoint_type_map = self._checkpoint_type_map + elif self._sklearn._checkpoint_type_map: + self._checkpoint_type_map = list(self._sklearn._checkpoint_type_map) + self._sklearn.type_map = self.type_map + return self._sklearn + + def _load_descriptor_model(self) -> Any: + p = self._ensure_sklearn() + model = p.load_descriptor_model() + self._model = model + self._device = p._device + self._checkpoint_type_map = list(p._checkpoint_type_map) + return model + + def _validate_type_map( + self, user_type_map: list[str], systems: list[dpdata.System] + ) -> None: + return self._ensure_sklearn().validate_type_map(user_type_map, systems) + + def _remap_atom_types( + self, atom_types: np.ndarray, system: dpdata.System + ) -> np.ndarray: + return self._ensure_sklearn().remap_atom_types(atom_types, system) + + def _extract_features_cached(self, systems: list[dpdata.System]) -> np.ndarray: + """Call ``_extract_features`` with descriptor-cache lookup. + + Kept on DPAFineTuner (not delegated) so that patches on + ``DPAFineTuner._extract_features`` are honoured through the + ``self._extract_features()`` call below. + """ + try: + # Lazy import to avoid circular dependency: finetuner → desc_cache → finetuner. + from dpa_adapt.data.desc_cache import ( + _cache_dir, + _cache_key, + ) + + key = _cache_key( + systems, + self.pretrained, + self.model_branch, + self.pooling, + type_map=tuple(self.type_map or ()), + ) + cache_path = _cache_dir() / f"{key}.npy" + if cache_path.is_file(): + return np.load(cache_path) + except Exception: + # Cache read failed (e.g. corrupted file, permissions) — + # fall through and recompute features from scratch. + pass + + features = self._extract_features(systems) + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + np.save(cache_path, features) + except Exception: + # Cache write is best-effort — silently skip on permission errors + # or disk-full conditions; the features are already in memory. + pass + return features + + def _extract_features(self, systems: list[dpdata.System]) -> np.ndarray: + return self._ensure_sklearn().extract_features(systems) + + # ------------------------------------------------------------------ + # The heavy implementations of the following methods now live in + # _FrozenSklearnPipeline (see class docstring above). The thin + # delegators at the top of this class forward calls to the pipeline. + # ------------------------------------------------------------------ + + # ----------------------------------------------------------------------- + # Public API + # ----------------------------------------------------------------------- + + # ------------------------------------------------------------------- + # Type-map auto-inference (shared with MFTFineTuner via data/type_map.py) + # ------------------------------------------------------------------- + + def _resolve_type_maps(self, train_data: str | list[str]) -> list[str]: + """Auto-infer the global type_map from the checkpoint and validate + *train_data* element set is a subset. + + Returns the checkpoint's type_map (e.g. 118-element full periodic + table for DPA-3.1-3M). + """ + from dpa_adapt.data.type_map import ( + read_checkpoint_type_map, + read_data_type_map_union, + validate_type_map_subset, + ) + + try: + systems = load_data(train_data) + except DPADataError: + # Data paths may not exist during testing; fall back gracefully. + return read_checkpoint_type_map( + self.pretrained, + branch=self.init_branch, + ) + + tm = read_checkpoint_type_map( + self.pretrained, + branch=self.init_branch, + ) + + try: + elements = read_data_type_map_union(systems) + except ValueError: + pass # no atom_names — deepmd uses raw atom indices + else: + validate_type_map_subset(elements, tm, label="train data") + + return tm + + # ------------------------------------------------------------------- + # Training-paradigm fit (frozen_head / finetune) + # ------------------------------------------------------------------- + + def _fit_training( + self, + train_data: str | list[str], + valid_data: str | list[str] | None, + type_map: list[str], + ) -> str: + """Delegate to DPATrainer for single-task ``dp --pt train``.""" + from dpa_adapt.trainer import ( + DPATrainer, + ) + + freeze = self.strategy == "frozen_head" + trainer = DPATrainer( + pretrained=self.pretrained, + init_branch=self.init_branch, + freeze_backbone=freeze, + property_name=self.property_name, + task_dim=self.task_dim, + intensive=self.intensive, + train_systems=train_data, + valid_systems=valid_data, + type_map=type_map, + fitting_net_params=self.fitting_net_params, + learning_rate=self.learning_rate, + stop_lr=self.stop_lr, + decay_steps=self.decay_steps if self.decay_steps is not None else 1000, + warmup_steps=self.warmup_steps, + max_steps=self.max_steps, + batch_size=self.batch_size, + loss_function=self.loss_function, + fparam_dim=self.fparam_dim, + seed=self.seed, + output_dir=self.output_dir, + save_freq=self.save_freq, + disp_freq=self.disp_freq, + ) + ckpt_path = trainer.fit() + self._fitted = True + return ckpt_path + + def _latest_training_checkpoint(self) -> str: + ckpts = list(Path(self.output_dir).glob("model.ckpt-*.pt")) + if not ckpts: + raise RuntimeError( + f"No model.ckpt-*.pt found in {self.output_dir}; call fit() first." + ) + + def step_of(path: Path) -> int: + return int(path.stem.split("-")[-1]) + + return str(max(ckpts, key=step_of)) + + @staticmethod + def _expand_system_specs(data: str | list[str]) -> list[str]: + import glob + + patterns = [data] if isinstance(data, str) else list(data) + systems = [] + for pattern in patterns: + matches = sorted(glob.glob(str(pattern))) + systems.extend(matches or [str(pattern)]) + + seen = set() + systems = [s for s in systems if not (s in seen or seen.add(s))] + if not systems: + raise DPADataError(f"No systems matched {data!r}.") + return systems + + def _freeze_training_checkpoint(self, output_path: str = "frozen_model.pth") -> str: + """Freeze a single-task DeePMD checkpoint via ``dp --pt freeze``.""" + ckpt = self._latest_training_checkpoint() + output_path = os.path.abspath(str(output_path)) + output_dir = os.path.abspath(self.output_dir) + os.makedirs(output_dir, exist_ok=True) + + freeze_name = os.path.basename(output_path) + produced = os.path.join(output_dir, freeze_name) + cmd = [ + resolve_dp_command(), + "--pt", + "freeze", + "-c", + ".", + "-o", + freeze_name, + ] + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=output_dir, + ) + if result.returncode != 0: + raise RuntimeError( + f"dp --pt freeze failed (return code {result.returncode}).\n" + f"cmd: {' '.join(cmd)}\n" + f"cwd: {output_dir}\n" + f"checkpoint: {ckpt}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + if not os.path.exists(produced): + raise RuntimeError( + f"dp --pt freeze reported success but {produced} was not " + f"created.\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}" + ) + if os.path.abspath(produced) != output_path: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + shutil.copyfile(produced, output_path) + return output_path + + def _run_training_predict( + self, data: str | list[str], fmt: str | None = None + ) -> DotDict: + """Run ``dp --pt test`` and parse property predictions from detail files.""" + from dpa_adapt.trainer import ( + DPATrainer, + ) + + if fmt is not None: + raise ValueError( + "fmt is not supported for frozen_head/finetune predict(); " + "provide deepmd/npy system directories." + ) + + ckpt = self._latest_training_checkpoint() + systems = self._expand_system_specs(data) + + output_dir = Path(self.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + datafile = output_dir / "predict_systems.txt" + datafile.write_text("\n".join(systems) + "\n") + + detail_prefix = output_dir / "predict_detail" + for old in output_dir.glob(f"{detail_prefix.name}.property.out.*"): + old.unlink() + + cmd = [ + resolve_dp_command(), + "--pt", + "test", + "-m", + ckpt, + "-f", + str(datafile), + "-n", + "999999", + "-d", + str(detail_prefix), + ] + result = subprocess.run(cmd, capture_output=True, text=True) + combined = result.stdout + "\n" + result.stderr + if result.returncode != 0: + raise RuntimeError( + f"dp --pt test failed (return code {result.returncode}).\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + + detail_files = sorted( + output_dir.glob(f"{detail_prefix.name}.property.out.*"), + key=lambda p: int(p.name.rsplit(".", 1)[-1]), + ) + if not detail_files: + raise RuntimeError( + "dp --pt test completed but no property detail files were written. " + f"Command was: {' '.join(cmd)}" + ) + + rows = [] + for path in detail_files: + arr = np.loadtxt(path) + arr = np.asarray(arr, dtype=float) + if arr.ndim == 1: + arr = arr.reshape(1, -1) + if arr.shape[1] < 2: + raise RuntimeError( + f"Expected at least two columns in {path}, got shape {arr.shape}." + ) + rows.append(arr[:, :2]) + + values = np.concatenate(rows, axis=0) + if values.shape[0] % self.task_dim != 0: + raise RuntimeError( + f"Could not reshape property detail rows {values.shape[0]} " + f"into task_dim={self.task_dim}." + ) + + values = values.reshape(-1, self.task_dim, 2) + labels = values[:, :, 0] + predictions = values[:, :, 1] + if self.task_dim == 1: + labels = labels.reshape(-1, 1) + predictions = predictions.reshape(-1, 1) + + metrics = DPATrainer._parse_test_output(combined) + n_sys_match = re.search( + r"number of systems\s*[:=]?\s*(\d+)", combined, re.IGNORECASE + ) + n_systems = int(n_sys_match.group(1)) if n_sys_match else len(systems) + return DotDict( + { + "predictions": predictions, + "labels": labels, + "mae": metrics["mae"], + "rmse": metrics["rmse"], + "n_frames": metrics["n_frames"], + "n_systems": n_systems, + "detail_prefix": str(detail_prefix), + "_raw_stdout": combined, + } + ) + + # ------------------------------------------------------------------- + # fit (dispatch) + # ------------------------------------------------------------------- + + def fit( + self, + train_data: str | list[str], + valid_data: str | list[str] | None = None, + type_map: list[str] | None = None, + target_key: str | list[str] | None = None, + labels: np.ndarray | None = None, + fmt: str | None = None, + aux_data: str | list[str] | None = None, + ) -> str | None: + """Train the model. + + *frozen_sklearn* (default): extract descriptors, fit sklearn head. + *frozen_head* / *finetune*: run ``dp --pt train``. + *mft*: multi-task fine-tuning (property head + force-field head). + + Parameters + ---------- + train_data : str | list[str] + Path(s) to deepmd/npy system directories. + valid_data : str | list[str], optional + Validation system directories. Required for training paradigms; + ignored by ``frozen_sklearn``. + type_map : list[str], optional + Element symbols. Auto-inferred from the checkpoint and data + ``type_map.raw`` when not provided. + target_key : str, optional + (frozen_sklearn) Label key, e.g. ``"energy"``. + labels : np.ndarray, optional + (frozen_sklearn) Pre-computed labels. + fmt : str, optional + Reserved for future format support. + aux_data : str | list[str], optional + (mft only) Auxiliary training system directories. Required when + ``strategy='mft'``; must be absent otherwise. + """ + if self.strategy == "frozen_sklearn": + return self._fit_sklearn(train_data, type_map, target_key, labels, fmt) + + if self.strategy == "mft": + if aux_data is None: + raise ValueError( + "strategy='mft' requires aux_data. " + "Provide auxiliary system directories for the force-field head." + ) + if type_map is not None: + self.type_map = type_map + if self._mft is not None: + self._mft.type_map = type_map + return self._fit_mft(train_data, aux_data, valid_data) + + # ---- single-task training paradigms ---- + if aux_data is not None: + raise ValueError( + f"aux_data is only valid when strategy='mft'; " + f"got strategy={self.strategy!r}." + ) + + if type_map is None: + type_map = self._resolve_type_maps(train_data) + + self.type_map = type_map + return self._fit_training(train_data, valid_data, type_map) + + def _fit_mft( + self, + train_data: str | list[str], + aux_data: str | list[str], + valid_data: str | list[str] | None = None, + ) -> str: + """Delegate to MFTFineTuner for multi-task fine-tuning.""" + mft = self._ensure_mft() + mft.fit(train_data=train_data, aux_data=aux_data, valid_data=valid_data) + self._fitted = True + return self.output_dir + + def _ensure_mft(self) -> Any: + """Create the MFT delegate on first use.""" + from dpa_adapt.mft import ( + MFTFineTuner, + ) + + if self._mft is None: + self._mft = MFTFineTuner( + pretrained=self.pretrained, + aux_branch=self.aux_branch, + aux_prob=self.aux_prob, + # Preserve "omitted" (None) for MFT: __init__ normalizes an + # unset type_map to [] for the frozen-sklearn path, but MFT + # treats an empty list as user-provided and would skip + # checkpoint auto-detection. Pass None so MFT auto-detects. + type_map=self.type_map or None, + fitting_net_params=self.fitting_net_params, + downstream_task_type=self.downstream_task_type, + property_name=self.property_name, + task_dim=self.task_dim, + intensive=self.intensive, + learning_rate=self.learning_rate, + stop_lr=self.stop_lr, + decay_steps=self.decay_steps, + warmup_steps=self.warmup_steps, + max_steps=self.max_steps, + batch_size=self.batch_size, + aux_batch_size=self.aux_batch_size, + downstream_batch_size=self.downstream_batch_size, + seed=self.seed, + fparam_dim=self.fparam_dim, + output_dir=self.output_dir, + save_freq=self.save_freq, + disp_freq=self.disp_freq, + ) + return self._mft + + def _fit_sklearn( + self, + data: str | list[str], + type_map: list[str] | None = None, + target_key: str | list[str] | None = None, + labels: np.ndarray | None = None, + fmt: str | None = None, + ) -> None: + """Fit the frozen-sklearn pipeline (delegates to ``_FrozenSklearnPipeline``). + + Refactored: logic extracted to ``_FrozenSklearnPipeline``; this method + now orchestrates the pipeline and mirrors its state for backward compat. + """ + if target_key is not None and labels is not None: + raise ValueError( + "target_key and labels are mutually exclusive; provide only one." + ) + if target_key is None and labels is None: + raise ValueError("Either target_key or labels must be provided.") + + p = self._ensure_sklearn() + + self.type_map = type_map or [] + p.type_map = self.type_map + self._target_key = target_key if target_key is not None else "property" + + systems = load_data(data, fmt=fmt) + if self._model is None: + self._model = self._load_descriptor_model() + self._validate_type_map(type_map or [], systems) + + features = self._extract_features_cached(systems) + + self._condition_manager = None + if self.fparam_dim > 0: + conditions = _read_fparam_from_systems( + systems, + expected_dim=self.fparam_dim, + ) + self._condition_manager = ConditionManager() + X_cond = self._condition_manager.fit_transform(conditions) + features = np.concatenate([features, X_cond], axis=1) + + if labels is not None: + y = np.asarray(labels) + else: + y = _load_labels(systems, self._target_key) + + self._task_dim = 1 if y.ndim == 1 else y.shape[-1] + y_flat = y.ravel() if self._task_dim == 1 else y + + from sklearn.pipeline import ( + make_pipeline, + ) + from sklearn.preprocessing import ( + StandardScaler, + ) + + from dpa_adapt.utils.sklearn_heads import ( + build_sklearn_head, + ) + + head = build_sklearn_head( + self._predictor_type, + seed=self.seed, + n_outputs=self._task_dim, + ) + self.predictor = make_pipeline(StandardScaler(), head) + self.predictor.fit(features, y_flat) + self._fitted = True + + # Mirror pipeline state for backward compat. + p.predictor = self.predictor + p.type_map = self.type_map + p._target_key = self._target_key + p._task_dim = self._task_dim + p._condition_manager = self._condition_manager + p._fitted = True + + def predict(self, data: str | list[str], fmt: str | None = None) -> DotDict: + """ + Predict with the adapted model. + + ``frozen_sklearn`` extracts features and runs the fitted sklearn + predictor. Training strategies run ``dp --pt test`` and parse the + property predictions from DeePMD's detail files. + + Parameters + ---------- + data : str | list[str] + Path(s) to deepmd/npy system directories. + fmt : str, optional + Reserved for future format support. + + Returns + ------- + DotDict + ``predictions`` : np.ndarray, shape (n_frames, task_dim) + """ + if self.strategy in {"frozen_head", "finetune"}: + return self._run_training_predict(data, fmt=fmt) + if self.strategy == "mft": + if fmt is not None: + raise ValueError( + "fmt is not supported for mft predict(); " + "provide deepmd/npy system directories." + ) + return self._ensure_mft().predict(data) + + if not self._fitted: + raise RuntimeError( + "predict() was called before fit(). Train the model with fit() first." + ) + + systems = load_data(data, fmt=fmt) + features = self._extract_features(systems) + + if self._condition_manager is not None: + try: + conditions = _read_fparam_from_systems( + systems, + expected_dim=self.fparam_dim if self.fparam_dim > 0 else None, + ) + except DPAConditionError as e: + raise DPAConditionError( + "This model was fit with fparam but set.*/fparam.npy " + f"could not be read from the prediction data: {e}" + ) from e + X_cond = self._condition_manager.transform(conditions) + features = np.concatenate([features, X_cond], axis=1) + + raw = self.predictor.predict(features) + predictions = np.asarray(raw).reshape(-1, self._task_dim) + return DotDict({"predictions": predictions}) + + def evaluate(self, data: str | list[str], fmt: str | None = None) -> DotDict: + """ + Predict on ``data`` and compute evaluation metrics against stored labels. + + Parameters + ---------- + data : str | list[str] + Path(s) to deepmd/npy system directories with label files. + fmt : str, optional + Reserved for future format support. + + Returns + ------- + DotDict + mae, rmse, r2 : float + predictions : np.ndarray, shape (n_frames, task_dim) + labels : np.ndarray, shape (n_frames, task_dim) + """ + if self.strategy in {"frozen_head", "finetune"}: + result = self._run_training_predict(data, fmt=fmt) + labels = result.labels + predictions = result.predictions + err = predictions - labels + ss_res = np.sum(err**2) + ss_tot = np.sum((labels - labels.mean()) ** 2) + result["r2"] = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan") + return result + if self.strategy == "mft": + if fmt is not None: + raise ValueError( + "fmt is not supported for mft evaluate(); " + "provide deepmd/npy system directories." + ) + mft = self._ensure_mft() + if getattr(mft, "downstream_task_type", "property") == "ener": + return DotDict(mft.evaluate(data)) + result = mft.predict(data) + labels = result.labels + predictions = result.predictions + err = predictions - labels + ss_res = np.sum(err**2) + ss_tot = np.sum((labels - labels.mean()) ** 2) + result["r2"] = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan") + return result + + result = self.predict(data, fmt=fmt) + predictions = result.predictions + + systems = load_data(data, fmt=fmt) + labels = _load_labels(systems, self._target_key) + labels = labels.reshape(predictions.shape) + + if predictions.shape != labels.shape: + raise DPADataError( + f"Shape mismatch: predictions {predictions.shape} vs " + f"labels {labels.shape}." + ) + + err = predictions - labels + if isinstance(self._target_key, list): + # Per-property metrics + keys = self._target_key + mae, rmse, r2 = {}, {}, {} + for i, key in enumerate(keys): + e_i = err[:, i] + mae[key] = float(np.mean(np.abs(e_i))) + rmse[key] = float(np.sqrt(np.mean(e_i**2))) + ss_res_i = np.sum(e_i**2) + ss_tot_i = np.sum((labels[:, i] - labels[:, i].mean()) ** 2) + r2[key] = ( + float(1.0 - ss_res_i / ss_tot_i) if ss_tot_i > 0 else float("nan") + ) + else: + mae = float(np.mean(np.abs(err))) + rmse = float(np.sqrt(np.mean(err**2))) + ss_res = np.sum(err**2) + ss_tot = np.sum((labels - labels.mean()) ** 2) + r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan") + + return DotDict( + { + "mae": mae, + "rmse": rmse, + "r2": r2, + "predictions": predictions, + "labels": labels, + } + ) + + def freeze(self, output_path: str = "frozen_model.pth") -> str: + """ + Freeze or serialize the fitted model for inference. + + ``frozen_sklearn`` writes a dpa_adapt bundle containing the sklearn + predictor and descriptor metadata. ``frozen_head`` / ``finetune`` use + ``dp --pt freeze`` on the latest training checkpoint. ``mft`` freezes + the downstream MFT head. + + Parameters + ---------- + output_path : str + Destination file path. + + Returns + ------- + str + The resolved ``output_path``. + """ + if not self._fitted: + raise RuntimeError( + "freeze() was called before fit(). Train the model with fit() first." + ) + + if self.strategy in {"frozen_head", "finetune"}: + return self._freeze_training_checkpoint(output_path) + + if self.strategy == "mft": + frozen_path = self._ensure_mft()._freeze_ckpt() + output_path = os.path.abspath(str(output_path)) + if os.path.abspath(frozen_path) != output_path: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + shutil.copyfile(frozen_path, output_path) + return output_path + + if self.predictor is None: + raise RuntimeError( + "freeze() expected a fitted sklearn predictor, but none was found." + ) + + bundle = { + "format_version": 1, + "pretrained": self.pretrained, + "model_branch": self.model_branch, + "predictor": self.predictor, + "target_key": self._target_key, + "type_map": self.type_map, + "task_dim": self._task_dim, + "predictor_type": self._predictor_type, + "pooling": self.pooling, + "condition_manager": self._condition_manager, + "fparam_dim": self.fparam_dim, + } + + output_path = str(output_path) + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + import torch + + torch.save(bundle, output_path) + _LOG = logging.getLogger("dpa_adapt") + _LOG.info("Frozen model saved to: %s", output_path) + return output_path diff --git a/dpa_adapt/main.py b/dpa_adapt/main.py new file mode 100644 index 0000000000..e08201b96b --- /dev/null +++ b/dpa_adapt/main.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Entry point for the ``dpa-adapt`` and ``dpaad`` CLIs. + +This is the console_script target registered in pyproject.toml. +""" + +from dpa_adapt.cli import ( + main, +) + +if __name__ == "__main__": + main() diff --git a/dpa_adapt/mft.py b/dpa_adapt/mft.py new file mode 100644 index 0000000000..731c9623ff --- /dev/null +++ b/dpa_adapt/mft.py @@ -0,0 +1,767 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import glob as _glob +import logging +import os +import re +import subprocess +import sys + +import numpy as np + +from dpa_adapt._backend import ( + load_torch_file, + resolve_dp_command, + resolve_pretrained_path, +) +from dpa_adapt._validation import ( + validate_fparam_dim, +) +from dpa_adapt.utils.dotdict import ( + DotDict, +) + +_LOG = logging.getLogger("dpa_adapt.mft") + + +class MFTFineTuner: + """ + Multi-task fine-tuning via dp --pt train. + + Jointly optimizes a downstream property head and an aux force-field head + on a shared DPA descriptor, preventing representation collapse (per + arXiv:2601.08486). + + Refactored: ``fitting_net_params`` is now lazily resolved from the + checkpoint on first access rather than eagerly in ``__init__``, so + constructing an ``MFTFineTuner`` no longer triggers ``torch.load`` + unless ``fit()`` (or any other accessor) actually needs the value. + + Parameters + ---------- + pretrained : str + Path to the DPA pretrained checkpoint (.pt). + aux_branch : str + Branch name in the checkpoint to initialize the aux head. + Default: 'MP_traj_v024_alldata_mixu' (general materials coverage). + Run `dp --pt show model-branch` to list all options. + aux_prob : float + Sampling probability for the aux branch. Must be in ``[0, 1]``; the + downstream branch uses the complementary probability ``1 - aux_prob``. + This is the primary experimental variable for sensitivity analysis. + Example: aux_prob=0.5 → aux:downstream = 1:1 sampling ratio. + type_map : list[str], optional + The global (shared) type map for MFT training. Both the aux and + downstream branches share a single descriptor, which uses this + type_map to map element symbols to integer indices. It must be a + superset (union) of the elements appearing in both datasets. When + omitted, it is auto-detected from the pretrained checkpoint (which + covers the full periodic table for DPA-3.1-3M). + fitting_net_params : dict, optional + Fitting net architecture for the aux branch. Must match the + checkpoint exactly. When omitted (the default), it is read + automatically from the pretrained checkpoint at + ``sd['model']['_extra_state']['model_params']['model_dict'][aux_branch]['fitting_net']``. + Pass an explicit dict only if you need to override the checkpoint's + config (e.g. for experiments). + downstream_task_type : str + Either ``"property"`` (intensive scalar head, e.g. HOMO/LUMO, the + default) or ``"ener"`` (force-field head, legacy mode). Selects how + the DOWNSTREAM branch's fitting_net and loss are built: + + * ``"property"`` — DOWNSTREAM gets a fresh ``type: property`` + fitting_net (using ``property_name``, ``task_dim``, ``intensive``) + and a property-style MSE loss with no force/virial prefs. This + is what arXiv:2601.08486 Table 3 / Fig 2 reports for HOMO/LUMO. + * ``"ener"`` — DOWNSTREAM reuses the aux fitting_net dict and an + ener-style loss with force/virial prefs. This is the legacy mode + used by earlier mp_data sensitivity-analysis MFT experiments. + property_name : str, optional + Required when ``downstream_task_type="property"``. Name of the + per-system property file (e.g. ``"homo"`` reads ``set.*/homo.npy``). + Must be a valid Python identifier. + task_dim : int + Output dimensionality of the property head. Default ``1``. + intensive : bool + Whether the property is intensive (mean-pool) or extensive (sum). + Default ``True`` (correct for HOMO/LUMO and most molecular + properties). + learning_rate : float + Initial learning rate. + stop_lr : float + Final learning rate. + decay_steps : int + Steps between LR decays for the ``exp`` scheduler (deepmd-kit native). + Default 1000 (property mode) or 5000 (ener mode). + warmup_steps : int + Linear LR warmup steps (deepmd-kit native). 0 = disabled. + max_steps : int + Total training steps. + batch_size : str | int + Batch size (e.g. "auto:32" or 32). + seed : int + Random seed. + output_dir : str + Directory for checkpoints and logs. + save_freq : int + Checkpoint save interval (steps). + disp_freq : int + Log display interval (steps). + """ + + def __init__( + self, + pretrained: str, + aux_branch: str = "MP_traj_v024_alldata_mixu", + aux_prob: float = 0.5, + type_map: list[str] | None = None, + fitting_net_params: dict | None = None, + downstream_task_type: str = "property", + property_name: str | None = None, + task_dim: int = 1, + intensive: bool = True, + learning_rate: float = 1e-3, + stop_lr: float = 1e-5, + decay_steps: int | None = None, # None → auto: 1000 for property, 5000 for ener + warmup_steps: int = 0, + max_steps: int = 50000, + batch_size: str | int = "auto:32", + aux_batch_size: str | int | None = None, + downstream_batch_size: str | int | None = None, + seed: int = 42, + fparam_dim: int = 0, + output_dir: str = "./mft_output", + save_freq: int = 10000, + disp_freq: int = 1000, + ) -> None: + if downstream_task_type not in ("ener", "property"): + raise ValueError( + f"downstream_task_type must be 'ener' or 'property'; " + f"got {downstream_task_type!r}." + ) + if downstream_task_type == "property": + if not isinstance(property_name, str) or not property_name.isidentifier(): + raise ValueError( + "property_name is required when " + "downstream_task_type='property' and must be a valid " + f"Python identifier; got {property_name!r}." + ) + if not isinstance(task_dim, int) or task_dim < 1: + raise ValueError(f"task_dim must be an int >= 1; got {task_dim!r}.") + validate_fparam_dim(fparam_dim) + try: + aux_prob = float(aux_prob) + except (TypeError, ValueError) as exc: + raise ValueError( + f"aux_prob must be a number in [0, 1]; got {aux_prob!r}." + ) from exc + if not 0.0 <= aux_prob <= 1.0: + raise ValueError(f"aux_prob must be in [0, 1]; got {aux_prob!r}.") + + self.type_map = type_map + self.pretrained = resolve_pretrained_path(pretrained) + self.aux_branch = aux_branch + self.aux_prob = aux_prob + # Lazy: only load from ckpt when fitting_net_params is first accessed. + self._fitting_net_params = fitting_net_params + self._fitting_net_params_resolved = fitting_net_params is not None + self.downstream_task_type = downstream_task_type + self.property_name = property_name + self.task_dim = task_dim + self.intensive = intensive + self.learning_rate = learning_rate + self.stop_lr = stop_lr + self.decay_steps = decay_steps + self.warmup_steps = warmup_steps + self.max_steps = max_steps + self.batch_size = batch_size + self.aux_batch_size = aux_batch_size + self.downstream_batch_size = downstream_batch_size + self.seed = seed + self.fparam_dim = fparam_dim + self.output_dir = output_dir + self.save_freq = save_freq + self.disp_freq = disp_freq + + # populated by fit() + self.train_data = None + self.aux_data = None + self.valid_data = None + + # ------------------------------------------------------------------ + # Lazy fitting_net_params resolution + # + # Refactored: torch.load is deferred from __init__ to first access + # so that constructing an MFTFineTuner is cheap. The checkpoint is + # only read when fit() (via MFTConfigManager) or user code accesses + # fitting_net_params and the value was not explicitly provided. + # ------------------------------------------------------------------ + + @property + def fitting_net_params(self) -> dict | None: + if self._fitting_net_params is None and not self._fitting_net_params_resolved: + self._fitting_net_params = self._read_fitting_net_from_ckpt( + self.pretrained, self.aux_branch + ) + self._fitting_net_params_resolved = True + return self._fitting_net_params + + @fitting_net_params.setter + def fitting_net_params(self, value: dict | None) -> None: + self._fitting_net_params = value + + @staticmethod + def _read_fitting_net_from_ckpt(pretrained: str, aux_branch: str) -> dict: + """ + Pull fitting_net config for ``aux_branch`` out of a DPA multi-task + checkpoint. Raises ValueError listing available branches if + ``aux_branch`` isn't present. + """ + sd = load_torch_file(resolve_pretrained_path(pretrained)) + try: + model_dict = sd["model"]["_extra_state"]["model_params"]["model_dict"] + except (KeyError, TypeError) as e: + raise RuntimeError( + f"Could not locate model_dict in checkpoint {pretrained}: " + f"missing key {e!r}. Expected path " + "sd['model']['_extra_state']['model_params']['model_dict']." + ) from e + if aux_branch not in model_dict: + available = sorted(model_dict.keys()) + raise ValueError( + f"aux_branch {aux_branch!r} not found in checkpoint {pretrained}. " + f"Available branches: {available}. " + f"Run `dp --pt show {pretrained} model-branch` to inspect." + ) + return model_dict[aux_branch]["fitting_net"] + + def _validate_and_resolve_type_map( + self, train_data: str | list[str], aux_data: str | list[str] + ) -> None: + """Validate and resolve the global type_map for MFT training. + + Always called by ``fit()`` — whether ``type_map`` is user-provided + or auto-detected. + + - If ``type_map`` was not provided, auto-detect it from the + pretrained checkpoint (which covers the full periodic table for + DPA-3.1-3M, so it is always a superset). + - If ``type_map`` was provided, validate that it covers all elements + appearing in both the downstream and aux datasets (i.e. it must + be the union of the two datasets' element sets). + - In both cases, validate that each dataset's elements are a subset + of the global type_map. + """ + from dpa_adapt.data.loader import ( + load_data, + ) + from dpa_adapt.data.type_map import ( + read_checkpoint_type_map, + read_data_type_map_union, + validate_type_map_subset, + ) + + # Read elements from both datasets. + # If data cannot be loaded (e.g. glob hasn't resolved yet, or the + # data directory does not exist), fall back to empty lists — the + # type_map will still be resolved from the checkpoint below. + try: + train_systems = load_data(train_data) + except Exception: + train_systems = [] + try: + aux_systems = load_data(aux_data) + except Exception: + aux_systems = [] + + if not self.type_map: + # Not provided (None) or empty list — auto-detect from the + # checkpoint, which is always a superset. + self.type_map = read_checkpoint_type_map( + self.pretrained, + branch=self.aux_branch, + ) + else: + # User-provided: validate that it covers both datasets. + downstream_elems = [] + aux_elems = [] + try: + downstream_elems = read_data_type_map_union(train_systems) + except ValueError: + pass # no atom_names — deepmd uses raw atom indices + try: + aux_elems = read_data_type_map_union(aux_systems) + except ValueError: + pass + + required = set(downstream_elems) | set(aux_elems) + missing = required - set(self.type_map) + if missing: + raise ValueError( + "The provided type_map is missing elements " + "required by the training data.\n" + f" Missing elements: {sorted(missing)}\n" + f" Downstream data elements: " + f"{sorted(downstream_elems) if downstream_elems else '(none)'}\n" + f" Aux data elements: " + f"{sorted(aux_elems) if aux_elems else '(none)'}\n" + f" Provided type_map: {self.type_map}\n" + "The type_map must be the union (superset) of both " + "datasets' elements." + ) + + # Validate both datasets are subsets of the global type_map. + for label, systems in [ + ("downstream", train_systems), + ("aux", aux_systems), + ]: + if not systems: + continue + try: + elements = read_data_type_map_union(systems) + except ValueError: + continue # no atom_names — deepmd uses raw atom indices + validate_type_map_subset( + elements, + self.type_map, + label=f"{label} data", + ) + + def fit( + self, + train_data: str | list[str], + aux_data: str | list[str], + valid_data: str | list[str] | None = None, + ) -> None: + """ + Run MFT training. + + Parameters + ---------- + train_data : str or list[str] + Downstream deepmd/npy directory (or list of directories). + DeePMD-kit requires the standard label filename ``energy.npy`` + under each ``set.*`` subdir. If the raw data uses a custom name + like ``e_form.npy``, create a symlink before calling fit(): + + ln -sf set.000/e_form.npy set.000/energy.npy + + force.npy is optional (loss weight applies regardless; set to 0 + if absent). + + aux_data : str or list[str] + Aux deepmd/npy directory. Must have energy.npy + force.npy. + + valid_data : str, optional + Validation deepmd/npy directory. + """ + self.train_data = train_data + self.aux_data = aux_data + self.valid_data = valid_data + + if self.fparam_dim > 0: + from dpa_adapt.trainer import ( + DPATrainer, + ) + + DPATrainer._validate_fparam(train_data, self.fparam_dim) + if valid_data is not None: + DPATrainer._validate_fparam(valid_data, self.fparam_dim) + + import glob + + train_dirs = train_data if isinstance(train_data, list) else [train_data] + for sys_path in train_dirs: + e_form_sets = glob.glob(os.path.join(sys_path, "set.*", "e_form.npy")) + for e_form_path in e_form_sets: + energy_path = os.path.join(os.path.dirname(e_form_path), "energy.npy") + if not os.path.exists(energy_path): + _LOG.warning( + "%s exists but %s is missing. DeePMD-kit expects " + "energy.npy — create a symlink: ln -sf e_form.npy %s", + e_form_path, + energy_path, + energy_path, + ) + + os.makedirs(self.output_dir, exist_ok=True) + + # Validate and resolve type_map — always runs, whether type_map + # is user-provided or auto-detected. + self._validate_and_resolve_type_map(train_data, aux_data) + + from dpa_adapt.config.manager import ( + MFTConfigManager, + ) + + cm = MFTConfigManager(self) + config = cm.build() + input_json = os.path.abspath(os.path.join(self.output_dir, "mft_input.json")) + cm.save(config, input_json) + cmd = cm.build_cmd(input_json) + + log_path = os.path.abspath(os.path.join(self.output_dir, "train.log")) + _LOG.info("Running: %s", " ".join(cmd)) + _LOG.info("Log: %s", log_path) + + with open(log_path, "w") as log_f: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + for line in process.stdout: + sys.stdout.write(line) + sys.stdout.flush() + log_f.write(line) + log_f.flush() + process.wait() + + if process.returncode != 0: + raise RuntimeError( + f"dp --pt train failed (return code {process.returncode}).\n" + f"cmd: {cmd}\n" + f"See {log_path} for full output." + ) + + # ----- evaluate ----- + # `dp --pt test` for a multi-task ckpt requires a frozen .pth produced by + # `dp --pt freeze --head ` (property | DOWNSTREAM). + # Feeding the raw .pt silently yields all-zero predictions. The frozen file + # is cached in `output_dir` so a second evaluate() call is fast. + # + # The "Energy MAE/Natoms" line is per-atom; downstream BOOM analysis wants + # per-molecule "Energy MAE". The regex below requires whitespace between + # "MAE" and ":" so the "/Natoms" variant cannot match. dp prints per-system + # blocks followed by a "weighted average of errors" block — we use findall + # and take the LAST occurrence. + _ENERGY_MAE_RE = re.compile( + r"Energy\s+MAE\s+:\s*([0-9eE.+-]+)\s*\S+", re.IGNORECASE + ) + _ENERGY_RMSE_RE = re.compile( + r"Energy\s+RMSE\s+:\s*([0-9eE.+-]+)\s*\S+", re.IGNORECASE + ) + _PROPERTY_MAE_RE = re.compile( + r"PROPERTY\s+MAE\s+:\s*([0-9eE.+-]+)\s*\S*", re.IGNORECASE + ) + _PROPERTY_RMSE_RE = re.compile( + r"PROPERTY\s+RMSE\s+:\s*([0-9eE.+-]+)\s*\S*", re.IGNORECASE + ) + _N_SYSTEMS_RE = re.compile(r"number of systems\s*[:=]?\s*(\d+)", re.IGNORECASE) + + @property + def _downstream_head(self) -> str: + """Branch/head name of the downstream task. Paper property mode uses + "property" (matching MFTConfigManager); legacy ener mode keeps + "DOWNSTREAM". + """ + return ( + "property" + if getattr(self, "downstream_task_type", "ener") == "property" + else "DOWNSTREAM" + ) + + def _freeze_ckpt(self) -> str: + """ + Freeze ``model.ckpt-{max_steps}.pt`` to ``frozen_.pth`` in + ``output_dir`` (head = "property" or "DOWNSTREAM"). Skips if the frozen + file already exists. + + Returns the absolute path to the frozen .pth. + """ + head = self._downstream_head + frozen_name = f"frozen_{head}.pth" + frozen_path = os.path.join(self.output_dir, frozen_name) + if os.path.exists(frozen_path): + return frozen_path + + ckpt = os.path.join(self.output_dir, f"model.ckpt-{self.max_steps}.pt") + if not os.path.isfile(ckpt): + raise RuntimeError( + f"Expected checkpoint {ckpt} not found; cannot freeze. " + f"Did fit() complete successfully?" + ) + + # `dp --pt freeze -c .` picks up the checkpoint file from cwd, so we + # must cd into output_dir. + freeze_cmd = [ + resolve_dp_command(), + "--pt", + "freeze", + "-c", + ".", + "-o", + frozen_name, + "--head", + head, + ] + result = subprocess.run( + freeze_cmd, + capture_output=True, + text=True, + cwd=self.output_dir, + ) + if result.returncode != 0: + raise RuntimeError( + f"dp --pt freeze failed (return code {result.returncode}).\n" + f"cmd: {' '.join(freeze_cmd)}\n" + f"cwd: {self.output_dir}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + if not os.path.exists(frozen_path): + raise RuntimeError( + f"dp --pt freeze reported success but {frozen_path} was not " + f"created.\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}" + ) + return frozen_path + + @staticmethod + def _resolve_test_data(test_data: str | list[str]) -> list[str]: + """ + Normalize ``test_data`` (single path, glob string, or list of paths/ + globs) to a flat list of system directories. + """ + if isinstance(test_data, str): + patterns = [test_data] + else: + patterns = list(test_data) + + resolved = [] + for pat in patterns: + if _glob.has_magic(pat): + matches = sorted(_glob.glob(pat)) + if not matches: + raise RuntimeError(f"Glob pattern {pat!r} resolved to 0 systems.") + resolved.extend(matches) + else: + resolved.append(pat) + + # de-duplicate preserving order + seen = set() + unique = [] + for p in resolved: + if p not in seen: + seen.add(p) + unique.append(p) + if not unique: + raise RuntimeError(f"test_data {test_data!r} resolved to 0 systems.") + return unique + + def evaluate(self, test_data: str | list[str]) -> dict: + """ + Evaluate the downstream head of the MFT checkpoint via ``dp --pt test``. + + Pipeline: + 1. ``dp --pt freeze --head `` to produce ``frozen_.pth`` + (head = "property" in paper property mode, "DOWNSTREAM" in legacy + ener mode; cached in ``output_dir``). + 2. Resolve ``test_data`` (str path, glob string, or list) to a flat + list of system directories. + 3. Write the list to a datafile and call ``dp --pt test -m + -f -n 999999`` once. (Spawning one dp test per system + is unacceptably slow — ~9s/process x hundreds of systems.) + 4. Parse the LAST occurrence of MAE / RMSE from the combined + stdout+stderr — this is the weighted average across all systems. + For ener tasks the keywords are ``Energy MAE`` / ``Energy RMSE`` + (the "Energy MAE/Natoms" variant is rejected by requiring + whitespace between MAE and ``:``). For property tasks the + keywords are ``PROPERTY MAE`` / ``PROPERTY RMSE``. The parser + auto-detects the format from the output. + + Parameters + ---------- + test_data : str or list[str] + Either a single system path, a glob string, or a list of paths / + globs. + + Returns + ------- + dict + ``{"mae": float, "rmse": float, "n_systems": int, + "_parser_pattern_used": str, "_raw_stdout": str}``. + + Notes + ----- + The DeePMD-kit output labels the unit as ``eV`` regardless of the + actual training units; callers using Hartree-trained checkpoints + should treat the returned numbers as Hartree. + """ + frozen_path = self._freeze_ckpt() + + systems = self._resolve_test_data(test_data) + + os.makedirs(self.output_dir, exist_ok=True) + datafile = os.path.join(self.output_dir, "test_systems.txt") + with open(datafile, "w") as f: + f.write("\n".join(systems) + "\n") + + cmd = [ + resolve_dp_command(), + "--pt", + "test", + "-m", + frozen_path, + "-f", + datafile, + "-n", + "999999", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + combined = result.stdout + "\n" + result.stderr + if result.returncode != 0: + raise RuntimeError( + f"dp --pt test failed (return code {result.returncode}).\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + + return self._parse_test_output(combined, n_resolved=len(systems)) + + def predict(self, test_data: str | list[str]) -> DotDict: + """ + Predict property labels with the downstream MFT property head. + + This uses the same frozen downstream head as ``evaluate()``, but passes + ``-d`` to ``dp --pt test`` and parses the generated property detail + files so callers get frame-level labels and predictions. + """ + if self._downstream_head != "property": + raise RuntimeError( + "MFT predict() is only supported for downstream_task_type='property'. " + "Energy-mode MFT can still use evaluate() for aggregate metrics." + ) + + frozen_path = self._freeze_ckpt() + systems = self._resolve_test_data(test_data) + + os.makedirs(self.output_dir, exist_ok=True) + datafile = os.path.join(self.output_dir, "predict_systems.txt") + with open(datafile, "w") as f: + f.write("\n".join(systems) + "\n") + + detail_prefix = os.path.join(self.output_dir, "predict_detail") + detail_name = os.path.basename(detail_prefix) + for old in _glob.glob( + os.path.join(self.output_dir, f"{detail_name}.property.out.*") + ): + os.remove(old) + + cmd = [ + resolve_dp_command(), + "--pt", + "test", + "-m", + frozen_path, + "-f", + datafile, + "-n", + "999999", + "-d", + detail_prefix, + ] + result = subprocess.run(cmd, capture_output=True, text=True) + combined = result.stdout + "\n" + result.stderr + if result.returncode != 0: + raise RuntimeError( + f"dp --pt test failed (return code {result.returncode}).\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + + detail_files = sorted( + _glob.glob(os.path.join(self.output_dir, f"{detail_name}.property.out.*")), + key=lambda p: int(os.path.basename(p).rsplit(".", 1)[-1]), + ) + if not detail_files: + raise RuntimeError( + "dp --pt test completed but no property detail files were written. " + f"Command was: {' '.join(cmd)}" + ) + + rows = [] + for path in detail_files: + arr = np.loadtxt(path) + arr = np.asarray(arr, dtype=float) + if arr.ndim == 1: + arr = arr.reshape(1, -1) + if arr.shape[1] < 2: + raise RuntimeError( + f"Expected at least two columns in {path}, got shape {arr.shape}." + ) + rows.append(arr[:, :2]) + + values = np.concatenate(rows, axis=0) + if values.shape[0] % self.task_dim != 0: + raise RuntimeError( + f"Could not reshape property detail rows {values.shape[0]} " + f"into task_dim={self.task_dim}." + ) + + values = values.reshape(-1, self.task_dim, 2) + labels = values[:, :, 0] + predictions = values[:, :, 1] + if self.task_dim == 1: + labels = labels.reshape(-1, 1) + predictions = predictions.reshape(-1, 1) + + metrics = self._parse_test_output(combined, n_resolved=len(systems)) + metrics.update( + { + "predictions": predictions, + "labels": labels, + "detail_prefix": detail_prefix, + } + ) + return DotDict(metrics) + + @classmethod + def _parse_test_output(cls, combined: str, n_resolved: int = 0) -> dict: + """ + Extract weighted-average ``mae`` / ``rmse`` (last match) and + ``n_systems`` from ``dp --pt test`` output. + + Auto-detects output format: "PROPERTY MAE" / "PROPERTY RMSE" for + property tasks, "Energy MAE" / "Energy RMSE" for ener tasks. + + Raises ``RuntimeError`` with diagnostic context if neither MAE nor + RMSE can be parsed — silent NaN returns previously masked the Bug-1 + all-zero failure for months, so we fail loudly instead. + """ + if "PROPERTY MAE" in combined or "PROPERTY RMSE" in combined: + mae_matches = cls._PROPERTY_MAE_RE.findall(combined) + rmse_matches = cls._PROPERTY_RMSE_RE.findall(combined) + tag = "PROPERTY" + else: + mae_matches = cls._ENERGY_MAE_RE.findall(combined) + rmse_matches = cls._ENERGY_RMSE_RE.findall(combined) + tag = "Energy" + + if not mae_matches and not rmse_matches: + tail = "\n".join(combined.splitlines()[-100:]) + raise RuntimeError( + "Could not parse Energy MAE or RMSE from `dp --pt test` " + "output. The most common cause is feeding a raw .pt ckpt " + "instead of a frozen .pth, which silently produces zero " + "predictions and no MAE/RMSE lines. Re-check the freeze " + "step.\n----- last 100 lines of combined stdout+stderr -----\n" + f"{tail}\n----------------------" + ) + + mae = float(mae_matches[-1]) if mae_matches else float("nan") + rmse = float(rmse_matches[-1]) if rmse_matches else float("nan") + + n_sys_match = cls._N_SYSTEMS_RE.search(combined) + n_systems = int(n_sys_match.group(1)) if n_sys_match else n_resolved + + pattern_used = f"{tag} MAE (last); {tag} RMSE (last)" + return { + "mae": mae, + "rmse": rmse, + "n_systems": n_systems, + "_parser_pattern_used": pattern_used, + "_raw_stdout": combined, + } diff --git a/dpa_adapt/predictor.py b/dpa_adapt/predictor.py new file mode 100644 index 0000000000..970bef7088 --- /dev/null +++ b/dpa_adapt/predictor.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# dpa_adapt/predictor.py + +from typing import ( + Any, +) + +import numpy as np + +from dpa_adapt.conditions import ( + DPAConditionError, +) +from dpa_adapt.data.loader import ( + load_data, +) +from dpa_adapt.utils.dotdict import ( + DotDict, +) + + +def _unwrap_multioutput(est: Any) -> Any: + """If *est* is a ``MultiOutputRegressor``, return the wrapped estimator.""" + from sklearn.multioutput import ( + MultiOutputRegressor, + ) + + if isinstance(est, MultiOutputRegressor): + return est.estimator + return est + + +def _is_rf(est: Any) -> bool: + from sklearn.ensemble import ( + RandomForestRegressor, + ) + + return isinstance(_unwrap_multioutput(est), RandomForestRegressor) + + +def _is_ridge(est: Any) -> bool: + from sklearn.linear_model import ( + Ridge, + ) + + return isinstance(_unwrap_multioutput(est), Ridge) + + +def _is_mlp(est: Any) -> bool: + from sklearn.neural_network import ( + MLPRegressor, + ) + + return isinstance(est, MLPRegressor) + + +def _rf_tree_predictions(est: Any, features: np.ndarray) -> np.ndarray: + """Return RF per-tree predictions with shape ``(n_trees, n_frames, dim)``.""" + from sklearn.ensemble import ( + RandomForestRegressor, + ) + from sklearn.multioutput import ( + MultiOutputRegressor, + ) + + if isinstance(est, MultiOutputRegressor): + per_output = [] + for rf in est.estimators_: + if not isinstance(rf, RandomForestRegressor): + raise TypeError( + "Expected MultiOutputRegressor(RandomForestRegressor), " + f"got wrapped estimator {type(rf).__name__!r}." + ) + per_output.append( + np.array([tree.predict(features) for tree in rf.estimators_]) + ) + return np.stack(per_output, axis=-1) + + if isinstance(est, RandomForestRegressor): + tree_preds = np.array([tree.predict(features) for tree in est.estimators_]) + return tree_preds.reshape(len(est.estimators_), -1, 1) + + raise TypeError( + "RF uncertainty requires RandomForestRegressor or " + f"MultiOutputRegressor(RandomForestRegressor), got {type(est).__name__!r}." + ) + + +class DPAPredictor: + """ + Read-only inference wrapper for a frozen DPA+sklearn bundle. + + Parameters + ---------- + model_path : str + Path to a frozen model file produced by ``DPAFineTuner.freeze()``. + n_committee : int + Number of committee members for uncertainty estimation. + Default 1 uses the single estimator from the bundle unchanged. + """ + + def __init__(self, model_path: str, n_committee: int = 1) -> None: + from dpa_adapt._backend import ( + load_torch_file, + ) + + bundle = load_torch_file(model_path) + + # Reject bundles from future versions we cannot read. + fmt = bundle.get("format_version") + if fmt is not None and fmt != 1: + raise ValueError( + f"Unsupported frozen-model format version {fmt}. " + "This version of dpa_adapt only supports format_version 1. " + "Re-freeze the model with the current dpa_adapt version." + ) + + # Detect models frozen with dpa_adapt <0.2 (missing modern metadata). + if "predictor" in bundle and "pooling" not in bundle: + raise ValueError( + "This model was frozen with dpa_adapt <0.2. " + "Re-freeze with the current version: " + "model.freeze(output_dir)." + ) + + self._predictor = bundle["predictor"] + self._target_key = bundle["target_key"] # str or list[str] + self._type_map = bundle["type_map"] + self._task_dim = bundle["task_dim"] + self._pretrained = bundle["pretrained"] + self._model_branch = bundle.get("model_branch") + self._pooling = bundle["pooling"] + self._condition_manager = bundle.get("condition_manager") + self._fparam_dim = bundle.get("fparam_dim", 0) + self.n_committee = n_committee + + # Detect estimator type from the final pipeline step. + final_est = self._predictor.steps[-1][1] + if _is_rf(final_est): + self._estimator_type = "rf" + elif _is_ridge(final_est): + self._estimator_type = "ridge" + elif _is_mlp(final_est): + self._estimator_type = "mlp" + else: + self._estimator_type = "unknown" + + from dpa_adapt.finetuner import ( + DPAFineTuner, + ) + + # TODO: replace with dedicated DescriptorExtractor class after refactor. + # For now, DPAFineTuner is reused purely as a descriptor feature extractor. + self._extractor = DPAFineTuner( + pretrained=self._pretrained, + model_branch=self._model_branch, + predictor="linear", + pooling=self._pooling, + type_map=self._type_map, + fparam_dim=self._fparam_dim, + ) + + def fit( + self, + data: str | list[str], + target_key: str | list[str] | None = None, + labels: np.ndarray | None = None, + fmt: str | None = None, + ) -> None: + """Train committee members for uncertainty estimation. + + Only valid when *n_committee* > 1. Clones the frozen sklearn + pipeline *n_committee* times with different random seeds and + stores the ensemble as ``self.estimators_``. Also computes + ``self.uncertainty_threshold_`` (95th-percentile train-set std). + """ + if self.n_committee <= 1: + raise RuntimeError( + "fit() requires n_committee > 1. " + "The single-estimator predictor is ready to use as-is." + ) + + from sklearn.base import ( + clone, + ) + + from dpa_adapt.finetuner import ( + _load_labels, + _read_fparam_from_systems, + ) + + if target_key is not None and labels is not None: + raise ValueError("target_key and labels are mutually exclusive") + if target_key is None and labels is None: + raise ValueError("Either target_key or labels must be provided") + + systems = load_data(data, fmt=fmt) + if self._extractor._model is None: + self._extractor._model = self._extractor._load_descriptor_model() + self._extractor._validate_type_map(self._type_map, systems) + features = self._extractor._extract_features(systems) + + if self._condition_manager is not None: + try: + conditions = _read_fparam_from_systems( + systems, + expected_dim=self._fparam_dim if self._fparam_dim else None, + ) + except DPAConditionError as e: + raise DPAConditionError( + "This model was fit with fparam but set.*/fparam.npy " + f"could not be read from the prediction data: {e}" + ) from e + X_cond = self._condition_manager.transform(conditions) + features = np.concatenate([features, X_cond], axis=1) + + if labels is not None: + y = np.asarray(labels) + else: + y = _load_labels(systems, target_key) + + y_flat = y.ravel() if y.ndim == 1 or y.shape[-1] == 1 else y + + self.estimators_ = [] + for seed in range(self.n_committee): + est = clone(self._predictor) + try: + est[-1].set_params(random_state=seed) + except ValueError: + # Estimator does not support random_state (e.g. KNeighborsRegressor); + # training with the default clone is fine — no ensemble diversity needed. + pass + est.fit(features, y_flat) + self.estimators_.append(est) + + preds = np.array([e.predict(features) for e in self.estimators_]) + preds = preds.reshape(self.n_committee, -1, self._task_dim) + self.uncertainty_threshold_ = float(np.percentile(np.std(preds, axis=0), 95)) + + def _extract_and_condition( + self, data: str | list[str], fmt: str | None + ) -> np.ndarray: + """Shared feature extraction + fparam auto-read.""" + from dpa_adapt.finetuner import ( + _read_fparam_from_systems, + ) + + systems = load_data(data, fmt=fmt) + if self._extractor._model is None: + self._extractor._model = self._extractor._load_descriptor_model() + self._extractor._validate_type_map(self._type_map, systems) + features = self._extractor._extract_features(systems) + + if self._condition_manager is not None: + conditions = _read_fparam_from_systems( + systems, + expected_dim=self._fparam_dim if self._fparam_dim else None, + ) + X_cond = self._condition_manager.transform(conditions) + features = np.concatenate([features, X_cond], axis=1) + + return features + + def predict( + self, + data: str | list[str], + fmt: str | None = None, + return_uncertainty: bool = False, + ) -> DotDict: + """ + Run inference on ``data``. + + fparam is automatically read from ``set.*/fparam.npy`` when the + model was fit with fparam. + + Parameters + ---------- + data : str | list[str] + Path(s) to deepmd/npy system directories. + fmt : str, optional + Reserved for future format support. + return_uncertainty : bool + When True, include ``"uncertainty"`` (per-sample std) in the + result. Behaviour depends on estimator type and committee + configuration. + + Returns + ------- + DotDict + ``predictions`` : np.ndarray, shape (n_frames, task_dim) + ``uncertainty`` : np.ndarray, shape (n_frames, task_dim) (if requested) + """ + features = self._extract_and_condition(data, fmt) + + if return_uncertainty: + return self._predict_with_uncertainty(features) + + if self.n_committee > 1: + preds = np.array([e.predict(features) for e in self.estimators_]) + preds = preds.reshape(self.n_committee, -1, self._task_dim) + return DotDict({"predictions": np.mean(preds, axis=0)}) + + raw = self._predictor.predict(features) + predictions = np.asarray(raw).reshape(-1, self._task_dim) + return DotDict({"predictions": predictions}) + + def _predict_with_uncertainty(self, features: np.ndarray) -> DotDict: + """Per-estimator uncertainty dispatch.""" + if self._estimator_type == "rf": + X_t = features + for _, step in self._predictor.steps[:-1]: + X_t = step.transform(X_t) + rf = self._predictor.steps[-1][1] + tree_preds = _rf_tree_predictions(rf, X_t) + tree_preds = tree_preds.reshape(tree_preds.shape[0], -1, self._task_dim) + return DotDict( + { + "predictions": np.mean(tree_preds, axis=0), + "uncertainty": np.std(tree_preds, axis=0), + } + ) + + if self._estimator_type in ("ridge", "linear"): + raise ValueError( + "Ridge regression has a unique closed-form solution and " + "cannot produce uncertainty estimates. " + "Use estimator='rf' or estimator='mlp' for uncertainty." + ) + + if self.n_committee > 1: + preds = np.array([e.predict(features) for e in self.estimators_]) + preds = preds.reshape(self.n_committee, -1, self._task_dim) + return DotDict( + { + "predictions": np.mean(preds, axis=0), + "uncertainty": np.std(preds, axis=0), + } + ) + + raise RuntimeError( + f"Uncertainty estimation requires either estimator='rf' " + f"or n_committee > 1 (for committee-based uncertainty). " + f"Got estimator_type={self._estimator_type!r} " + f"with n_committee={self.n_committee}." + ) + + def evaluate(self, data: str | list[str], fmt: str | None = None) -> DotDict: + """ + Predict on ``data`` and compute evaluation metrics against stored labels. + + Parameters + ---------- + data : str | list[str] + Path(s) to deepmd/npy system directories with label files. + fmt : str, optional + Reserved for future format support. + + Returns + ------- + DotDict + mae, rmse, r2 : float + predictions : np.ndarray, shape (n_frames, task_dim) + labels : np.ndarray, shape (n_frames, task_dim) + """ + from dpa_adapt.data.errors import ( + DPADataError, + ) + from dpa_adapt.finetuner import ( + _load_labels, + ) + + result = self.predict(data, fmt=fmt) + predictions = result.predictions + + systems = load_data(data, fmt=fmt) + labels = _load_labels(systems, self._target_key) + labels = labels.reshape(predictions.shape) + + if predictions.shape != labels.shape: + raise DPADataError( + f"Shape mismatch: predictions {predictions.shape} vs " + f"labels {labels.shape}." + ) + + err = predictions - labels + if isinstance(self._target_key, list): + # Per-property metrics + keys = self._target_key + mae, rmse, r2 = {}, {}, {} + for i, key in enumerate(keys): + e_i = err[:, i] + mae[key] = float(np.mean(np.abs(e_i))) + rmse[key] = float(np.sqrt(np.mean(e_i**2))) + ss_res_i = np.sum(e_i**2) + ss_tot_i = np.sum((labels[:, i] - labels[:, i].mean()) ** 2) + r2[key] = ( + float(1.0 - ss_res_i / ss_tot_i) if ss_tot_i > 0 else float("nan") + ) + else: + mae = float(np.mean(np.abs(err))) + rmse = float(np.sqrt(np.mean(err**2))) + ss_res = np.sum(err**2) + ss_tot = np.sum((labels - labels.mean()) ** 2) + r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan") + + return DotDict( + { + "mae": mae, + "rmse": rmse, + "r2": r2, + "predictions": predictions, + "labels": labels, + } + ) diff --git a/dpa_adapt/trainer.py b/dpa_adapt/trainer.py new file mode 100644 index 0000000000..36ce953c1d --- /dev/null +++ b/dpa_adapt/trainer.py @@ -0,0 +1,777 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# dpa_adapt/trainer.py +""" +DPATrainer: drives ``dp --pt train`` for Scratch / FT / LP adaptation modes, +mirroring the comparison setup of arXiv:2601.08486 (Table 3 / Fig 2). + +Mode is selected by constructor arguments: + +| Mode | ``pretrained`` | ``freeze_backbone`` | +| ------- | ---------------- | ------------------- | +| Scratch | ``None`` | ``False`` | +| FT | path to ckpt | ``False`` | +| LP | path to ckpt | ``True`` | + +MFT lives in :class:`dpa_adapt.mft.MFTFineTuner`; the sklearn-head +(frozen_sklearn strategy) lives in +:class:`dpa_adapt.finetuner.DPAFineTuner`. +""" + +from __future__ import ( + annotations, +) + +import copy +import glob as _glob +import json +import logging +import os +import re +import subprocess +from typing import ( + ClassVar, +) + +from dpa_adapt._backend import ( + resolve_dp_command, + resolve_pretrained_path, +) +from dpa_adapt._validation import ( + validate_fparam_dim, +) + +_LOG = logging.getLogger("dpa_adapt.trainer") + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +# Fallback descriptor config used when pretrained=None (Scratch mode). +# Must match DPA-3.1-3M exactly. Source: ckpt _extra_state.model_params.shared_dict. +DPA3_DESCRIPTOR_DEFAULT = { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 16, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 1200, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 300, + "axis_neuron": 4, + "skip_stat": True, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "smooth_edge_update": True, + "use_dynamic_sel": True, + "sel_reduce_factor": 10.0, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "n_multi_edge_message": 1, + "optim_update": True, + "use_exp_switch": True, + "fix_stat_std": 0.3, + }, + # Paper qm9_gap input.json uses "silut:3.0" (alias of "custom_silu:3.0"; + # verified identical output in deepmd-kit 3.1.3). + "activation_function": "silut:3.0", + "precision": "float32", + "use_tebd_bias": False, + "concat_output_tebd": False, + "exclude_types": [], + "env_protection": 0.0, + "trainable": True, + "use_econf_tebd": False, +} + +DEFAULT_FITTING_NET = { + "type": "property", + "neuron": [240, 240, 240], + "activation_function": "tanh", # paper Table 8 + "resnet_dt": True, + "precision": "float32", +} + +_VALID_LOSSES = ("mse", "smooth_mae") + + +# --------------------------------------------------------------------------- +# DPATrainer +# --------------------------------------------------------------------------- + + +class DPATrainer: + """ + Drive ``dp --pt train`` for Scratch / FT / LP downstream adaptation. + + Parameters + ---------- + pretrained : str or None + Path to a DPA pretrained checkpoint (.pt). ``None`` means Scratch. + init_branch : str + Branch name in the checkpoint used to initialize the descriptor. + Only consulted when ``pretrained`` is given. + freeze_backbone : bool + If True, freeze the descriptor (LP mode). Requires ``pretrained``. + property_name : str + Name of the property npy file under ``set.000/`` (e.g. ``"homo"``). + Must be a valid Python identifier. + task_dim : int + Output dimensionality of the property head. Must be ``>= 1``. + intensive : bool + Whether the property is intensive (mean-pool) or extensive (sum). + train_systems, valid_systems : str or list[str] + Globs (or list of globs) resolving to deepmd/npy system directories. + Both required. + type_map : list[str] + Element symbols. Required; no auto-inference. + fitting_net_params : dict, optional + Overrides for the property head config (shallow-merged onto the + defaults). The defaults are ``DEFAULT_FITTING_NET`` plus + ``property_name``, ``task_dim``, ``intensive``, ``seed``. + learning_rate, stop_lr : float + Exp-decay LR endpoints. + decay_steps : int + Steps between LR decays (DeePMD-kit ``exp`` scheduler). Default 1000. + warmup_steps : int + Linear LR warmup steps (DeePMD-kit native). 0 = disabled. + max_steps : int + Total training steps. + batch_size : str or int + DeePMD-kit batch_size spec (e.g. ``"auto:512"``). + loss_function : str + ``"mse"`` or ``"smooth_mae"``. + seed : int + Random seed. + output_dir : str + Directory for checkpoints, input.json, and manifests. + save_freq, disp_freq : int + DeePMD-kit save/display intervals. + """ + + def __init__( + self, + # ---- pretraining / freezing ---- + pretrained: str | None = None, + init_branch: str = "SPICE2", + freeze_backbone: bool = False, + # ---- downstream task ---- + property_name: str = "homo", + task_dim: int = 1, + intensive: bool = True, + # ---- data ---- + train_systems: str | list | None = None, + valid_systems: str | list | None = None, + type_map: list | None = None, + # ---- model overrides ---- + fitting_net_params: dict | None = None, + fparam_dim: int = 0, + # ---- training ---- + learning_rate: float = 1e-3, + stop_lr: float = 1e-5, + decay_steps: int = 1000, + warmup_steps: int = 0, + max_steps: int = 100_000, + batch_size: str | int = "auto:512", + loss_function: str = "mse", + seed: int = 42, + # ---- output ---- + output_dir: str = "./dpa_output", + save_freq: int = 10_000, + disp_freq: int = 1_000, + ) -> None: + # ---- validation ---- + if train_systems is None: + raise ValueError("train_systems is required (got None).") + if valid_systems is None: + raise ValueError("valid_systems is required (got None).") + if type_map is None: + raise ValueError( + "type_map is required. Pass an explicit list of element " + "symbols (e.g. the SPICE2 full periodic table). " + "Auto-inference is intentionally not supported." + ) + if not isinstance(type_map, list) or not all( + isinstance(x, str) for x in type_map + ): + raise ValueError("type_map must be a list of element symbol strings.") + if freeze_backbone and pretrained is None: + raise ValueError( + "LP requires a pretrained checkpoint to freeze. " + "Set freeze_backbone=False for Scratch, or pass a pretrained ckpt." + ) + if pretrained is not None: + pretrained = resolve_pretrained_path(pretrained) + if not os.path.isfile(pretrained): + raise ValueError(f"pretrained checkpoint not found: {pretrained!r}.") + if not isinstance(property_name, str) or not property_name.isidentifier(): + raise ValueError( + f"property_name must be a valid Python identifier " + f"(no spaces or slashes); got {property_name!r}." + ) + if not isinstance(task_dim, int) or task_dim < 1: + raise ValueError(f"task_dim must be an int >= 1; got {task_dim!r}.") + if loss_function not in _VALID_LOSSES: + raise ValueError( + f"loss_function must be one of {_VALID_LOSSES}; got {loss_function!r}." + ) + validate_fparam_dim(fparam_dim) + + self.pretrained = pretrained + self.init_branch = init_branch + self.freeze_backbone = freeze_backbone + self.property_name = property_name + self.task_dim = task_dim + self.intensive = intensive + self.train_systems = train_systems + self.valid_systems = valid_systems + self.type_map = type_map + self.fitting_net_params = fitting_net_params + self.fparam_dim = fparam_dim + self.learning_rate = learning_rate + self.stop_lr = stop_lr + self.decay_steps = decay_steps + self.warmup_steps = warmup_steps + self.max_steps = max_steps + self.batch_size = batch_size + self.loss_function = loss_function + self.seed = seed + self.output_dir = output_dir + self.save_freq = save_freq + self.disp_freq = disp_freq + + # ----- mode label (debugging convenience) ----- + @property + def mode(self) -> str: + return "LP" if self.freeze_backbone else "FT" + + # ----- descriptor sourcing ----- + def _read_descriptor_from_ckpt(self) -> dict: + import torch + + sd = torch.load(self.pretrained, map_location="cpu", weights_only=False) + try: + descriptor = sd["model"]["_extra_state"]["model_params"]["shared_dict"][ + "dpa3_descriptor" + ] + except (KeyError, TypeError) as e: + raise RuntimeError( + f"Could not locate dpa3_descriptor in checkpoint {self.pretrained}: " + f"missing key {e!r}. Expected path sd['model']['_extra_state']" + "['model_params']['shared_dict']['dpa3_descriptor']." + ) from e + return copy.deepcopy(descriptor) + + def _get_descriptor(self) -> dict: + if self.pretrained is not None: + descriptor = self._read_descriptor_from_ckpt() + else: + descriptor = copy.deepcopy(DPA3_DESCRIPTOR_DEFAULT) + # Paper alignment (qm9_gap input.json): silut:3.0 activation (alias of + # the ckpt's custom_silu:3.0) + explicit fix_stat_std=0.3. + descriptor["activation_function"] = "silut:3.0" + descriptor["repflow"]["fix_stat_std"] = 0.3 + # LP: freeze the descriptor by setting trainable=False on the descriptor + # block. DeePMD-kit 3.1.3 honors this field in the `--finetune` code path + # (verified by reading deepmd.pt.train.training; the descriptor's + # `requires_grad_` is set from this flag at init). If a future deepmd-kit + # version changes this, switch to passing `--freeze-descriptor` to the + # CLI or use `dp --pt freeze` as a post-processing step. + descriptor["trainable"] = not self.freeze_backbone + return descriptor + + # ----- glob expansion ----- + @staticmethod + def _expand_systems(spec: str | list[str], label: str) -> list: + if isinstance(spec, str): + patterns = [spec] + else: + patterns = list(spec) + resolved: list = [] + for pat in patterns: + matches = sorted(_glob.glob(pat)) + resolved.extend(matches) + # de-duplicate while preserving order + seen = set() + unique = [] + for p in resolved: + if p not in seen: + seen.add(p) + unique.append(p) + if not unique: + raise ValueError( + f"{label} resolved to 0 systems from patterns={patterns!r}. " + f"Check the glob and that the directories exist." + ) + if len(unique) < 50: + _LOG.warning( + "%s resolved to only %d systems (patterns=%r). " + "MFT-paper BOOM splits typically yield 500/300 for train/valid.", + label, + len(unique), + patterns, + ) + return unique + + # ----- config build ----- + def _build_fitting_net(self) -> dict: + fn = copy.deepcopy(DEFAULT_FITTING_NET) + fn.update( + { + "property_name": self.property_name, + "task_dim": self.task_dim, + "intensive": self.intensive, + # verified: deepmd.utils.argcheck.fitting_property() accepts seed + # (inspect.getsource shows Argument("seed", [int, None], optional=True)) + "seed": self.seed, + } + ) + # NB: dim_case_embd is intentionally NOT injected for FT/LP. The paper + # qm9_gap input.json omits it: single-task `--finetune` (without + # --model-branch) copies only the backbone and random-inits the + # property head at [128, 240], so there is no [159, 240] checkpoint + # head to size-match against. An explicit user value still wins. + if self.fparam_dim > 0: + fn["numb_fparam"] = self.fparam_dim + if self.fitting_net_params: + fn.update(self.fitting_net_params) + return fn + + def _build_config(self) -> dict: + # Seed propagation in DeePMD-kit v3.1.3 (deepmd/utils/argcheck.py): + # - model.descriptor.seed verified: descrpt_dpa3_args() L1428 + # - model.fitting_net.seed verified: fitting_property() L1966 + # - training.seed verified: training_args() L3856 + # A top-level "seed" was previously added as a "v0/v1 compat default" + # but deepmd 3.1.3 dargs is strict-mode and rejects unknown root keys + # (ArgumentKeyError at root location). Do NOT re-add it. + train_sys = self._expand_systems(self.train_systems, "train_systems") + valid_sys = self._expand_systems(self.valid_systems, "valid_systems") + self._resolved_train_systems = train_sys + self._resolved_valid_systems = valid_sys + + descriptor = self._get_descriptor() + descriptor["seed"] = self.seed # verified: descrpt_dpa3_args (deepmd v3.1.3) + fitting_net = self._build_fitting_net() + + return { + "model": { + "type_map": self.type_map, + "descriptor": descriptor, + "fitting_net": fitting_net, + }, + "loss": { + "type": "property", + "loss_func": self.loss_function, + "metric": ["mae", "rmse"], + }, + "learning_rate": { + "type": "exp", + "start_lr": self.learning_rate, + "stop_lr": self.stop_lr, + "decay_steps": self.decay_steps, + **( + {"warmup_steps": self.warmup_steps} if self.warmup_steps > 0 else {} + ), + }, + "training": { + "training_data": { + "systems": train_sys, + "batch_size": self.batch_size, + }, + "validation_data": { + "systems": valid_sys, + "batch_size": self.batch_size, + }, + "numb_steps": self.max_steps, + "seed": self.seed, # verified: training_args (deepmd v3.1.3) + # Paper qm9_gap: gradient_max_norm=5.0 (gradient clipping). + "gradient_max_norm": 5.0, + "disp_freq": self.disp_freq, + "save_freq": self.save_freq, + # Absolute path so checkpoints land in output_dir without + # depending on the caller's cwd (we no longer pass --output). + "save_ckpt": os.path.join(self.output_dir, "model.ckpt"), + }, + } + + # ----- CLI build ----- + def _build_cmd(self, input_json: str) -> list: + # Paper qm9_gap uses `dp --pt train --finetune ` with NO + # --model-branch: single-task fine-tune copies the backbone and + # random-inits the property head. Passing --model-branch would try to + # copy a branch's [159, 240] property head and fail with a size + # mismatch. `--skip-neighbor-stat` is kept (paper omits it, but our + # data-stat pass is too slow); deepmd honors `training.save_ckpt` from + # the JSON so no `--output` flag is needed. + cmd = [resolve_dp_command(), "--pt", "train", str(input_json)] + cmd += ["--skip-neighbor-stat"] + if self.pretrained is not None: + cmd += ["--finetune", str(self.pretrained)] + return cmd + + # ----- checkpoint discovery ----- + def _find_latest_checkpoint(self) -> tuple: + """ + Return ``(Path | None, int)`` for the checkpoint with the largest + step in ``output_dir``, or ``(None, 0)`` if none exist. + """ + from pathlib import ( + Path, + ) + + ckpts = list(Path(self.output_dir).glob("model.ckpt-*.pt")) + if not ckpts: + return None, 0 + + def step_of(p: Path) -> int: + return int(p.stem.split("-")[-1]) + + latest = max(ckpts, key=step_of) + return latest, step_of(latest) + + def _final_ckpt_path(self) -> str | None: + latest, _ = self._find_latest_checkpoint() + return str(latest) if latest is not None else None + + # ----- fparam validation ----- + @staticmethod + def _validate_fparam(systems_spec: str | list[str], fparam_dim: int) -> None: + """Check that every set.* directory contains fparam.npy with correct shape. + + Parameters + ---------- + systems_spec : str or list[str] + Glob patterns or paths to system directories. + fparam_dim : int + Expected second dimension of the fparam array (must be > 0). + + Raises + ------ + DPADataError + If any set.* directory is missing fparam.npy or its shape[1] + does not match *fparam_dim*. + """ + import glob + + import numpy as np + + from dpa_adapt.data.errors import ( + DPADataError, + ) + + # Expand globs to system directories (same logic as _expand_systems + # but without logging warnings — this is pure validation). + if isinstance(systems_spec, str): + patterns = [systems_spec] + else: + patterns = list(systems_spec) + + system_dirs: list = [] + for pat in patterns: + matches = sorted(glob.glob(pat)) + system_dirs.extend(matches) + + # De-duplicate while preserving order. + seen = set() + system_dirs = [d for d in system_dirs if not (d in seen or seen.add(d))] + + for sys_dir in system_dirs: + set_dirs = sorted(glob.glob(os.path.join(sys_dir, "set.*"))) + for sd in set_dirs: + fpath = os.path.join(sd, "fparam.npy") + if not os.path.isfile(fpath): + raise DPADataError( + f"fparam_dim={fparam_dim} but {fpath} is missing. " + f"Every set.* directory under {sys_dir} must contain " + f"fparam.npy of shape (n_frames, {fparam_dim})." + ) + shape = np.load(fpath).shape + if len(shape) != 2: + raise DPADataError( + f"fparam.npy at {fpath} has shape {shape}; expected a " + f"2-D array (n_frames, {fparam_dim})." + ) + if shape[1] != fparam_dim: + raise DPADataError( + f"fparam.npy at {fpath} has shape {shape} " + f"but fparam_dim={fparam_dim}. " + f"Expected shape (n_frames, {fparam_dim})." + ) + # Preflight: one fparam row per frame (coord.npy frame count). + coord_path = os.path.join(sd, "coord.npy") + if os.path.isfile(coord_path): + n_frames = np.load(coord_path, mmap_mode="r").shape[0] + if shape[0] != n_frames: + raise DPADataError( + f"fparam.npy at {fpath} has {shape[0]} rows but set " + f"{sd} has {n_frames} frames (coord.npy); expected " + f"one fparam row per frame." + ) + + # ----- fit ----- + def fit(self) -> str: + """ + Run ``dp --pt train``. + + Returns + ------- + str + Path to the final ``model.ckpt-.pt``. + + Notes + ----- + Idempotency: training is skipped if a checkpoint at step + ``>= max_steps`` exists in ``output_dir``. If ``max_steps`` is + increased between runs (i.e. only a shorter checkpoint exists), + training is restarted from ``pretrained`` — checkpoint resumption is + not supported. + """ + os.makedirs(self.output_dir, exist_ok=True) + + latest, step = self._find_latest_checkpoint() + if latest is not None and step >= self.max_steps: + _LOG.info( + "Skipping training: found %s (step %d) >= max_steps=%d", + latest, + step, + self.max_steps, + ) + return str(latest) + + if self.fparam_dim > 0: + self._validate_fparam(self.train_systems, self.fparam_dim) + if self.valid_systems is not None: + self._validate_fparam(self.valid_systems, self.fparam_dim) + + config = self._build_config() + input_json = os.path.join(self.output_dir, "input.json") + with open(input_json, "w") as f: + json.dump(config, f, indent=2) + + manifest_train = os.path.join(self.output_dir, "manifest_train.txt") + with open(manifest_train, "w") as f: + f.write("\n".join(self._resolved_train_systems) + "\n") + manifest_valid = os.path.join(self.output_dir, "manifest_valid.txt") + with open(manifest_valid, "w") as f: + f.write("\n".join(self._resolved_valid_systems) + "\n") + + cmd = self._build_cmd(input_json) + # fit() deliberately echoes the CLI so the user can rerun it manually. + _LOG.info("Running: %s", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"dp --pt train failed (return code {result.returncode}).\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + + ckpt = self._final_ckpt_path() + if ckpt is None: + raise RuntimeError( + f"Training finished but no model.ckpt-*.pt was found in " + f"{self.output_dir}." + ) + return ckpt + + # ----- evaluate ----- + def evaluate(self, test_systems: str | list) -> dict: + """ + Run ``dp --pt test`` on the trained checkpoint. + + Parameters + ---------- + test_systems : str or list[str] + Glob (or list of globs) resolving to deepmd/npy system dirs. + + Returns + ------- + dict + ``{'rmse': float, 'mae': float, 'n_frames': int, 'n_systems': int, + '_raw_stdout': str, '_parser_pattern_used': str}``. + Raises ``RuntimeError`` if neither RMSE nor MAE can be parsed. + + Notes + ----- + Uses ``dp --pt test -f `` (single-value flag taking a path + to a file listing one system per line). Previously used multiple + ``-s`` flags, but argparse honored only the last one and the parser + silently succeeded with a single-system result. + """ + ckpt = self._final_ckpt_path() + if ckpt is None: + raise RuntimeError( + f"No checkpoint found in {self.output_dir}; call fit() first." + ) + systems = self._expand_systems(test_systems, "test_systems") + + # Write the resolved system paths to a datafile and pass via -f. + # This is dp --pt test's native multi-system input mode (see + # `dp --pt test --help`). + os.makedirs(self.output_dir, exist_ok=True) + datafile = os.path.join(self.output_dir, "test_systems.txt") + with open(datafile, "w") as f: + f.write("\n".join(systems) + "\n") + + cmd = [ + resolve_dp_command(), + "--pt", + "test", + "-m", + ckpt, + "-f", + datafile, + "-n", + "999999", + ] + _LOG.info( + "Running: %s (with %d systems listed in %s)", + " ".join(cmd), + len(systems), + datafile, + ) + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"dp --pt test failed (return code {result.returncode}).\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + # DeePMD-kit logs PROPERTY MAE/RMSE to stderr (Python logging default). + # Feed both streams to the parser. + combined = result.stdout + "\n" + result.stderr + + parsed = self._parse_test_output(combined) + + # Sanity check: extract the "# number of systems : N" line and verify + # it matches our resolved list. + n_sys_match = re.search( + r"number of systems\s*[:=]?\s*(\d+)", combined, re.IGNORECASE + ) + if n_sys_match: + n_found = int(n_sys_match.group(1)) + parsed["n_systems"] = n_found + if n_found != len(systems): + _LOG.warning( + "dp test reports %d systems but %d were resolved; " + "some systems may have been skipped (missing labels?)", + n_found, + len(systems), + ) + else: + parsed["n_systems"] = 0 + _LOG.warning( + "Could not extract 'number of systems' from dp test output; " + "inspect _raw_stdout." + ) + + return parsed + + # ----- test-output parsing ----- + # Calibrated against real deepmd-kit 3.1.3 `dp --pt test` stderr (property + # task). Sample line: "PROPERTY RMSE : 6.065579e-02 units" + # The output appears twice — once per system, once in "weighted average of + # errors" — so the parser uses findall and takes the LAST match (Fix 3). + # + # Refactored: replaced fragile multi-pattern regex fallback chain with a + # single well-anchored regex per metric type, auto-detected from the output. + # Generic \brmse\b / \bmae\b fallback patterns removed; unparseable output + # now raises RuntimeError with the last 50 lines of stdout+stderr. + _PROPERTY_RMSE_RE = re.compile( + r"PROPERTY\s+RMSE\s+:\s*([0-9eE.+-]+)", re.IGNORECASE + ) + _PROPERTY_MAE_RE = re.compile(r"PROPERTY\s+MAE\s+:\s*([0-9eE.+-]+)", re.IGNORECASE) + _ENERGY_RMSE_RE = re.compile( + r"Energy\s+RMSE\s+:\s*([0-9eE.+-]+)\s*\S+", re.IGNORECASE + ) + _ENERGY_MAE_RE = re.compile( + r"Energy\s+MAE\s+:\s*([0-9eE.+-]+)\s*\S+", re.IGNORECASE + ) + _N_FRAMES_PATTERNS: ClassVar[list] = [ + re.compile(r"number of test data\s*[:=]?\s*(\d+)", re.IGNORECASE), + re.compile(r"#\s*of test data\s*[:=]?\s*(\d+)", re.IGNORECASE), + re.compile(r"\bn_frames\b\s*[:=]?\s*(\d+)", re.IGNORECASE), + ] + + @classmethod + def _parse_test_output(cls, stdout: str) -> dict: + """ + Extract ``rmse``, ``mae``, ``n_frames`` from ``dp --pt test`` stdout. + + Auto-detects output format — ``PROPERTY MAE`` / ``PROPERTY RMSE`` for + property tasks, ``Energy MAE`` / ``Energy RMSE`` for ener tasks — + and applies a single well-anchored regex per metric type. No generic + fallback patterns are used; if parsing fails a ``RuntimeError`` is + raised with the last 50 lines of the combined output. + + Refactored: replaced fragile multi-pattern regex fallback chain with + format-aware, single-pattern-per-metric parsing. + """ + # Auto-detect output format from the presence of known metric labels. + if "PROPERTY MAE" in stdout or "PROPERTY RMSE" in stdout: + mae_re = cls._PROPERTY_MAE_RE + rmse_re = cls._PROPERTY_RMSE_RE + tag = "PROPERTY" + elif "Energy MAE" in stdout or "Energy RMSE" in stdout: + mae_re = cls._ENERGY_MAE_RE + rmse_re = cls._ENERGY_RMSE_RE + tag = "Energy" + else: + tail = "\n".join(stdout.splitlines()[-50:]) + raise RuntimeError( + "Could not parse MAE or RMSE from `dp --pt test` output. " + "No PROPERTY MAE/RMSE or Energy MAE/RMSE lines found.\n" + "----- last 50 lines of combined stdout+stderr -----\n" + f"{tail}\n" + "----------------------" + ) + + # Take the LAST match. dp --pt test prints per-system errors followed by + # a "weighted average of errors" block; the weighted average is what we + # want when multiple systems are evaluated together. For a single-system + # test, the per-system and weighted lines have the same value. + mae_matches = mae_re.findall(stdout) + rmse_matches = rmse_re.findall(stdout) + + if not mae_matches and not rmse_matches: + tail = "\n".join(stdout.splitlines()[-50:]) + raise RuntimeError( + f"Detected {tag} output format but could not extract numeric " + "MAE or RMSE values.\n" + "----- last 50 lines of combined stdout+stderr -----\n" + f"{tail}\n" + "----------------------" + ) + + mae = float(mae_matches[-1]) if mae_matches else float("nan") + rmse = float(rmse_matches[-1]) if rmse_matches else float("nan") + + # TODO: for the total across systems we'd need to sum all matches; + # here we take the last (per-system) match. `n_frames` is currently + # only used for logging, so this approximation is acceptable. + n_frames = 0 + for pat in cls._N_FRAMES_PATTERNS: + matches = pat.findall(stdout) + if matches: + n_frames = int(matches[-1]) + break + + pattern_used = f"{tag} MAE (last); {tag} RMSE (last)" + return { + "rmse": rmse, + "mae": mae, + "n_frames": n_frames, + "_raw_stdout": stdout, + "_parser_pattern_used": pattern_used, + } diff --git a/dpa_adapt/utils/__init__.py b/dpa_adapt/utils/__init__.py new file mode 100644 index 0000000000..93ef3210cf --- /dev/null +++ b/dpa_adapt/utils/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .dotdict import ( + DotDict, +) + +__all__ = ["DotDict"] diff --git a/dpa_adapt/utils/dotdict.py b/dpa_adapt/utils/dotdict.py new file mode 100644 index 0000000000..dc03800813 --- /dev/null +++ b/dpa_adapt/utils/dotdict.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# utils/dotdict.py + +from typing import ( + Any, +) + + +class DotDict(dict): + """A dict subclass that allows attribute-style access.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(f"'DotDict' has no attribute '{name}'") from None + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + try: + del self[name] + except KeyError: + raise AttributeError(f"'DotDict' has no attribute '{name}'") from None diff --git a/dpa_adapt/utils/sklearn_heads.py b/dpa_adapt/utils/sklearn_heads.py new file mode 100644 index 0000000000..b18398e560 --- /dev/null +++ b/dpa_adapt/utils/sklearn_heads.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# utils/sklearn_heads.py +# +# Single source of truth for building sklearn predictor heads. +# Used by DPAFineTuner._fit_sklearn() and cv._build_sklearn_head(). + +from typing import ( + Any, +) + + +def build_sklearn_head(predictor_type: str, seed: int = 42, n_outputs: int = 1) -> Any: + """Build an sklearn estimator for the given predictor type. + + Parameters + ---------- + predictor_type : str + One of ``"rf"``, ``"linear"`` / ``"ridge"``, or ``"mlp"``. + seed : int + Random seed for reproducibility. + n_outputs : int + Number of output dimensions. When > 1, ``"rf"`` and ``"ridge"`` + are automatically wrapped in ``MultiOutputRegressor``. ``"mlp"`` + supports multi-output natively and ignores this parameter. + + Returns + ------- + estimator + An sklearn-compatible regressor (NOT wrapped in a Pipeline). + + Raises + ------ + ValueError + If *predictor_type* is not recognised. + """ + if predictor_type in ("linear", "ridge"): + from sklearn.linear_model import ( + Ridge, + ) + + est = Ridge(alpha=1.0, random_state=seed) + if n_outputs > 1: + from sklearn.multioutput import ( + MultiOutputRegressor, + ) + + return MultiOutputRegressor(est) + return est + + if predictor_type == "rf": + from sklearn.ensemble import ( + RandomForestRegressor, + ) + + est = RandomForestRegressor(n_estimators=100, random_state=seed) + if n_outputs > 1: + from sklearn.multioutput import ( + MultiOutputRegressor, + ) + + return MultiOutputRegressor(est) + return est + + if predictor_type == "mlp": + from sklearn.neural_network import ( + MLPRegressor, + ) + + return MLPRegressor( + hidden_layer_sizes=(512, 512, 256), + max_iter=2000, + alpha=0.0, + learning_rate_init=1e-3, + random_state=seed, + early_stopping=True, + validation_fraction=0.1, + n_iter_no_change=20, + ) + + raise ValueError( + f"Unknown predictor type: {predictor_type!r}. " + "Supported: 'rf', 'linear'/'ridge', 'mlp'." + ) diff --git a/examples/dpa_adapt/README.md b/examples/dpa_adapt/README.md new file mode 100644 index 0000000000..744daeaad6 --- /dev/null +++ b/examples/dpa_adapt/README.md @@ -0,0 +1,81 @@ +# ADAPT example + +This directory contains a small ready-to-run example for `dpa_adapt`. +The example uses 8 pre-processed QM9 molecules to fine-tune and evaluate a +DPA-based HOMO-LUMO gap predictor. + +The processed data is already included, so you can run the demo directly. + +## Directory layout + +```text +examples/dpa_adapt/ +|-- data/ # ready-to-use processed data +| |-- train/ # 5 training systems in deepmd/npy format +| |-- test/ # 3 test systems in deepmd/npy format +| |-- train_labels.npy +| `-- test_labels.npy +|-- scripts/ +| |-- run_evaluate_frozen_sklearn.py # frozen_sklearn demo: DPA-3.1-3M + Ridge +| |-- run_evaluate_frozen_head.py # frozen_head demo: DPA-3.1-3M fine-tuning +| `-- prepare_data.py # regenerate data/ from raw GDB9 data +`-- README.md +``` + +## Run the example + +Two evaluation scripts are provided, demonstrating different adaptation strategies. + +From this directory, run either (or both): + +```bash +# frozen_sklearn strategy - extract DPA features, fit a Ridge regressor +python scripts/run_evaluate_frozen_sklearn.py + +# frozen_head strategy - fine-tune the prediction head with gradient steps +python scripts/run_evaluate_frozen_head.py +``` + +If you do not activate the virtual environment, run the same commands with the +environment's Python executable, for example: + +```bash +../../../.venv/Scripts/python.exe scripts/run_evaluate_frozen_head.py +``` + +### `run_evaluate_frozen_sklearn.py` + +Uses the `frozen_sklearn` strategy with the `Domains_Drug` model branch. +DPA-3.1-3M features are extracted from the training systems and a Ridge (`linear`) +regressor is fitted on top. Prints MAE, RMSE, and R2 on the test set. + +### `run_evaluate_frozen_head.py` + +Uses the `frozen_head` strategy. A fresh prediction head is trained on top of +frozen DPA-3.1-3M features with `learning_rate=1e-3`, `batch_size=128`, +`max_steps=5`. Prints predictions and evaluation metrics (MAE, RMSE, R2) on the +test set. + +## About the included data + +The `data/` directory already contains the processed example dataset. Each system +is stored in `deepmd/npy` format. The included split has 5 training systems and +3 test systems. Each `set.000/` directory contains a `gap.npy` label file. The +label key used by the example is `gap`. + +In normal use, you do not need to run any data preparation step. + +## Regenerating the data + +`scripts/prepare_data.py` is provided only for reproducibility. It rebuilds the +included `data/` directory from raw GDB9/QM9 files. + +Run it only if you want to recreate the processed data: + +```bash +python scripts/prepare_data.py +``` + +The script downloads `gdb9.tar.gz`, extracts the raw SDF and CSV files into +`raw/`, converts the first 8 molecules to `deepmd/npy`, and writes HOMO-LUMO gap +labels as `gap.npy`. diff --git a/examples/dpa_adapt/data/test/sys_0000/set.000/box.npy b/examples/dpa_adapt/data/test/sys_0000/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0000/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0000/set.000/coord.npy b/examples/dpa_adapt/data/test/sys_0000/set.000/coord.npy new file mode 100644 index 0000000000..745ee99891 Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0000/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0000/set.000/gap.npy b/examples/dpa_adapt/data/test/sys_0000/set.000/gap.npy new file mode 100644 index 0000000000..5cee9edd37 Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0000/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0000/type.raw b/examples/dpa_adapt/data/test/sys_0000/type.raw new file mode 100644 index 0000000000..dfc30a5ba4 --- /dev/null +++ b/examples/dpa_adapt/data/test/sys_0000/type.raw @@ -0,0 +1,12 @@ +1 +1 +3 +1 +0 +0 +0 +0 +0 +0 +0 +0 diff --git a/examples/dpa_adapt/data/test/sys_0000/type_map.raw b/examples/dpa_adapt/data/test/sys_0000/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/test/sys_0000/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/test/sys_0001/set.000/box.npy b/examples/dpa_adapt/data/test/sys_0001/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0001/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0001/set.000/coord.npy b/examples/dpa_adapt/data/test/sys_0001/set.000/coord.npy new file mode 100644 index 0000000000..0858c44a1d Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0001/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0001/set.000/gap.npy b/examples/dpa_adapt/data/test/sys_0001/set.000/gap.npy new file mode 100644 index 0000000000..c8600e5b19 Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0001/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0001/type.raw b/examples/dpa_adapt/data/test/sys_0001/type.raw new file mode 100644 index 0000000000..7a4f9bbd93 --- /dev/null +++ b/examples/dpa_adapt/data/test/sys_0001/type.raw @@ -0,0 +1,10 @@ +3 +1 +1 +3 +0 +0 +0 +0 +0 +0 diff --git a/examples/dpa_adapt/data/test/sys_0001/type_map.raw b/examples/dpa_adapt/data/test/sys_0001/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/test/sys_0001/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/test/sys_0002/set.000/box.npy b/examples/dpa_adapt/data/test/sys_0002/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0002/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0002/set.000/coord.npy b/examples/dpa_adapt/data/test/sys_0002/set.000/coord.npy new file mode 100644 index 0000000000..3ec8d5e643 Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0002/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0002/set.000/gap.npy b/examples/dpa_adapt/data/test/sys_0002/set.000/gap.npy new file mode 100644 index 0000000000..46200504aa Binary files /dev/null and b/examples/dpa_adapt/data/test/sys_0002/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/test/sys_0002/type.raw b/examples/dpa_adapt/data/test/sys_0002/type.raw new file mode 100644 index 0000000000..947d132b92 --- /dev/null +++ b/examples/dpa_adapt/data/test/sys_0002/type.raw @@ -0,0 +1,12 @@ +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 +0 +0 diff --git a/examples/dpa_adapt/data/test/sys_0002/type_map.raw b/examples/dpa_adapt/data/test/sys_0002/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/test/sys_0002/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/test_labels.npy b/examples/dpa_adapt/data/test_labels.npy new file mode 100644 index 0000000000..f723b764b0 Binary files /dev/null and b/examples/dpa_adapt/data/test_labels.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0000/set.000/box.npy b/examples/dpa_adapt/data/train/sys_0000/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0000/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0000/set.000/coord.npy b/examples/dpa_adapt/data/train/sys_0000/set.000/coord.npy new file mode 100644 index 0000000000..5a6e5d4e8b Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0000/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0000/set.000/gap.npy b/examples/dpa_adapt/data/train/sys_0000/set.000/gap.npy new file mode 100644 index 0000000000..a093b5dbe6 Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0000/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0000/type.raw b/examples/dpa_adapt/data/train/sys_0000/type.raw new file mode 100644 index 0000000000..533994c2f9 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0000/type.raw @@ -0,0 +1,5 @@ +1 +0 +0 +0 +0 diff --git a/examples/dpa_adapt/data/train/sys_0000/type_map.raw b/examples/dpa_adapt/data/train/sys_0000/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0000/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/train/sys_0001/set.000/box.npy b/examples/dpa_adapt/data/train/sys_0001/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0001/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0001/set.000/coord.npy b/examples/dpa_adapt/data/train/sys_0001/set.000/coord.npy new file mode 100644 index 0000000000..f14d8166a3 Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0001/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0001/set.000/gap.npy b/examples/dpa_adapt/data/train/sys_0001/set.000/gap.npy new file mode 100644 index 0000000000..ff560acb77 Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0001/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0001/type.raw b/examples/dpa_adapt/data/train/sys_0001/type.raw new file mode 100644 index 0000000000..f3b28367b7 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0001/type.raw @@ -0,0 +1,4 @@ +2 +0 +0 +0 diff --git a/examples/dpa_adapt/data/train/sys_0001/type_map.raw b/examples/dpa_adapt/data/train/sys_0001/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0001/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/train/sys_0002/set.000/box.npy b/examples/dpa_adapt/data/train/sys_0002/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0002/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0002/set.000/coord.npy b/examples/dpa_adapt/data/train/sys_0002/set.000/coord.npy new file mode 100644 index 0000000000..a5c7d56af0 Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0002/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0002/set.000/gap.npy b/examples/dpa_adapt/data/train/sys_0002/set.000/gap.npy new file mode 100644 index 0000000000..efc503ab0d Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0002/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0002/type.raw b/examples/dpa_adapt/data/train/sys_0002/type.raw new file mode 100644 index 0000000000..6c9eabe634 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0002/type.raw @@ -0,0 +1,3 @@ +3 +0 +0 diff --git a/examples/dpa_adapt/data/train/sys_0002/type_map.raw b/examples/dpa_adapt/data/train/sys_0002/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0002/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/train/sys_0003/set.000/box.npy b/examples/dpa_adapt/data/train/sys_0003/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0003/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0003/set.000/coord.npy b/examples/dpa_adapt/data/train/sys_0003/set.000/coord.npy new file mode 100644 index 0000000000..662b7b1266 Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0003/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0003/set.000/gap.npy b/examples/dpa_adapt/data/train/sys_0003/set.000/gap.npy new file mode 100644 index 0000000000..eca7dffff1 Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0003/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0003/type.raw b/examples/dpa_adapt/data/train/sys_0003/type.raw new file mode 100644 index 0000000000..d9ff83f194 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0003/type.raw @@ -0,0 +1,4 @@ +1 +1 +0 +0 diff --git a/examples/dpa_adapt/data/train/sys_0003/type_map.raw b/examples/dpa_adapt/data/train/sys_0003/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0003/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/train/sys_0004/set.000/box.npy b/examples/dpa_adapt/data/train/sys_0004/set.000/box.npy new file mode 100644 index 0000000000..0ffa6656ca Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0004/set.000/box.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0004/set.000/coord.npy b/examples/dpa_adapt/data/train/sys_0004/set.000/coord.npy new file mode 100644 index 0000000000..78981c8acc Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0004/set.000/coord.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0004/set.000/gap.npy b/examples/dpa_adapt/data/train/sys_0004/set.000/gap.npy new file mode 100644 index 0000000000..c35d40c7a8 Binary files /dev/null and b/examples/dpa_adapt/data/train/sys_0004/set.000/gap.npy differ diff --git a/examples/dpa_adapt/data/train/sys_0004/type.raw b/examples/dpa_adapt/data/train/sys_0004/type.raw new file mode 100644 index 0000000000..a384d6e471 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0004/type.raw @@ -0,0 +1,3 @@ +1 +2 +0 diff --git a/examples/dpa_adapt/data/train/sys_0004/type_map.raw b/examples/dpa_adapt/data/train/sys_0004/type_map.raw new file mode 100644 index 0000000000..9f0af9e987 --- /dev/null +++ b/examples/dpa_adapt/data/train/sys_0004/type_map.raw @@ -0,0 +1,5 @@ +H +C +N +O +F diff --git a/examples/dpa_adapt/data/train_labels.npy b/examples/dpa_adapt/data/train_labels.npy new file mode 100644 index 0000000000..c516e814cf Binary files /dev/null and b/examples/dpa_adapt/data/train_labels.npy differ diff --git a/examples/dpa_adapt/scripts/prepare_data.py b/examples/dpa_adapt/scripts/prepare_data.py new file mode 100644 index 0000000000..08709cfc1b --- /dev/null +++ b/examples/dpa_adapt/scripts/prepare_data.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +# One-time data preparation script. Data is already included in +# examples/dpa_adapt/data/. Only re-run if you need to regenerate from raw GDB9. +"""Download QM9 GDB9 and prepare deepmd/npy systems for the quickstart demo. + +Reads molecules 1-8 from the SDF, reads HOMO-LUMO gaps from the companion +CSV file, stages a small 8-row dataset, converts it with ``dpa_adapt.convert``, +and splits into 5 training and 3 test systems. + +Usage:: + + python scripts/prepare_data.py + +Can be run from anywhere; all paths are resolved relative to the +``examples/dpa_adapt/`` directory (the parent of this script). +""" + +from __future__ import ( + annotations, +) + +import csv +import logging +import shutil +import tarfile +import urllib.request +from pathlib import ( + Path, +) + +import numpy as np + +from dpa_adapt import ( + convert, +) + +logger = logging.getLogger(__name__) + +# This script lives in examples/dpa_adapt/scripts/; resolve data and raw dirs +# against examples/dpa_adapt/. +DEMO_DIR = Path(__file__).resolve().parent.parent +RAW_DIR = DEMO_DIR / "raw" +DATA_DIR = DEMO_DIR / "data" +STAGED_DIR = RAW_DIR / "qm9_8" +STAGED_MOL_DIR = STAGED_DIR / "mol" +STAGED_CSV_PATH = STAGED_DIR / "qm9_8.csv" +SDF_PATH = RAW_DIR / "gdb9.sdf" +CSV_PATH = RAW_DIR / "gdb9.sdf.csv" +TAR_PATH = RAW_DIR / "gdb9.tar.gz" +TAR_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/gdb9.tar.gz" + +N_TRAIN = 5 +N_TEST = 3 +N_TOTAL = N_TRAIN + N_TEST +BOX_LENGTH = 100.0 # Angstrom, cubic box for non-periodic systems +TYPE_MAP = ["H", "C", "N", "O", "F"] + +# Hartree to eV conversion factor +HARTREE_TO_EV = 27.211386245988 + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _download_and_extract(force: bool = False) -> None: + """Download and extract gdb9.tar.gz if the data files don't already exist.""" + if SDF_PATH.exists() and CSV_PATH.exists() and not force: + logger.info("SDF already present: %s", SDF_PATH) + logger.info("CSV already present: %s", CSV_PATH) + return + + RAW_DIR.mkdir(parents=True, exist_ok=True) + + if not TAR_PATH.exists() or force: + logger.info("Downloading %s ...", TAR_URL) + urllib.request.urlretrieve(TAR_URL, TAR_PATH) + logger.info("Downloaded -> %s", TAR_PATH) + + logger.info("Extracting from tarball ...") + with tarfile.open(TAR_PATH, "r:gz") as tar: + for member in tar.getmembers(): + name = Path(member.name).name + if name in ("gdb9.sdf", "gdb9.sdf.csv"): + if not (RAW_DIR / name).exists() or force: + logger.info( + " Extracting %s (%s MB) ...", + name, + f"{member.size / 1024 / 1024:.1f}", + ) + tar.extract(member, path=str(RAW_DIR)) + logger.info("Extraction complete.") + + +def _load_gaps_from_csv(n: int) -> dict[int, float]: + """Read the first *n* rows from the GDB9 CSV, return {index: gap_ev}. + + The CSV columns include ``mol_id``, ``homo``, ``lumo``, ``gap``. + Values are in Hartree; returned values are in eV. + The *mol_id* is ``gdb_N``; we map to 0-based index N-1. + """ + gaps: dict[int, float] = {} + with open(CSV_PATH, newline="", encoding="utf-8") as fh: + reader = csv.DictReader(fh) + for row in reader: + mol_id = row["mol_id"] # e.g. "gdb_1" + idx = int(mol_id.split("_")[1]) - 1 # 0-based + if idx >= n: + break + # Use pre-computed gap if available; otherwise lumo - homo. + if row.get("gap"): + gap_ha = float(row["gap"]) + else: + gap_ha = float(row["lumo"]) - float(row["homo"]) + gaps[idx] = gap_ha * HARTREE_TO_EV + return gaps + + +def _read_sdf_blocks(n: int) -> list[str]: + """Read the first *n* molecule blocks from the SDF file. + + GDB9 molecules are separated by ``$$$$``. + """ + logger.info("Reading %s ...", SDF_PATH) + raw_text = SDF_PATH.read_text(encoding="utf-8") + + blocks = raw_text.split("$$$$") + blocks = [b.strip() for b in blocks if b.strip()] + logger.info("Found %s molecules in SDF.", len(blocks)) + + if len(blocks) < n: + raise RuntimeError(f"Expected at least {n} molecules, found {len(blocks)}") + return blocks[:n] + + +def _stage_qm9_subset( + mol_blocks: list[str], + gaps: np.ndarray, +) -> None: + """Write an 8-row CSV plus one single-molecule SDF per row.""" + if STAGED_DIR.exists(): + shutil.rmtree(STAGED_DIR) + STAGED_MOL_DIR.mkdir(parents=True) + + with STAGED_CSV_PATH.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=["mol_id", "gap"]) + writer.writeheader() + for i, (block, gap) in enumerate(zip(mol_blocks, gaps, strict=True)): + (STAGED_MOL_DIR / f"id{i}.sdf").write_text( + block.strip() + "\n$$$$\n", + encoding="utf-8", + ) + writer.writerow({"mol_id": f"gdb_{i + 1}", "gap": f"{float(gap):.10f}"}) + + +def _collect_labels(system_dirs: list[str]) -> np.ndarray: + """Collect all gap labels from generated system directories.""" + chunks = [] + for sys_dir in sorted(Path(p) for p in system_dirs): + for set_dir in sorted(sys_dir.glob("set.*")): + chunks.append(np.load(set_dir / "gap.npy").reshape(-1)) + if not chunks: + return np.asarray([], dtype=np.float32) + return np.concatenate(chunks).astype(np.float32) + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +def main() -> None: + logger.info("=" * 60) + logger.info("DPA Tools - Quickstart Data Preparation") + logger.info("=" * 60) + + # 1. Download & extract -------------------------------------------------- + _download_and_extract() + + # 2. Read gaps from CSV -------------------------------------------------- + all_gaps = _load_gaps_from_csv(N_TOTAL) + gaps = np.array([all_gaps[i] for i in range(N_TOTAL)], dtype=np.float32) + + logger.info( + "Gap stats (all %d): mean=%.4f eV, std=%.4f eV", + N_TOTAL, + gaps.mean(), + gaps.std(), + ) + + # 3. Read molecules from SDF --------------------------------------------- + mol_blocks = _read_sdf_blocks(N_TOTAL) + + # 4. Stage the 8-row raw subset ------------------------------------------ + _stage_qm9_subset(mol_blocks, gaps) + + # 5. Convert to deepmd/npy via dpa_adapt.convert -------------------------- + if DATA_DIR.exists(): + shutil.rmtree(DATA_DIR) + result = convert( + str(STAGED_CSV_PATH), + str(DATA_DIR), + fmt="smiles", + mol_dir=str(STAGED_MOL_DIR), + mol_template="id{row}.sdf", + property_col="gap", + property_name="gap", + train_ratio=N_TRAIN / N_TOTAL, + split_seed=42, + overwrite=True, + verbose=False, + ) + + # Keep the historical demo layout: data/test rather than data/valid. + valid_dir = DATA_DIR / "valid" + test_dir = DATA_DIR / "test" + valid_dir.rename(test_dir) + train_systems = sorted(result["train_systems"]) + test_systems = sorted(str(p) for p in test_dir.iterdir() if p.is_dir()) + + # 6. Write aggregated labels in generated-system order -------------------- + train_labels = _collect_labels(train_systems) + test_labels = _collect_labels(test_systems) + np.save(str(DATA_DIR / "train_labels.npy"), train_labels) + np.save(str(DATA_DIR / "test_labels.npy"), test_labels) + logger.info( + " train systems -> %s (%s dirs, %s samples)", + DATA_DIR / "train", + len(train_systems), + train_labels.shape[0], + ) + logger.info( + " test systems -> %s (%s dirs, %s samples)", + test_dir, + len(test_systems), + test_labels.shape[0], + ) + + # 7. Summary -------------------------------------------------------------- + logger.info("") + logger.info("=" * 60) + logger.info("n_train : %s", N_TRAIN) + logger.info("n_test : %s", N_TEST) + logger.info("gap mean: %.4f eV", gaps.mean()) + logger.info("gap std : %.4f eV", gaps.std()) + logger.info("Done. Run one of the evaluation scripts next.") + logger.info("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/dpa_adapt/scripts/run_evaluate_frozen_head.py b/examples/dpa_adapt/scripts/run_evaluate_frozen_head.py new file mode 100644 index 0000000000..d19875761c --- /dev/null +++ b/examples/dpa_adapt/scripts/run_evaluate_frozen_head.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Minimal demo: frozen_head fine-tuning on QM9 HOMO-LUMO gap.""" + +import logging +from pathlib import ( + Path, +) + +from dpa_adapt import ( + DPAFineTuner, +) + +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + +HERE = Path(__file__).resolve().parent.parent +DATA = HERE / "data" + +model = DPAFineTuner( + pretrained="DPA-3.1-3M", + strategy="frozen_head", + property_name="gap", + learning_rate=1e-3, + batch_size=128, + max_steps=5, +) +model.fit(train_data=str(DATA / "train" / "*"), valid_data=str(DATA / "test" / "*")) + +pred = model.predict(data=str(DATA / "test" / "*")) +metrics = model.evaluate(data=str(DATA / "test" / "*")) + +logger.info(pred.predictions) +logger.info("%s %s %s", metrics.mae, metrics.rmse, metrics.r2) diff --git a/examples/dpa_adapt/scripts/run_evaluate_frozen_sklearn.py b/examples/dpa_adapt/scripts/run_evaluate_frozen_sklearn.py new file mode 100644 index 0000000000..9200b81035 --- /dev/null +++ b/examples/dpa_adapt/scripts/run_evaluate_frozen_sklearn.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Minimal demo: frozen_sklearn + Ridge on QM9 HOMO-LUMO gap.""" + +import logging +from pathlib import ( + Path, +) + +from dpa_adapt import ( + DPAFineTuner, +) + +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + +HERE = Path(__file__).resolve().parent.parent +DATA = HERE / "data" + +model = DPAFineTuner( + pretrained="DPA-3.1-3M", + model_branch="Domains_Drug", + strategy="frozen_sklearn", + predictor="linear", + seed=42, +) +model.fit(train_data=str(DATA / "train" / "*"), target_key="gap") + +m = model.evaluate(data=str(DATA / "test" / "*")) +logger.info("MAE = %.4f eV", m.mae) +logger.info("RMSE = %.4f eV", m.rmse) +logger.info("R2 = %.4f", m.r2) diff --git a/pyproject.toml b/pyproject.toml index 8541ccf413..fd7e156a87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,9 +91,19 @@ test = [ "pytest-split", "pytest-timeout", "dpgui", + # DPA-ADAPT tests import sklearn via dpa_adapt.cv at module load time. + "scikit-learn", # to support Array API 2024.12 'array-api-strict>=2.2;python_version>="3.9"', ] +dpa-adapt = [ + "scikit-learn", + "dpdata", + "torch", + "ase", + "rdkit", + "e3nn", +] docs = [ "sphinx>=3.1.1", "sphinx-book-theme", @@ -148,6 +158,8 @@ jax = [ [tool.deepmd_build_backend.scripts] dp = "deepmd.main:main" +dpa-adapt = "dpa_adapt.main:main" +dpaad = "dpa_adapt.main:main" [dependency-groups] dev = [ @@ -207,6 +219,7 @@ sdist.exclude = [ ] wheel.packages = [ "deepmd", + "dpa_adapt", ] wheel.py-api = "py37" build-dir = "build/{wheel_tag}" @@ -274,10 +287,12 @@ DP_ENABLE_PADDLE = "1" CMAKE_ARGS = "-DCMAKE_DISABLE_FIND_PACKAGE_OpenMP=1" [[tool.cibuildwheel.overrides]] -# error: 'value' is unavailable: introduced in macOS 10.13 +# PaddlePaddle does not provide macOS x86_64 wheels for this test matrix. select = "*-macosx_x86_64" inherit.environment = "append" +test-extras = ["cpu", "test", "lmp", "ipi", "torch"] environment.MACOSX_DEPLOYMENT_TARGET = "11.0" +environment.DP_ENABLE_PADDLE = "0" [tool.cibuildwheel.linux] repair-wheel-command = "auditwheel repair --exclude libtensorflow_framework.so.2 --exclude libtensorflow_framework.so.1 --exclude libtensorflow_framework.so --exclude _pywrap_tensorflow_internal.so --exclude libtensorflow_cc.so.2 --exclude libc10.so --exclude libtorch.so --exclude libtorch_cpu.so --exclude libmpi.so.12 -w {dest_dir} {wheel}" @@ -290,7 +305,7 @@ environment-pass = [ ] before-all = [ """if [ ! -z "${DP_PKG_NAME}" ]; then sed -i "s/name = \\"deepmd-kit\\"/name = \\"${DP_PKG_NAME}\\"/g" pyproject.toml; fi""", - """{ if [ "$(uname -m)" = "x86_64" ] ; then yum config-manager --add-repo http://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo && yum install -y cuda-nvcc-${CUDA_VERSION/./-} cuda-cudart-devel-${CUDA_VERSION/./-}; fi }""", + """{ if [ -n "${CUDA_VERSION}" ] && [ "$(uname -m)" = "x86_64" ] ; then yum config-manager --add-repo http://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo && yum install -y cuda-nvcc-${CUDA_VERSION/./-} cuda-cudart-devel-${CUDA_VERSION/./-}; fi }""", ] before-build = [ ] diff --git a/source/install/uv_with_retry.sh b/source/install/uv_with_retry.sh index 2d9a524f6b..deb0228d82 100755 --- a/source/install/uv_with_retry.sh +++ b/source/install/uv_with_retry.sh @@ -1,5 +1,5 @@ #!/bin/bash -# This script is used to retry the uv command if the error "error decoding response body" is encountered. +# This script is used to retry the uv command if a transient network error is encountered. # See also: # https://github.com/astral-sh/uv/issues/2586 # https://github.com/astral-sh/uv/issues/3456 @@ -15,16 +15,16 @@ while true; do rm -f "${tmpstderr}" exit 0 fi - # check if "error decoding response body" is in the stderr - if grep -q "error decoding response body" "${tmpstderr}"; then - echo "Retrying uv in 1 s..." + # check if a retryable network error is in the stderr + if grep -qE "error decoding response body|network timeout|I/O operation failed during extraction|Failed to download distribution" "${tmpstderr}"; then + echo "Retrying uv in 5 s..." max_retry=$((max_retry - 1)) if [ $max_retry -eq 0 ]; then echo "Max retry reached, exiting..." rm -f "${tmpstderr}" exit 1 fi - sleep 1 + sleep 5 else rm -f "${tmpstderr}" exit $exit_code diff --git a/source/tests/dpa_adapt/__init__.py b/source/tests/dpa_adapt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/dpa_adapt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/dpa_adapt/test_auto_convert.py b/source/tests/dpa_adapt/test_auto_convert.py new file mode 100644 index 0000000000..1ba35b558b --- /dev/null +++ b/source/tests/dpa_adapt/test_auto_convert.py @@ -0,0 +1,204 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for ``convert`` and the CSV-sniffing helpers.""" + +from __future__ import ( + annotations, +) + +from importlib.util import ( + find_spec, +) +from pathlib import ( + Path, +) + +import pytest + +_HAS_RDKIT = find_spec("rdkit") is not None + +from dpa_adapt.data.convert import ( + _is_smiles_input, + _sniff_csv, + _sniff_xlsx, + convert, +) + +# --------------------------------------------------------------------------- +# CSV sniffing +# --------------------------------------------------------------------------- + + +class TestSniffCsv: + def test_detects_smiles_column(self, tmp_path): + f = tmp_path / "data.csv" + f.write_text("id,SMILES,Property\n0,CCO,1.23\n1,c1ccccc1,4.56\n") + assert _is_smiles_input(str(f)) is True + + def test_detects_smi_column(self, tmp_path): + f = tmp_path / "data.csv" + f.write_text("idx,smi,target\n0,CCO,1.0\n") + assert _is_smiles_input(str(f)) is True + + def test_rejects_non_smiles_csv(self, tmp_path): + f = tmp_path / "data.csv" + f.write_text("formula,energy\nH2O,-1.0\n") + assert _is_smiles_input(str(f)) is False + + def test_non_csv_extension(self, tmp_path): + f = tmp_path / "POSCAR" + f.write_text("Si\n1.0\n0 0 0\n") + assert _is_smiles_input(str(f)) is False + + def test_malformed_csv(self, tmp_path): + f = tmp_path / "bad.csv" + f.write_bytes(b"\x00\x01\x02") + assert _sniff_csv(str(f)) is None + + def test_empty_csv(self, tmp_path): + f = tmp_path / "empty.csv" + f.write_text("") + assert _sniff_csv(str(f)) is None + + +class TestSniffXlsx: + @pytest.fixture(autouse=True) + def _require_openpyxl(self): + pytest.importorskip("openpyxl") + + @pytest.mark.parametrize("filename", ["data.xlsx", "data.xls"]) + def test_detects_smiles_column(self, tmp_path, filename): + pd = pytest.importorskip("pandas") + f = tmp_path / filename + pd.DataFrame({"SMILES": ["CCO", "c1ccccc1"], "Prop": [1.0, 2.0]}).to_excel( + f, + index=False, + engine="openpyxl", + ) + assert _is_smiles_input(str(f)) is True + + def test_rejects_non_smiles_xlsx(self, tmp_path): + pd = pytest.importorskip("pandas") + f = tmp_path / "data.xlsx" + pd.DataFrame({"formula": ["H2O"], "energy": [1.0]}).to_excel( + f, + index=False, + engine="openpyxl", + ) + assert _is_smiles_input(str(f)) is False + + def test_pandas_not_installed(self, tmp_path, monkeypatch): + f = tmp_path / "data.xlsx" + f.write_text("dummy") # not a real xlsx, but we won't reach pandas + monkeypatch.setitem(__import__("sys").modules, "pandas", None) + assert _sniff_xlsx(str(f)) is None + + +# --------------------------------------------------------------------------- +# convert routing +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _HAS_RDKIT, reason="RDKit not installed") +class TestAutoConvertSmiles: + """convert routes CSV-with-SMILES to the SMILES pipeline.""" + + def test_routes_csv_smiles_to_smiles_method(self, tmp_path): + f = tmp_path / "mol.csv" + f.write_text("SMILES,Property\nCCO,1.5\nCN,2.0\n") + out = tmp_path / "npy" + + result = convert(str(f), str(out)) + + assert result["method"] == "smiles" + assert result["samples_used"] == 2 + assert "C" in result["type_map"] + assert len(result["train_systems"]) > 0 + assert len(result["valid_systems"]) > 0 + + def test_explicit_fmt_smiles_overrides_sniff(self, tmp_path): + f = tmp_path / "mol.csv" + f.write_text("SMILES,val\nC,1.0\nCC,2.0\n") + out = tmp_path / "npy2" + + result = convert(str(f), str(out), fmt="smiles", property_col="val") + + assert result["method"] == "smiles" + assert result["samples_used"] == 2 + assert "failed_rows" in result + assert "skipped_zero" in result + assert "skipped_overlap" in result + + def test_explicit_fmt_smiles_is_case_insensitive(self, tmp_path): + f = tmp_path / "mol.csv" + f.write_text("SMILES,val\nC,1.0\nCC,2.0\n") + out = tmp_path / "npy3" + + result = convert(str(f), str(out), fmt="SMILES", property_col="val") + + assert result["method"] == "smiles" + assert result["samples_used"] == 2 + + +class TestAutoConvertStructure: + """convert routes structure files through dpdata.""" + + def test_routes_poscar_to_dpdata(self, tmp_path): + f = tmp_path / "POSCAR" + f.write_text("Si\n1.0\n5.43 0 0\n0 5.43 0\n0 0 5.43\nSi\n1\nCartesian\n0 0 0\n") + out = tmp_path / "npy" + + result = convert(str(f), str(out)) + + assert result["method"] == "dpdata" + out_dir = result["output_dir"] + assert (Path(out_dir) / "type.raw").exists() + assert (Path(out_dir) / "set.000" / "coord.npy").exists() + + def test_explicit_fmt_passed_through(self, tmp_path): + f = tmp_path / "POSCAR" + f.write_text("Si\n1.0\n5.43 0 0\n0 5.43 0\n0 0 5.43\nSi\n1\nCartesian\n0 0 0\n") + out = tmp_path / "npy2" + + result = convert(str(f), str(out), fmt="vasp/poscar") + + assert result["method"] == "dpdata" + + +class TestAutoConvertNoSmiles: + """CSV without recognised SMILES column falls through to dpdata.""" + + def test_falls_through_to_dpdata(self, tmp_path): + f = tmp_path / "props.csv" + f.write_text("formula,energy\nH2O,-1.0\n") + out = tmp_path / "npy" + + # dpdata may or may not handle this, but it must NOT go to SMILES + with pytest.raises(Exception): # dpdata won't recognise it either + convert(str(f), str(out)) + + +@pytest.mark.skipif(not _HAS_RDKIT, reason="RDKit not installed") +class TestSmoke: + """Minimal round-trip: SMILES → npy → load_data.""" + + def test_smiles_round_trip(self, tmp_path): + from dpa_adapt.data.loader import ( + load_data, + ) + + f = tmp_path / "round.csv" + f.write_text("SMILES,Property\nCCO,1.5\nCN,2.0\n") + out = tmp_path / "npy" + + result = convert( + str(f), + str(out), + property_name="homo", + property_col="Property", + ) + assert result["method"] == "smiles" + + # Verify one of the output systems is loadable and carries the label. + systems = load_data(result["train_systems"]) + assert len(systems) > 0 + assert "homo" in systems[0].data diff --git a/source/tests/dpa_adapt/test_backend_contract.py b/source/tests/dpa_adapt/test_backend_contract.py new file mode 100644 index 0000000000..649eef24d4 --- /dev/null +++ b/source/tests/dpa_adapt/test_backend_contract.py @@ -0,0 +1,516 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Contract tests for ``dpa_adapt._backend``. + +These tests call **real** deepmd APIs — no mocks — on a minimal synthetic +DPA-3 descriptor model. Their purpose is to catch silent breakage when +deepmd internal APIs change signatures, ``"Default"`` key paths shift, or +the descriptor-hook behaviour is altered upstream. + +No large checkpoint file is needed; we build a tiny model from a config +dict and run a single forward pass. +""" + +from __future__ import ( + annotations, +) + +import numpy as np +import pytest + +# Smallest possible DPA-3 descriptor config that get_model accepts. +_MINIMAL_DPA3_CONFIG = { + "type_map": ["H", "O"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 16, + "e_dim": 8, + "a_dim": 4, + "nlayers": 2, + "e_rcut": 4.0, + "e_rcut_smth": 3.5, + "e_sel": 10, + "a_rcut": 3.0, + "a_rcut_smth": 2.5, + "a_sel": 5, + "axis_neuron": 2, + "skip_stat": True, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "smooth_edge_update": True, + "use_dynamic_sel": True, + "sel_reduce_factor": 10.0, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "n_multi_edge_message": 1, + "optim_update": True, + "use_exp_switch": True, + }, + "activation_function": "silu", + "precision": "float64", + "use_tebd_bias": False, + "concat_output_tebd": False, + "exclude_types": [], + "env_protection": 0.0, + "trainable": True, + "use_econf_tebd": False, + }, + "fitting_net": { + "type": "ener", + "neuron": [16, 16], + "activation_function": "tanh", + "precision": "float64", + "resnet_dt": True, + "use_tebd_bias": False, + "exclude_types": [], + "numb_fparam": 0, + "numb_aparam": 0, + }, +} + + +@pytest.fixture(autouse=True) +def _clear_default_torch_device(): + """Keep these CPU contract tests isolated from leaked torch defaults.""" + try: + import torch + import torch.utils._device as _device + from torch.overrides import ( + _get_current_function_mode_stack, + ) + except Exception: + yield + return + + def _pop_device_contexts(): + while True: + modes = _get_current_function_mode_stack() + if not modes or not isinstance(modes[-1], _device.DeviceContext): + break + modes[-1].__exit__(None, None, None) + + _pop_device_contexts() + torch.set_default_device(None) + try: + yield + finally: + _pop_device_contexts() + torch.set_default_device(None) + + +def _run_forward_cpu(extractor, coords, atype, box): + """Run the descriptor forward path, skipping CPU-only CI CUDA leaks.""" + import torch + + try: + with torch.device("cpu"): + return extractor._run_forward(coords, atype, box) + except AssertionError as exc: + if "Torch not compiled with CUDA enabled" in str(exc): + pytest.skip(f"PyTorch default-device CUDA leak in CPU-only build: {exc}") + raise + + +@pytest.mark.skipif(True, reason="requires real DPA checkpoint / GPU — CI contract") +class _HeavyContract: + """Guarded heavy tests that need DPA checkpoint + GPU.""" + + def test_real_checkpoint_descriptor_shape( + self, + ): # placeholder for future Bohrium-only tests + pass + + +class _HookOwner: + def __init__(self): + self.flags = [] + self.eval_descriptor_list = [object()] + + def set_eval_descriptor_hook(self, enable): + self.flags.append(enable) + + def eval_descriptor(self): + return None + + +class _FakeWrapper: + def __init__(self, inner): + self.model = {"Default": inner} + + +class TestDescriptorHookResolution: + def test_prefers_inner_model_hook(self): + from dpa_adapt._backend import ( + _DescriptorExtraction, + ) + + inner = _HookOwner() + inner.atomic_model = object() + + extractor = _DescriptorExtraction(_FakeWrapper(inner)) + extractor._enable_hook() + extractor._disable_hook() + + assert extractor._descriptor_hook_model is inner + assert inner.flags == [True, False] + + def test_falls_back_to_atomic_model_hook(self): + from dpa_adapt._backend import ( + _DescriptorExtraction, + ) + + atomic = _HookOwner() + inner = type("Inner", (), {"atomic_model": atomic})() + + extractor = _DescriptorExtraction(_FakeWrapper(inner)) + extractor._enable_hook() + extractor._clear_accumulator() + extractor._disable_hook() + + assert extractor._descriptor_hook_model is atomic + assert atomic.flags == [True, False] + assert atomic.eval_descriptor_list == [] + + +class _FakeInnerWithEmbedding: + """Inner model with forward_embedding() but no hook API (e.g. DPA3).""" + + def __init__(self, descriptor_tensor): + self._descriptor = descriptor_tensor + self.atomic_model = object() + + def forward_embedding(self, coord, atype, box): + # Mirror the real forward_embedding contract: dict[str, torch.Tensor] + # with keys ``descriptor``, ``atomic_feature``, ``structural_feature``. + return { + "descriptor": self._descriptor, + "atomic_feature": None, + "structural_feature": None, + } + + +class TestForwardEmbeddingFallback: + def test_enable_hook_is_noop_without_hook_model(self): + import sys + from unittest.mock import ( + MagicMock, + ) + + if isinstance(sys.modules.get("torch"), MagicMock): + pytest.skip("torch is mocked by another test") + + import torch + + from dpa_adapt._backend import ( + _DescriptorExtraction, + ) + + desc = torch.zeros(1, 2, 16, dtype=torch.float64) + inner = _FakeInnerWithEmbedding(desc) + extractor = _DescriptorExtraction(_FakeWrapper(inner)) + + assert extractor._descriptor_hook_model is None + extractor._enable_hook() + extractor._disable_hook() + extractor._clear_accumulator() + + def test_run_forward_uses_forward_embedding(self): + import sys + from unittest.mock import ( + MagicMock, + ) + + if isinstance(sys.modules.get("torch"), MagicMock): + pytest.skip("torch is mocked by another test") + + import torch + + from dpa_adapt._backend import ( + _DescriptorExtraction, + ) + + desc = torch.ones(1, 2, 16, dtype=torch.float64) + inner = _FakeInnerWithEmbedding(desc) + extractor = _DescriptorExtraction(_FakeWrapper(inner)) + + coord = torch.zeros(1, 6, dtype=torch.float64, requires_grad=True) + atype = torch.tensor([[0, 1]], dtype=torch.long) + box = torch.eye(3, dtype=torch.float64).ravel().unsqueeze(0) + + result = extractor._run_forward(coord, atype, box) + + assert result.shape == (1, 2, 16) + assert not result.requires_grad + + +class TestBackendContract: + """Contract tests using real deepmd APIs (no mocks). + + These require a fully-functional deepmd-kit installation. They are + skipped when the environment is incomplete (e.g. CI without MPI). + """ + + @pytest.fixture(autouse=True) + def _require_deepmd(self): + """Skip if the deepmd model builder is not usable.""" + try: + from dpa_adapt._backend import ( + build_model_from_config, + ) + + build_model_from_config(_MINIMAL_DPA3_CONFIG) + except Exception as exc: + pytest.skip(f"deepmd build_model_from_config not functional: {exc}") + + @pytest.fixture + def _extractor(self): + """Build a model + extractor, yield it, then **always** disable the + descriptor hook so a test failure never leaks global state. + """ + from dpa_adapt._backend import ( + _DescriptorExtraction, + build_model_from_config, + ) + + wrapper = build_model_from_config(_MINIMAL_DPA3_CONFIG) + wrapper.eval() + extractor = _DescriptorExtraction(wrapper) + missing_hook_api = [ + name + for name in ( + "set_eval_descriptor_hook", + "eval_descriptor", + "eval_descriptor_list", + ) + if not hasattr(extractor._atomic_model, name) + ] + if missing_hook_api: + pytest.skip( + "deepmd descriptor-hook API is not available on " + f"{type(extractor._atomic_model).__name__}: " + f"{', '.join(missing_hook_api)}" + ) + extractor._enable_hook() + try: + yield extractor + finally: + extractor._disable_hook() + + def test_build_model_from_config(self): + """``build_model_from_config`` succeeds with minimal config.""" + from dpa_adapt._backend import ( + build_model_from_config, + ) + + wrapper = build_model_from_config(_MINIMAL_DPA3_CONFIG) + assert wrapper is not None + assert "Default" in wrapper.model, ( + "ModelWrapper.model must contain 'Default' key" + ) + + def test_descriptor_extraction_chain(self, _extractor): + """Full chain: build → hook → forward → eval_descriptor → shape check.""" + import torch + + # Synthetic input: 1 frame, 2 atoms (H and O), reasonable distances + n_frames = 1 + n_atoms = 2 + coords = torch.tensor( + [[0.0, 0.0, 0.0, 1.5, 0.0, 0.0]], + dtype=torch.float64, + device="cpu", + ).requires_grad_(True) + atype = torch.tensor([[0, 1]], dtype=torch.long, device="cpu") # H, O + box = torch.tensor( + [[10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0]], + dtype=torch.float64, + device="cpu", + ) + + desc = _run_forward_cpu(_extractor, coords, atype, box) + + assert desc.ndim == 3, ( + f"expected (n_frames, n_atoms, feat_dim), got {desc.shape}" + ) + assert desc.shape[0] == n_frames + assert desc.shape[1] == n_atoms + assert desc.shape[2] > 0, "feature dim must be > 0" + assert not torch.any(torch.isnan(desc)), "descriptor contains NaN" + assert not torch.any(torch.isinf(desc)), "descriptor contains Inf" + + def test_descriptor_feat_dim_matches_repflow(self, _extractor): + """The feature dimension matches n_dim from the repflow config.""" + import torch + + coords = torch.tensor( + [[0.0, 0.0, 0.0, 1.5, 0.0, 0.0]], + dtype=torch.float64, + device="cpu", + ).requires_grad_(True) + atype = torch.tensor([[0, 1]], dtype=torch.long, device="cpu") + box = torch.tensor( + [[10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0]], + dtype=torch.float64, + device="cpu", + ) + + desc = _run_forward_cpu(_extractor, coords, atype, box) + + n_dim = _MINIMAL_DPA3_CONFIG["descriptor"]["repflow"]["n_dim"] + assert desc.shape[2] == n_dim, ( + f"descriptor feat dim {desc.shape[2]} != repflow n_dim {n_dim}" + ) + + def test_forward_common_fails_without_grad(self, _extractor): + """``forward_common`` requires gradients on coords — verify the guard.""" + import torch + + coords = torch.tensor( + [[0.0, 0.0, 0.0, 1.5, 0.0, 0.0]], + dtype=torch.float64, + device="cpu", + ) # NO requires_grad + atype = torch.tensor([[0, 1]], dtype=torch.long, device="cpu") + box = torch.tensor( + [[10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0]], + dtype=torch.float64, + device="cpu", + ) + + with pytest.raises(RuntimeError, match="grad"): + _run_forward_cpu(_extractor, coords, atype, box) + + +class TestBackendHelpers: + """Unit-level checks for _backend utility functions.""" + + def test_resolve_dp_command_keeps_symlinked_venv_scripts( + self, tmp_path, monkeypatch + ): + import os + import sys + from pathlib import ( + Path, + ) + + from dpa_adapt._backend import ( + resolve_dp_command, + ) + + exe_name = "dp.exe" if os.name == "nt" else "dp" + python_name = "python.exe" if os.name == "nt" else "python" + + real_bin = tmp_path / "real" / "bin" + venv_bin = tmp_path / "venv" / "bin" + real_bin.mkdir(parents=True) + venv_bin.mkdir(parents=True) + + real_python = real_bin / python_name + real_python.write_text("") + symlink_python = venv_bin / python_name + symlink_python.write_text("") + + wrong_dp = real_bin / exe_name + wrong_dp.write_text("") + expected_dp = venv_bin / exe_name + expected_dp.write_text("") + + def _fake_resolve(self): + if self == symlink_python: + return real_python + return self + + monkeypatch.setattr(Path, "resolve", _fake_resolve) + monkeypatch.setattr(sys, "executable", os.fspath(symlink_python)) + monkeypatch.setattr( + "sysconfig.get_path", + lambda name: os.fspath(tmp_path / "other") if name == "scripts" else "", + ) + monkeypatch.setattr("shutil.which", lambda name: None) + + assert resolve_dp_command() == os.fspath(expected_dp) + + def test_get_torch_device_returns_device(self): + import sys + from unittest.mock import ( + MagicMock, + ) + + if isinstance(sys.modules.get("torch"), MagicMock): + pytest.skip("torch is mocked by another test") + + from dpa_adapt._backend import ( + get_torch_device, + ) + + device = get_torch_device() + assert device.type in ("cpu", "cuda") + + def test_load_torch_file_roundtrip(self, tmp_path): + import sys + from unittest.mock import ( + MagicMock, + ) + + if isinstance(sys.modules.get("torch"), MagicMock): + pytest.skip("torch is mocked by another test") + + import torch + + from dpa_adapt._backend import ( + load_torch_file, + ) + + path = str(tmp_path / "test.pt") + data = {"key": "value", "n": 42} + torch.save(data, path) + loaded = load_torch_file(path) + assert loaded == data + + +class TestFormatVersion: + """format_version contract.""" + + def test_freeze_bundle_has_format_version(self, tmp_path): + """A frozen bundle from DPAFineTuner.freeze() must carry format_version=1.""" + from unittest.mock import ( + patch, + ) + + from dpa_adapt import ( + DPAFineTuner, + ) + + system = tmp_path / "sys" + system.mkdir() + (system / "type.raw").write_text("0\n1\n") + (system / "type_map.raw").write_text("Cu\nO\n") + sd = system / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((3, 6))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (3, 1))) + np.save(sd / "energy.npy", np.arange(3, dtype=float)) + + def _fake_extract(self, systems): + return np.random.default_rng(0).random((3, 8)) + + with ( + patch.object(DPAFineTuner, "_load_descriptor_model", lambda self: None), + patch.object(DPAFineTuner, "_extract_features", _fake_extract), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear") + ft._checkpoint_type_map = ["Cu", "O"] + ft.fit(str(system), target_key="energy") + frozen = ft.freeze(str(tmp_path / "model.pth")) + + from dpa_adapt._backend import ( + load_torch_file, + ) + + bundle = load_torch_file(frozen) + assert bundle.get("format_version") == 1, ( + f"format_version missing or wrong: {bundle.get('format_version')!r}" + ) diff --git a/source/tests/dpa_adapt/test_cache.py b/source/tests/dpa_adapt/test_cache.py new file mode 100644 index 0000000000..e16be6199b --- /dev/null +++ b/source/tests/dpa_adapt/test_cache.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for descriptor cache (desc_cache.py).""" + +import numpy as np + +from dpa_adapt.data.desc_cache import ( + _cache_dir, + _cache_key, + _data_fingerprint, + _per_system_cache_path, + _system_fingerprint, +) +from dpa_adapt.data.loader import ( + load_data, +) +from dpa_adapt.finetuner import ( + ensure_per_system_cache, +) + + +def _make_system(tmp_path, name="sys", natoms=2, nframes=3, elements=None): + """Create a minimal deepmd/npy system dir and load it via dpdata.""" + if elements is None: + elements = ["H", "O"] + root = tmp_path / name + root.mkdir(parents=True, exist_ok=True) + (root / "type.raw").write_text( + "\n".join(str(i % len(elements)) for i in range(natoms)) + "\n" + ) + (root / "type_map.raw").write_text("\n".join(elements) + "\n") + sd = root / "set.000" + sd.mkdir(exist_ok=True) + np.save(sd / "coord.npy", np.random.default_rng().random((nframes, natoms * 3))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (nframes, 1))) + return load_data(str(root))[0] + + +class TestSystemFingerprint: + def test_same_data_same_fp(self, tmp_path): + s = _make_system(tmp_path, "s1") + fp1 = _system_fingerprint(s) + fp2 = _system_fingerprint(s) + assert fp1 == fp2 + + def test_different_data_different_fp(self, tmp_path): + s1 = _make_system(tmp_path, "s1", nframes=3) + s2 = _make_system(tmp_path, "s2", nframes=5) + assert _system_fingerprint(s1) != _system_fingerprint(s2) + + def test_different_elements_different_fp(self, tmp_path): + s1 = _make_system(tmp_path, "s1", elements=["H", "O"]) + s2 = _make_system(tmp_path, "s2", elements=["Cu", "O"]) + assert _system_fingerprint(s1) != _system_fingerprint(s2) + + +class TestFingerprint: + def test_identical_data_same_fp(self, tmp_path): + s = _make_system(tmp_path, "s1") + fp1 = _data_fingerprint([s]) + fp2 = _data_fingerprint([s]) + assert fp1 == fp2 + + def test_different_data_different_fp(self, tmp_path): + s1 = _make_system(tmp_path, "s1", nframes=3) + s2 = _make_system(tmp_path, "s2", nframes=5) + fp1 = _data_fingerprint([s1]) + fp2 = _data_fingerprint([s2]) + assert fp1 != fp2 + + def test_system_order_changes_fp(self, tmp_path): + s1 = _make_system(tmp_path, "s1", nframes=3) + s2 = _make_system(tmp_path, "s2", nframes=5) + fp1 = _data_fingerprint([s1, s2]) + fp2 = _data_fingerprint([s2, s1]) + assert fp1 != fp2 + + +class TestCacheKey: + def test_same_inputs_same_key(self, tmp_path): + s = _make_system(tmp_path, "s1") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + k1 = _cache_key([s], str(ckpt), None, "mean") + k2 = _cache_key([s], str(ckpt), None, "mean") + assert k1 == k2 + + def test_different_pooling_different_key(self, tmp_path): + s = _make_system(tmp_path, "s1") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + k1 = _cache_key([s], str(ckpt), None, "mean") + k2 = _cache_key([s], str(ckpt), None, "mean+std") + assert k1 != k2 + + def test_different_branch_different_key(self, tmp_path): + s = _make_system(tmp_path, "s1") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + k1 = _cache_key([s], str(ckpt), "Omat24", "mean") + k2 = _cache_key([s], str(ckpt), "Domains_Drug", "mean") + assert k1 != k2 + + def test_different_checkpoint_different_key(self, tmp_path): + s = _make_system(tmp_path, "s1") + ckpt1 = tmp_path / "dummy1.pt" + ckpt2 = tmp_path / "dummy2.pt" + ckpt1.write_text("dummy") + ckpt2.write_text("different") + k1 = _cache_key([s], str(ckpt1), None, "mean") + k2 = _cache_key([s], str(ckpt2), None, "mean") + assert k1 != k2 + + def test_different_type_map_different_key(self, tmp_path): + s = _make_system(tmp_path, "s1") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + k1 = _cache_key([s], str(ckpt), None, "mean", type_map=("H", "O")) + k2 = _cache_key([s], str(ckpt), None, "mean", type_map=("O", "H")) + assert k1 != k2 + + def test_different_system_order_different_key(self, tmp_path): + s1 = _make_system(tmp_path, "s1", nframes=3) + s2 = _make_system(tmp_path, "s2", nframes=5) + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + k1 = _cache_key([s1, s2], str(ckpt), None, "mean") + k2 = _cache_key([s2, s1], str(ckpt), None, "mean") + assert k1 != k2 + + +class TestCacheDir: + def test_respects_xdg(self, monkeypatch, tmp_path): + monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path)) + d = _cache_dir() + assert str(tmp_path) in str(d) + assert "dpa_adapt" in str(d) + + +class TestPerSystemCachePath: + def test_uses_hash_not_path(self, tmp_path): + s = _make_system(tmp_path, "s1") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + path = _per_system_cache_path(s, str(ckpt)) + # Should be under the cache dir, not next to the original data + assert "dpa_adapt" in str(path) + assert path.suffix == ".npy" + + def test_includes_type_map(self, tmp_path): + s = _make_system(tmp_path, "s1") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + p1 = _per_system_cache_path(s, str(ckpt), type_map=("H", "O")) + p2 = _per_system_cache_path(s, str(ckpt), type_map=("O", "H")) + assert p1 != p2 + + +class TestEnsurePerSystemCache: + def _write_dummy_desc_cache(self, system, pretrained, feat_dim=8, nframes=2): + cache_path = _per_system_cache_path(system, pretrained) + cache_path.parent.mkdir(parents=True, exist_ok=True) + np.save(cache_path, np.zeros((nframes, feat_dim))) + + def test_all_cached_does_not_load_model(self, tmp_path, monkeypatch): + s1 = _make_system(tmp_path, "sys1") + s2 = _make_system(tmp_path, "sys2") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + self._write_dummy_desc_cache(s1, str(ckpt)) + self._write_dummy_desc_cache(s2, str(ckpt)) + + called = [] + + class FakeFineTuner: + def __init__(self, **kwargs): + called.append(True) + + def _extract_features(self, systems): + return np.zeros((2, 8)) + + monkeypatch.setattr( + "dpa_adapt.finetuner.DPAFineTuner", + FakeFineTuner, + ) + ensure_per_system_cache( + [s1, s2], + pretrained=str(ckpt), + pooling="mean", + ) + assert called == [], "DPAFineTuner was called but all systems were cached" + + def test_some_missing_loads_model(self, tmp_path, monkeypatch): + s1 = _make_system(tmp_path, "sys1") + s2 = _make_system(tmp_path, "sys2") + ckpt = tmp_path / "dummy.pt" + ckpt.write_text("dummy") + self._write_dummy_desc_cache(s1, str(ckpt)) + + called = [] + + class FakeFineTuner: + def __init__(self, **kwargs): + called.append(True) + + def _extract_features(self, systems): + return np.zeros((2, 8)) + + _device = None + + monkeypatch.setattr( + "dpa_adapt.finetuner.DPAFineTuner", + FakeFineTuner, + ) + ensure_per_system_cache( + [s1, s2], + pretrained=str(ckpt), + pooling="mean", + ) + assert len(called) == 1, ( + "DPAFineTuner should be called exactly once for the missing system" + ) diff --git a/source/tests/dpa_adapt/test_cli_smoke.py b/source/tests/dpa_adapt/test_cli_smoke.py new file mode 100644 index 0000000000..d258f612c9 --- /dev/null +++ b/source/tests/dpa_adapt/test_cli_smoke.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Smoke tests for the standalone ``dpa-adapt`` / ``dpaad`` CLI. + +Test that all verbs are reachable, ``--help`` does not trigger eager loading +of torch or any DPA implementation, and dispatch tables cover all verbs. +""" + +from __future__ import ( + annotations, +) + +import sys + + +class TestDpaAdaptParserRegistration: + """Verify all dpa-adapt verbs are registered in the standalone parser.""" + + def test_dpa_verbs_registered(self): + from dpa_adapt.cli import ( + get_parser, + ) + + parser = get_parser() + sub_action = next(a for a in parser._actions if a.dest == "command") + verbs = sorted(sub_action.choices) + for expected in ( + "extract-descriptors", + "fit", + "cv", + "predict", + "evaluate", + "data", + ): + assert expected in verbs, f"{expected!r} missing from {verbs}" + assert "mft" not in verbs, "mft should be folded into fit --strategy mft" + + def test_data_subcommands_registered(self): + from dpa_adapt.cli import ( + get_parser, + ) + + parser = get_parser() + sub_action = next(a for a in parser._actions if a.dest == "command") + data_parser = sub_action.choices["data"] + data_sub_action = next( + a for a in data_parser._actions if a.dest == "data_command" + ) + data_verbs = sorted(data_sub_action.choices) + for expected in ("convert", "validate", "attach-labels"): + assert expected in data_verbs, f"{expected!r} missing from {data_verbs}" + + +class TestDpaAdaptHelpNoTorch: + """``dpa-adapt --help`` must NOT trigger a torch import.""" + + def test_help_does_not_load_torch(self): + from unittest.mock import ( + MagicMock, + ) + + from dpa_adapt.cli import ( + get_parser, + ) + + # Other tests may inject a mock torch into sys.modules; that's fine + # as long as OUR parser path doesn't cause a *new* import. + torch_already = "torch" in sys.modules + if torch_already: + existing = sys.modules["torch"] + if not isinstance(existing, MagicMock): + import pytest + + pytest.skip("torch already loaded by another test") + + parser = get_parser() + + # Format the help text — this is the code path that argparse runs + # when --help is requested. + parser.format_help() + + if not torch_already: + assert "torch" not in sys.modules, ( + "torch was loaded during dpa-adapt --help path!" + ) + + def test_main_without_subcommand_prints_help(self, capsys): + from dpa_adapt.cli import ( + main, + ) + + main([]) + captured = capsys.readouterr() + + assert "usage:" in captured.out + assert "subcommands" in captured.out + assert captured.err == "" + + +class TestDpaDispatch: + """Verify the dispatch table covers all registered verbs.""" + + def test_dispatch_keys_match_parser_verbs(self): + from dpa_adapt.cli import ( + _DISPATCH, + get_parser, + ) + + parser = get_parser() + sub_action = next(a for a in parser._actions if a.dest == "command") + + parser_verbs = set(sub_action.choices) + dispatch_verbs = set(_DISPATCH) | {"data"} + + extra_in_parser = parser_verbs - dispatch_verbs + extra_in_dispatch = dispatch_verbs - parser_verbs + assert not extra_in_parser, ( + f"Verbs in parser but not in dispatch: {extra_in_parser}" + ) + assert not extra_in_dispatch, ( + f"Verbs in dispatch but not in parser: {extra_in_dispatch}" + ) + + def test_data_dispatch_keys_match_parser_verbs(self): + from dpa_adapt.cli import ( + _DATA_DISPATCH, + get_parser, + ) + + parser = get_parser() + sub_action = next(a for a in parser._actions if a.dest == "command") + data_parser = sub_action.choices["data"] + data_sub_action = next( + a for a in data_parser._actions if a.dest == "data_command" + ) + + parser_verbs = set(data_sub_action.choices) + dispatch_verbs = set(_DATA_DISPATCH) + + extra_in_parser = parser_verbs - dispatch_verbs + extra_in_dispatch = dispatch_verbs - parser_verbs + assert not extra_in_parser, ( + f"Data verbs in parser but not in dispatch: {extra_in_parser}" + ) + assert not extra_in_dispatch, ( + f"Data verbs in dispatch but not in parser: {extra_in_dispatch}" + ) + + +class TestDpaFitArgumentNormalization: + """Verify fit list arguments normalize argparse ``nargs`` values.""" + + def test_maybe_split_list_accepts_string_sequences(self): + from dpa_adapt.cli import ( + _maybe_split_list, + ) + + assert _maybe_split_list(["train_a", "train_b,train_c"]) == [ + "train_a", + "train_b", + "train_c", + ] + assert _maybe_split_list("H,C, O") == ["H", "C", "O"] + assert _maybe_split_list(None) is None + + def test_batch_size_parser_preserves_deepmd_specs(self): + from dpa_adapt.cli import ( + _parse_batch_size, + ) + + assert _parse_batch_size("128") == 128 + assert _parse_batch_size("auto:512") == "auto:512" + + def test_fit_accepts_downstream_auto_batch_size(self): + from dpa_adapt.cli import ( + get_parser, + ) + + args = get_parser().parse_args( + [ + "fit", + "--train-data", + "train", + "--strategy", + "mft", + "--downstream-batch-size", + "auto:512", + ] + ) + + assert args.downstream_batch_size == "auto:512" + + def test_fit_batch_size_numbers_parse_to_int(self): + from dpa_adapt.cli import ( + get_parser, + ) + + args = get_parser().parse_args( + [ + "fit", + "--train-data", + "train", + "--batch-size", + "64", + "--aux-batch-size", + "128", + "--downstream-batch-size", + "256", + ] + ) + + assert args.batch_size == 64 + assert args.aux_batch_size == 128 + assert args.downstream_batch_size == 256 + + +class TestInitAllExports: + """Verify __all__ covers the key public names.""" + + def test_all_exports(self): + import dpa_adapt + + for name in [ + "DPAFineTuner", + "DPAPredictor", + "MFTFineTuner", + "DPATrainer", + "cross_validate", + "train_test_split", + "extract_descriptors", + "convert", + "attach_labels", + "check_data", + "load_dataset", + "ConditionManager", + "DPAConditionError", + ]: + assert hasattr(dpa_adapt, name), f"{name!r} not found on dpa_adapt" diff --git a/source/tests/dpa_adapt/test_conditions.py b/source/tests/dpa_adapt/test_conditions.py new file mode 100644 index 0000000000..f2aefad714 --- /dev/null +++ b/source/tests/dpa_adapt/test_conditions.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for ConditionManager and conditions integration — no real DPA checkpoint needed.""" + +import pickle +import sys +from pathlib import ( + Path, +) +from unittest.mock import ( + MagicMock, + patch, +) + +import numpy as np +import pytest + +# ---- mock torch (same pattern as test_predictor.py) ---- + + +def _pickle_save(obj, path, **kwargs): + with open(path, "wb") as f: + pickle.dump(obj, f) + + +def _pickle_load(path, **kwargs): + with open(path, "rb") as f: + return pickle.load(f) + + +# Only stub torch when it is genuinely absent; injecting a MagicMock into +# sys.modules unconditionally leaks into other test modules during a full +# pytest run (the stub wins the import race and stays session-wide). Same +# guard as test_predictor.py. +try: + import torch as _unused_torch_for_test +except Exception: + _mock_torch = MagicMock() + _mock_torch.save = _pickle_save + _mock_torch.load = _pickle_load + _mock_torch.cuda.is_available.return_value = False + _mock_torch.Tensor = type("Tensor", (), {}) + _unused_torch_for_test = _mock_torch + sys.modules.setdefault("torch", _mock_torch) +else: + _unused_torch_for_test.set_default_device(None) + +from dpa_adapt import ( + DPAFineTuner, + DPAPredictor, +) +from dpa_adapt.conditions import ( + ConditionManager, + DPAConditionError, +) + +# ---- helpers ---- + + +def _make_npy_system(root: Path, n_frames: int = 3, n_atoms: int = 2) -> None: + (root / "type.raw").write_text("0\n1\n") + (root / "type_map.raw").write_text("Cu\nO\n") + set_dir = root / "set.000" + set_dir.mkdir() + np.save(set_dir / "coord.npy", np.zeros((n_frames, n_atoms * 3))) + np.save(set_dir / "box.npy", np.eye(3).reshape(1, 9).repeat(n_frames, 0)) + np.save(set_dir / "energy.npy", np.arange(n_frames, dtype=float)) + + +FEAT_DIM = 8 + + +def _mock_extract_features(self, systems): + n_frames = sum(s.data["coords"].shape[0] for s in systems) + rng = np.random.default_rng(0) + return rng.random((n_frames, FEAT_DIM)) + + +def _mock_load_descriptor_model(self): + self._checkpoint_type_map = ["Cu", "O"] + return None + + +# ====================================================================== +# ConditionManager tests +# ====================================================================== + + +class TestConditionManager: + def test_fit_transform_single_key(self): + cm = ConditionManager() + cond = {"T": np.array([300.0, 400.0, 500.0])} + X = cm.fit_transform(cond) + assert X.shape == (3, 1) + + def test_fit_transform_multi_key(self): + cm = ConditionManager() + cond = { + "T": np.array([300.0, 400.0, 500.0]), + "P": np.array([1.0, 2.0, 3.0]), + } + X = cm.fit_transform(cond) + assert X.shape == (3, 2) + + def test_transform_normalizes_correctly(self): + cm = ConditionManager() + cond = {"T": np.array([300.0, 400.0, 500.0])} + X = cm.fit_transform(cond) + assert abs(X.mean()) < 1e-6 + assert abs(X.std(ddof=0) - 1.0) < 1e-6 + + def test_save_load_roundtrip(self, tmp_path): + cm = ConditionManager() + cond = {"T": np.array([300.0, 400.0, 500.0])} + cm.fit(cond) + expected = cm.transform(cond) + + path = str(tmp_path / "cm.pkl") + cm.save(path) + cm2 = ConditionManager.load(path) + result = cm2.transform(cond) + np.testing.assert_array_equal(result, expected) + + def test_transform_before_fit_raises(self): + cm = ConditionManager() + with pytest.raises(DPAConditionError, match="before fit"): + cm.transform({"T": np.array([1.0])}) + + def test_transform_missing_key_raises(self): + cm = ConditionManager() + cm.fit({"T": np.array([1.0, 2.0])}) + with pytest.raises(DPAConditionError, match="missing from transform"): + cm.transform({"other": np.array([1.0, 2.0])}) + + +# ====================================================================== +# DPAFineTuner with conditions +# ====================================================================== + + +class TestFineTunerWithConditions: + def test_fit_with_conditions_changes_feature_dim(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=4) + np.save(system / "set.000" / "fparam.npy", np.zeros((4, 1))) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear", fparam_dim=1) + ft.fit(str(system), target_key="energy") + + # The pipeline's first step (StandardScaler) reveals the input dim + scaler = ft.predictor.named_steps["standardscaler"] + assert scaler.n_features_in_ == FEAT_DIM + 1 + + def test_predict_missing_conditions_raises(self, tmp_path): + system_fit = tmp_path / "sys_fit" + system_fit.mkdir() + _make_npy_system(system_fit, n_frames=4) + np.save(system_fit / "set.000" / "fparam.npy", np.zeros((4, 1))) + + system_predict = tmp_path / "sys_predict" + system_predict.mkdir() + _make_npy_system(system_predict, n_frames=4) + # No fparam.npy here — should trigger DPAConditionError on predict + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear", fparam_dim=1) + ft.fit(str(system_fit), target_key="energy") + + with pytest.raises(DPAConditionError, match="fit with fparam"): + ft.predict(str(system_predict)) + + def test_predict_with_unexpected_fparam_does_not_raise(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=4) + # fparam.npy present even though model was NOT trained with fparam_dim + np.save(system / "set.000" / "fparam.npy", np.zeros((4, 1))) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear") + ft.fit(str(system), target_key="energy") + + # fparam.npy is silently ignored when model was fitted without fparam_dim + result = ft.predict(str(system)) + + assert result.predictions.shape == (4, 1) + + def test_freeze_load_with_conditions(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=4) + np.save(system / "set.000" / "fparam.npy", np.zeros((4, 1))) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear", fparam_dim=1) + ft.fit(str(system), target_key="energy") + + frozen = ft.freeze(str(tmp_path / "model.pth")) + + pred = DPAPredictor(frozen) + result = pred.predict(str(system)) + + assert result.predictions.shape == (4, 1) + + +# ====================================================================== +# DPAFineTuner without conditions (backward compat) +# ====================================================================== + + +class TestFineTunerNoConditions: + def test_fit_predict_no_conditions_unchanged(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=4) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear") + ft.fit(str(system), target_key="energy") + + result = ft.predict(str(system)) + + assert result.predictions.shape == (4, 1) diff --git a/source/tests/dpa_adapt/test_config_merge.py b/source/tests/dpa_adapt/test_config_merge.py new file mode 100644 index 0000000000..4375b136f8 --- /dev/null +++ b/source/tests/dpa_adapt/test_config_merge.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for recursive dict merge (was dpa_adapt ConfigHandler).""" + +from __future__ import ( + annotations, +) + +from dpa_adapt.data.smiles import _deep_merge # re-exported for reuse + + +def test_merge_deep_updates_nested_dicts() -> None: + base = {"training": {"numb_steps": 10, "data": {"batch_size": 1}}, "loss": "mae"} + updates = {"training": {"data": {"batch_size": 4}}} + + merged = _deep_merge(base, updates) + + assert merged["training"]["numb_steps"] == 10 + assert merged["training"]["data"]["batch_size"] == 4 + assert base["training"]["data"]["batch_size"] == 1 # original untouched diff --git a/source/tests/dpa_adapt/test_convert.py b/source/tests/dpa_adapt/test_convert.py new file mode 100644 index 0000000000..5ed5715a29 --- /dev/null +++ b/source/tests/dpa_adapt/test_convert.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for convert() routing and validation wiring. + +Uses hand-written VASP POSCAR files as inputs — a single-file, structure-only +format dpdata reads reliably, which is enough to exercise globbing, tree +mirroring, the manifest, and skip-on-failure. +""" + +import importlib +import json +import logging +from pathlib import ( + Path, +) + +import pytest + +from dpa_adapt.data.convert import ( + _glob_base, + convert, +) +from dpa_adapt.data.validate import ( + Issue, +) + +# The dpa_adapt.data package re-exports the convert() function, which shadows +# the submodule name — grab the real module object for monkeypatching. +convert_mod = importlib.import_module("dpa_adapt.data.convert") + + +_POSCAR = """\ +Cu O test +1.0 +10.0 0.0 0.0 +0.0 10.0 0.0 +0.0 0.0 10.0 +Cu O +1 1 +Cartesian +0.0 0.0 0.0 +1.0 1.0 1.0 +""" + + +def _write_poscar(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(_POSCAR) + + +# --------------------------------------------------------------------------- +# _glob_base +# --------------------------------------------------------------------------- + + +def test_glob_base_recursive_wildcard(): + assert _glob_base("calcs/**/OUTCAR") == Path("calcs") + + +def test_glob_base_single_wildcard(): + assert _glob_base("data/raw/*.xyz") == Path("data/raw") + + +def test_glob_base_no_wildcard_uses_parent(tmp_path): + f = tmp_path / "only" / "POSCAR" + _write_poscar(f) + assert _glob_base(str(f)) == f.parent + + +# --------------------------------------------------------------------------- +# convert() glob batch routing +# --------------------------------------------------------------------------- + + +def test_convert_glob_mirrors_input_tree(tmp_path): + _write_poscar(tmp_path / "in" / "a" / "POSCAR") + _write_poscar(tmp_path / "in" / "b" / "c" / "POSCAR") + out = tmp_path / "out" + + result = convert( + str(tmp_path / "in" / "**" / "POSCAR"), + str(out), + fmt="vasp/poscar", + type_map=["Cu", "O"], + ) + results = result["output_dirs"] + + assert result["method"] == "batch_dpdata" + assert len(results) == 2 + # input tree mirrored, file stem used as the leaf system directory + assert (out / "a" / "POSCAR" / "type.raw").exists() + assert (out / "b" / "c" / "POSCAR" / "type.raw").exists() + assert (out / "a" / "POSCAR" / "set.000" / "coord.npy").exists() + # returned paths point at the created system dirs + assert all(Path(r).is_dir() for r in results) + + +def test_convert_glob_writes_manifest(tmp_path): + _write_poscar(tmp_path / "in" / "a" / "POSCAR") + out = tmp_path / "out" + result = convert( + str(tmp_path / "in" / "**" / "POSCAR"), + str(out), + fmt="vasp/poscar", + type_map=["Cu", "O"], + ) + assert result["manifest"] == str(out.resolve() / "manifest.json") + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["fmt"] == "vasp/poscar" + assert manifest["type_map"] == ["Cu", "O"] + assert len(manifest["converted"]) == 1 + assert manifest["skipped"] == [] + assert manifest["converted"][0]["input"].endswith("POSCAR") + + +def test_convert_glob_skips_bad_file(tmp_path, caplog): + _write_poscar(tmp_path / "in" / "good" / "POSCAR") + bad = tmp_path / "in" / "bad" / "POSCAR" + bad.parent.mkdir(parents=True) + bad.write_text("garbage not a poscar\n") + out = tmp_path / "out" + + with caplog.at_level(logging.WARNING, logger="dpa_adapt"): + result = convert( + str(tmp_path / "in" / "**" / "POSCAR"), + str(out), + fmt="vasp/poscar", + type_map=["Cu", "O"], + ) + results = result["output_dirs"] + + # good file converted, bad file skipped and recorded + assert len(results) == 1 + assert "good" in results[0] + manifest = json.loads((out / "manifest.json").read_text()) + assert len(manifest["converted"]) == 1 + assert len(manifest["skipped"]) == 1 + assert "bad" in manifest["skipped"][0]["input"] + assert manifest["skipped"][0]["error"] + assert "skipping" in caplog.text + # the empty output subdir left by the failed convert is cleaned up + assert not (out / "bad" / "POSCAR").exists() + + +def test_convert_glob_strict_fails_fast_on_bad_file(tmp_path): + bad = tmp_path / "in" / "bad" / "POSCAR" + bad.parent.mkdir(parents=True) + bad.write_text("garbage not a poscar\n") + out = tmp_path / "out" + with pytest.raises(Exception): + convert( + str(tmp_path / "in" / "**" / "POSCAR"), + str(out), + fmt="vasp/poscar", + type_map=["Cu", "O"], + strict=True, + ) + + +# --------------------------------------------------------------------------- +# convert() validation wiring +# --------------------------------------------------------------------------- + + +def test_convert_validate_true_runs_check(tmp_path, monkeypatch): + _write_poscar(tmp_path / "POSCAR") + seen = {} + + def _fake_check(data, strict=False): + seen["is_system"] = hasattr(data, "data") # dpdata.System + seen["strict"] = strict + return [] + + monkeypatch.setattr(convert_mod, "check_data", _fake_check) + result = convert( + str(tmp_path / "POSCAR"), + str(tmp_path / "out"), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=True, + ) + out = result["output_dir"] + assert seen["is_system"] is True # check_data received a dpdata object + assert seen["strict"] is False + assert Path(out).exists() + + +def test_convert_validate_false_skips_check(tmp_path, monkeypatch): + _write_poscar(tmp_path / "POSCAR") + + def _boom(*a, **k): + raise AssertionError("check_data must not run when validate=False") + + monkeypatch.setattr(convert_mod, "check_data", _boom) + result = convert( + str(tmp_path / "POSCAR"), + str(tmp_path / "out"), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=False, + ) + out = result["output_dir"] + assert Path(out).exists() + + +def test_convert_validation_issues_are_logged(tmp_path, monkeypatch, caplog): + _write_poscar(tmp_path / "POSCAR") + fake = Issue("error", "sys", "", "energies", "boom description") + monkeypatch.setattr(convert_mod, "check_data", lambda data, strict=False: [fake]) + with caplog.at_level(logging.WARNING, logger="dpa_adapt"): + convert( + str(tmp_path / "POSCAR"), + str(tmp_path / "out"), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=True, + ) + assert "boom description" in caplog.text + + +def test_convert_strict_passed_through(tmp_path, monkeypatch): + _write_poscar(tmp_path / "POSCAR") + seen = {} + + def _fake_check(path, strict=False): + seen["strict"] = strict + return [] + + monkeypatch.setattr(convert_mod, "check_data", _fake_check) + convert( + str(tmp_path / "POSCAR"), + str(tmp_path / "out"), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=True, + strict=True, + ) + assert seen["strict"] is True + + +# --------------------------------------------------------------------------- +# convert() glob support +# --------------------------------------------------------------------------- + + +def test_convert_glob_single_match(tmp_path): + """Pass a glob pattern that matches exactly one file → batch output.""" + raw_dir = tmp_path / "raw" + raw_dir.mkdir() + _write_poscar(raw_dir / "input.sdf") + + out = tmp_path / "out" + result = convert( + str(raw_dir / "*.sdf"), + str(out), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=False, + ) + assert result["method"] == "batch_dpdata" + assert len(result["output_dirs"]) == 1 + system_dir = out / "input" + assert system_dir.is_dir() + assert (system_dir / "type.raw").exists() + assert (system_dir / "set.000" / "coord.npy").exists() + assert (out / "manifest.json").exists() + + +def test_convert_glob_multi_match(tmp_path): + """Pass a glob pattern matching 3 files → mirrored batch output.""" + raw_dir = tmp_path / "raw" + raw_dir.mkdir() + for name in ("a.sdf", "b.sdf", "c.sdf"): + _write_poscar(raw_dir / name) + + out = tmp_path / "out" + result = convert( + str(raw_dir / "*.sdf"), + str(out), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=False, + ) + assert result["method"] == "batch_dpdata" + assert len(result["output_dirs"]) == 3 + for sub in ("a", "b", "c"): + sub_dir = out / sub + assert sub_dir.is_dir(), f"missing {sub}" + assert (sub_dir / "type.raw").exists() + assert (sub_dir / "set.000" / "coord.npy").exists() + subdirs = [p.name for p in out.iterdir() if p.is_dir()] + assert sorted(subdirs) == ["a", "b", "c"] + + +def test_convert_glob_no_match(tmp_path): + """Pass a glob pattern with no matches → FileNotFoundError.""" + raw_dir = tmp_path / "raw" + raw_dir.mkdir() + + with pytest.raises(FileNotFoundError, match="No files matched pattern"): + convert( + str(raw_dir / "*.sdf"), + str(tmp_path / "out"), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=False, + ) + + +def test_convert_literal_path_unchanged(tmp_path): + """Pass a literal path with no wildcards → works as before.""" + _write_poscar(tmp_path / "POSCAR") + out = tmp_path / "out" + result = convert( + str(tmp_path / "POSCAR"), + str(out), + fmt="vasp/poscar", + type_map=["Cu", "O"], + validate=False, + ) + assert result["method"] == "dpdata" + assert Path(result["output_dir"]).is_dir() + assert (Path(result["output_dir"]) / "type.raw").exists() diff --git a/source/tests/dpa_adapt/test_dataset.py b/source/tests/dpa_adapt/test_dataset.py new file mode 100644 index 0000000000..963088f5ce --- /dev/null +++ b/source/tests/dpa_adapt/test_dataset.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for load_dataset().""" + +import logging +from pathlib import ( + Path, +) + +import numpy as np +import pytest + +from dpa_adapt.data.dataset import ( + load_dataset, +) +from dpa_adapt.data.errors import ( + DPADataError, +) +from dpa_adapt.data.loader import ( + load_data, +) + + +def _write_system( + root: str, + natoms: int = 2, + nframes: int = 3, + label_key: str = "energy", + elements: list[str] | None = None, +) -> Path: + """Create a minimal deepmd/npy system directory. Returns its Path.""" + if elements is None: + elements = ["H", "O"] + root = Path(root) + root.mkdir(parents=True, exist_ok=True) + (root / "type.raw").write_text( + "\n".join(str(i % len(elements)) for i in range(natoms)) + "\n" + ) + (root / "type_map.raw").write_text("\n".join(elements) + "\n") + sdir = root / "set.000" + sdir.mkdir(exist_ok=True) + np.save(sdir / "coord.npy", np.zeros((nframes, natoms * 3))) + np.save(sdir / "box.npy", np.tile(np.eye(3).ravel(), (nframes, 1))) + np.save(sdir / f"{label_key}.npy", np.zeros((nframes, 1))) + return root + + +class TestLoadDataset: + def test_label_filter(self, tmp_path): + root = _write_system(str(tmp_path / "sys1"), label_key="energy") + # load_dataset resolves "energy" → "energies" via alias + systems = load_dataset(str(root), label_key="energy") + assert len(systems) == 1 + + def test_label_filter_skips_missing(self, tmp_path, caplog): + root = _write_system(str(tmp_path / "sys1"), label_key="energy") + caplog.set_level(logging.WARNING, logger="dpa_adapt.data.dataset") + with pytest.raises(DPADataError, match="no valid systems"): + load_dataset(str(root), label_key="nonexistent") + + def test_explicit_list(self, tmp_path): + s1 = load_data(str(_write_system(str(tmp_path / "s1"), label_key="energy")))[0] + s2 = load_data(str(_write_system(str(tmp_path / "s2"), label_key="energy")))[0] + systems = load_dataset([s1, s2], label_key="energy") + assert len(systems) == 2 + + def test_single_path(self, tmp_path): + root = _write_system(str(tmp_path / "s1"), label_key="energy") + systems = load_dataset(str(root), label_key="energy") + assert len(systems) == 1 + + def test_no_label_filter_raises_when_all_skipped(self, tmp_path): + root = _write_system(str(tmp_path / "s1"), label_key="energy") + with pytest.raises(DPADataError): + load_dataset(str(root), label_key="bandgap") diff --git a/source/tests/dpa_adapt/test_finetuner_strategies.py b/source/tests/dpa_adapt/test_finetuner_strategies.py new file mode 100644 index 0000000000..6f280ffa90 --- /dev/null +++ b/source/tests/dpa_adapt/test_finetuner_strategies.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for DPAFineTuner training-paradigm strategies +(frozen_head / finetune). + +Mock ``dp --pt train`` via ``subprocess.run``; verify: +- Correct DPATrainer params per strategy +- Auto type_map inference (non-empty, checkpoint-derived) +- Config structure (input.json) +""" + +from __future__ import ( + annotations, +) + +import json +import os +from pathlib import ( + Path, +) +from unittest.mock import ( + patch, +) + +import pytest + +from dpa_adapt.finetuner import ( + DPAFineTuner, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_FULL_TYPE_MAP = ["H", "He", "Li", "Be", "B", "C", "N", "O"] # 8 el. subset + + +def _fake_ckpt_sd(type_map=None): + """Minimal DPA-3.1-3M-like state_dict.""" + if type_map is None: + type_map = list(_FULL_TYPE_MAP) + descriptor = { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 16, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 1200, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 300, + "axis_neuron": 4, + "skip_stat": True, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "smooth_edge_update": True, + "use_dynamic_sel": True, + "sel_reduce_factor": 10.0, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "n_multi_edge_message": 1, + "optim_update": True, + "use_exp_switch": True, + }, + "activation_function": "custom_silu:3.0", + "precision": "float32", + "use_tebd_bias": False, + "concat_output_tebd": False, + "exclude_types": [], + "env_protection": 0.0, + "trainable": True, + "use_econf_tebd": False, + } + return { + "model": { + "_extra_state": { + "model_params": { + "shared_dict": { + "dpa3_descriptor": descriptor, + "type_map": type_map, + }, + # model_dict must be non-empty for read_checkpoint_type_map + # to enter the multi-task branch and scan shared_dict. + "model_dict": { + "SPICE2": {"fitting_net": {"type": "ener"}}, + }, + } + } + } + } + + +def _make_system_dirs(tmp_path, formulas=("CompA", "CompB"), n=3): + """Create minimal system dirs with type_map.raw, set.000/coord.npy, + and set.000/overpotential.npy. + """ + import numpy as np + + systems = [] + for formula in formulas: + for i in range(n): + sysdir = tmp_path / formula / str(i) + sysdir.mkdir(parents=True) + (sysdir / "type_map.raw").write_text("H\nO\n") + (sysdir / "type.raw").write_text("0\n1\n") + sdir = sysdir / "set.000" + sdir.mkdir() + np.save(sdir / "coord.npy", np.zeros((2, 6))) + np.save(sdir / "box.npy", np.tile(np.eye(3).ravel(), (2, 1))) + np.save(sdir / "overpotential.npy", np.ones((2, 1))) + systems.append(str(sysdir)) + return systems + + +def _make_system_dirs(tmp_path, formulas=("CompA", "CompB"), n=3): + """Create minimal system dirs with type_map.raw, set.000/coord.npy, + and set.000/overpotential.npy. + """ + import numpy as np + + systems = [] + for formula in formulas: + for i in range(n): + sysdir = tmp_path / formula / str(i) + sysdir.mkdir(parents=True) + (sysdir / "type_map.raw").write_text("H\nO\n") + (sysdir / "type.raw").write_text("0\n1\n") + sdir = sysdir / "set.000" + sdir.mkdir() + np.save(sdir / "coord.npy", np.zeros((2, 6))) + np.save(sdir / "box.npy", np.tile(np.eye(3).ravel(), (2, 1))) + np.save(sdir / "overpotential.npy", np.ones((2, 1))) + systems.append(str(sysdir)) + return systems + + +def _mock_dp_train(ckpt_dir): + """Return a ``subprocess.run`` side-effect that writes a fake ckpt.""" + + def _run(cmd, *args, **kwargs): + os.makedirs(ckpt_dir, exist_ok=True) + # Determine max_steps from config + for a in cmd if isinstance(cmd, list) else []: + if a.endswith(".json"): + with open(a) as f: + cfg = json.load(f) + step = cfg["training"]["numb_steps"] + (Path(ckpt_dir) / f"model.ckpt-{step}.pt").write_bytes(b"") + break + + class R: + returncode = 0 + + return R() + + return _run + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestStrategyValidation: + def test_invalid_strategy_raises(self): + with pytest.raises(ValueError, match="strategy"): + DPAFineTuner(strategy="nonexistent") + + def test_default_is_frozen_sklearn(self): + m = DPAFineTuner() + assert m.strategy == "frozen_sklearn" + + +class TestAutoTypeMap: + """Auto type_map inference for training paradigms.""" + + def test_resolve_type_maps_from_checkpoint(self, monkeypatch, tmp_path): + """LP/FT: type_map from checkpoint (8 elements).""" + import torch + + monkeypatch.setattr(torch, "load", lambda *a, **kw: _fake_ckpt_sd()) + + systems = _make_system_dirs(tmp_path) + m = DPAFineTuner( + pretrained="/fake.pt", + strategy="frozen_head", + init_branch="SPICE2", + ) + tm = m._resolve_type_maps(systems) + assert tm == _FULL_TYPE_MAP + assert len(tm) == 8 + assert tm != [] + + def test_no_type_map_raw_is_ok(self, monkeypatch, tmp_path): + """LP/FT: missing type_map.raw should not crash (checkpoint fallback).""" + import torch + + monkeypatch.setattr(torch, "load", lambda *a, **kw: _fake_ckpt_sd()) + + import numpy as np + + systems = [] + for i in range(2): + sysdir = tmp_path / f"sys_{i}" + sysdir.mkdir(parents=True) + sdir = sysdir / "set.000" + sdir.mkdir() + np.save(sdir / "coord.npy", np.zeros((2, 6))) + np.save(sdir / "box.npy", np.tile(np.eye(3).ravel(), (2, 1))) + np.save(sdir / "overpotential.npy", np.ones((2, 1))) + systems.append(str(sysdir)) + + m = DPAFineTuner( + pretrained="/fake.pt", + strategy="finetune", + ) + tm = m._resolve_type_maps(systems) + assert tm == _FULL_TYPE_MAP # still reads from checkpoint + + +class TestTrainingParadigms: + """End-to-end: each strategy builds correct config, type_map auto-inferred, + dp train mocked to write a fake checkpoint. + """ + + @pytest.fixture(autouse=True) + def _mock_torch(self, monkeypatch, tmp_path): + import torch + + monkeypatch.setattr(torch, "load", lambda *a, **kw: _fake_ckpt_sd()) + # DPATrainer.__init__ checks os.path.isfile(pretrained); create a + # real file so the check passes. + self._ckpt = tmp_path / "fake.pt" + self._ckpt.write_bytes(b"") + + @pytest.mark.parametrize( + "strategy,expect_freeze,expect_tm_len", + [ + ("frozen_head", True, 8), + ("finetune", False, 8), + ], + ) + def test_config_type_map_nonempty( + self, + tmp_path, + strategy, + expect_freeze, + expect_tm_len, + ): + """input.json must have non-empty type_map (not []) for each strategy.""" + out_dir = tmp_path / "out" + systems = _make_system_dirs(tmp_path) + valid_systems = _make_system_dirs(tmp_path, formulas=("CompC",), n=2) + + m = DPAFineTuner( + pretrained=str(self._ckpt), + strategy=strategy, + property_name="overpotential", + task_dim=1, + intensive=True, + max_steps=20, + output_dir=str(out_dir), + ) + + with patch("subprocess.run", side_effect=_mock_dp_train(str(out_dir))): + ckpt = m._fit_training( + systems, valid_systems, m._resolve_type_maps(systems) + ) + + assert ckpt is not None + assert "model.ckpt-20.pt" in ckpt + + # Check the generated input.json + input_json = out_dir / "input.json" + assert input_json.is_file(), f"input.json not found in {out_dir}" + cfg = json.loads(input_json.read_text()) + tm = cfg["model"]["type_map"] + assert isinstance(tm, list), f"type_map is not a list: {tm!r}" + assert len(tm) == expect_tm_len, ( + f"{strategy}: type_map should be {expect_tm_len} elements, " + f"got {len(tm)}: {tm}" + ) + assert tm != [], "type_map is empty — would cause CUDA gather out-of-bounds" + + @pytest.mark.parametrize("strategy", ["frozen_head", "finetune"]) + def test_strategy_to_trainer_params(self, tmp_path, strategy): + """Each strategy produces correct DPATrainer freeze_backbone / pretrained.""" + out_dir = tmp_path / "out" + systems = _make_system_dirs(tmp_path) + valid_systems = _make_system_dirs(tmp_path, formulas=("CompC",), n=2) + + m = DPAFineTuner( + pretrained=str(self._ckpt), + strategy=strategy, + property_name="gap", + task_dim=1, + intensive=True, + max_steps=20, + output_dir=str(out_dir), + init_branch="SPICE2", + ) + + with patch("subprocess.run", side_effect=_mock_dp_train(str(out_dir))): + m._fit_training(systems, valid_systems, list(_FULL_TYPE_MAP)) + + cfg = json.loads((out_dir / "input.json").read_text()) + + # Check fitting_net params were propagated + fn = cfg["model"]["fitting_net"] + assert fn["property_name"] == "gap" + assert fn["task_dim"] == 1 + assert fn["intensive"] is True + + # LP must freeze backbone + if strategy == "frozen_head": + assert cfg["model"]["descriptor"]["trainable"] is False + else: + assert cfg["model"]["descriptor"]["trainable"] is True + + def test_fit_dispatch_calls_training_path(self, tmp_path): + """fit() with a training strategy calls _fit_training, not sklearn.""" + out_dir = tmp_path / "out" + systems = _make_system_dirs(tmp_path) + valid_systems = _make_system_dirs(tmp_path, formulas=("CompC",), n=2) + + m = DPAFineTuner( + pretrained=str(self._ckpt), + strategy="finetune", + property_name="overpotential", + max_steps=20, + output_dir=str(out_dir), + ) + + with patch("subprocess.run", side_effect=_mock_dp_train(str(out_dir))): + m.fit(train_data=systems, valid_data=valid_systems) + + assert m._fitted is True + assert (out_dir / "input.json").is_file() + cfg = json.loads((out_dir / "input.json").read_text()) + assert len(cfg["model"]["type_map"]) == 8 + + +def _mock_load_descriptor_model_cache_test(self): + self._checkpoint_type_map = ["H", "O"] + return None + + +class TestFitDescriptorCache: + """_fit_sklearn() caches extracted descriptors via desc_cache.""" + + def test_fit_uses_cache(self, tmp_path, monkeypatch): + """Second fit() on same data hits the cache — extraction called once.""" + import numpy as np + + # Isolate cache to a temp directory. + monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path)) + + # Create pretrained checkpoint file (cache key uses its mtime). + ckpt = tmp_path / "fake.pt" + ckpt.write_text("mock") + + # Create a minimal deepmd/npy system. + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n1\n") + (root / "type_map.raw").write_text("H\nO\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((3, 6))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (3, 1))) + np.save(sd / "energy.npy", np.arange(3, dtype=float)) + + call_count = 0 + + def _fake_extract(self, systems): + nonlocal call_count + call_count += 1 + n_frames = sum(s.data["coords"].shape[0] for s in systems) + return np.random.default_rng(42).random((n_frames, 32)) + + with ( + patch.object( + DPAFineTuner, + "_load_descriptor_model", + _mock_load_descriptor_model_cache_test, + ), + patch.object(DPAFineTuner, "_extract_features", _fake_extract), + ): + m = DPAFineTuner(pretrained=str(ckpt), predictor="ridge") + m.fit(str(root), target_key="energy") + + m2 = DPAFineTuner(pretrained=str(ckpt), predictor="ridge") + m2.fit(str(root), target_key="energy") + + assert call_count == 1, f"Expected 1 extraction call, got {call_count}" + + +class TestFreezeStrategies: + def test_freeze_training_strategy_runs_dp_freeze(self, tmp_path, monkeypatch): + out_dir = tmp_path / "out" + out_dir.mkdir() + (out_dir / "model.ckpt-20.pt").write_bytes(b"ckpt") + target = tmp_path / "frozen_training.pth" + + calls = [] + + def fake_run(cmd, *args, **kwargs): + calls.append((cmd, kwargs)) + output_name = cmd[cmd.index("-o") + 1] + Path(kwargs["cwd"], output_name).write_bytes(b"frozen") + + class R: + returncode = 0 + stdout = "" + stderr = "" + + return R() + + monkeypatch.setattr("subprocess.run", fake_run) + m = DPAFineTuner(strategy="finetune", output_dir=str(out_dir)) + m._fitted = True + + assert m.freeze(str(target)) == str(target.resolve()) + assert target.read_bytes() == b"frozen" + assert calls[0][0][1:3] == ["--pt", "freeze"] + assert calls[0][1]["cwd"] == str(out_dir.resolve()) + + def test_freeze_mft_strategy_copies_downstream_freeze(self, tmp_path): + src = tmp_path / "out" / "frozen_property.pth" + src.parent.mkdir() + src.write_bytes(b"mft") + target = tmp_path / "custom_mft.pth" + + class FakeMFT: + downstream_task_type = "property" + + def _freeze_ckpt(self): + return str(src) + + m = DPAFineTuner(strategy="mft", property_name="gap") + m._fitted = True + m._mft = FakeMFT() + + assert m.freeze(str(target)) == str(target.resolve()) + assert target.read_bytes() == b"mft" + + +def test_extract_features_detaches_grad_tensors_before_numpy(monkeypatch): + import numpy as np + import torch + + from dpa_adapt import finetuner as finetuner_mod + + class FakeExtractor: + def __init__(self, model): + self.model = model + + def _enable_hook(self): + pass + + def _disable_hook(self): + pass + + def _run_forward(self, coord_t, atype_t, box_t): + return (coord_t * 2.0).reshape(coord_t.shape[0], atype_t.shape[1], 3) + + class FakeSystem: + orig = "fake" + + def __init__(self): + self.data = {"atom_names": ["H"]} + + monkeypatch.setattr(finetuner_mod, "_DescriptorExtraction", FakeExtractor) + monkeypatch.setattr( + finetuner_mod, + "_load_npy_system", + lambda system: ( + np.array([[[1.0, 2.0, 3.0]]]), + np.tile(np.eye(3).ravel(), (1, 1)), + np.array([0], dtype=np.int64), + ), + ) + + ft = finetuner_mod._FrozenSklearnPipeline( + pretrained="fake.pt", + model_branch=None, + predictor_type="linear", + pooling="mean", + seed=42, + ) + ft._model = object() + ft._device = torch.device("cpu") + ft.type_map = ["H"] + ft._checkpoint_type_map = ["H"] + + features = ft.extract_features([FakeSystem()]) + + np.testing.assert_allclose(features, np.array([[2.0, 4.0, 6.0]])) diff --git a/source/tests/dpa_adapt/test_fparam.py b/source/tests/dpa_adapt/test_fparam.py new file mode 100644 index 0000000000..02cf39edc2 --- /dev/null +++ b/source/tests/dpa_adapt/test_fparam.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# Tests for fparam (frame-level condition input) support. +# Heavy deps (torch, dpdata, dp subprocess) are mocked throughout. + +from __future__ import ( + annotations, +) + +from unittest.mock import ( + patch, +) + +import numpy as np +import pytest + +from dpa_adapt.data.errors import ( + DPADataError, +) +from dpa_adapt.trainer import ( + DPATrainer, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +DUMMY_TYPE_MAP = ["H", "C", "N", "O"] + + +def _make_systems(tmp_path, prefix: str, n: int) -> str: + """Create n empty system dirs and return a glob pattern matching them.""" + root = tmp_path / prefix + root.mkdir(parents=True, exist_ok=True) + for i in range(n): + (root / f"sys_{i:03d}").mkdir() + return str(root / "sys_*") + + +def _make_dummy_trainer(fparam_dim=0, **kwargs): + """Construct a DPATrainer with minimal valid args.""" + defaults = { + "pretrained": None, + "train_systems": "dummy_train", + "valid_systems": "dummy_valid", + "type_map": DUMMY_TYPE_MAP, + "fparam_dim": fparam_dim, + } + defaults.update(kwargs) + return DPATrainer(**defaults) + + +# --------------------------------------------------------------------------- +# Tests: trainer fparam_dim validation in __init__ +# --------------------------------------------------------------------------- + + +def test_trainer_fparam_dim_negative_raises(): + """DPATrainer(fparam_dim=-1) raises ValueError.""" + with pytest.raises(ValueError, match="fparam_dim must be a non-negative"): + _make_dummy_trainer(fparam_dim=-1) + + +def test_trainer_fparam_dim_non_int_raises(): + """DPATrainer(fparam_dim='3') raises ValueError.""" + with pytest.raises(ValueError, match="fparam_dim must be a non-negative"): + _make_dummy_trainer(fparam_dim="3") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Tests: trainer._build_fitting_net fparam injection +# --------------------------------------------------------------------------- + + +def test_trainer_fparam_dim_injected_in_fitting_net(): + """DPATrainer(fparam_dim=3)._build_fitting_net() includes numb_fparam=3.""" + t = _make_dummy_trainer(fparam_dim=3) + fn = t._build_fitting_net() + assert fn["numb_fparam"] == 3 + + +def test_trainer_fparam_dim_zero_not_injected(): + """DPATrainer(fparam_dim=0)._build_fitting_net() does NOT contain 'fparam_dim'.""" + t = _make_dummy_trainer(fparam_dim=0) + fn = t._build_fitting_net() + assert "fparam_dim" not in fn + + +# --------------------------------------------------------------------------- +# Tests: trainer._validate_fparam +# --------------------------------------------------------------------------- + + +def test_validate_fparam_missing_file_raises(tmp_path): + """_validate_fparam raises DPADataError when fparam.npy is missing.""" + sys_dir = tmp_path / "system" + set_dir = sys_dir / "set.000" + set_dir.mkdir(parents=True) + + with pytest.raises(DPADataError, match="is missing"): + DPATrainer._validate_fparam([str(sys_dir)], fparam_dim=2) + + +def test_validate_fparam_wrong_shape_raises(tmp_path): + """_validate_fparam raises DPADataError when shape[1] != fparam_dim.""" + sys_dir = tmp_path / "system" + set_dir = sys_dir / "set.000" + set_dir.mkdir(parents=True) + # shape (5, 3), expected dim 2 + np.save(str(set_dir / "fparam.npy"), np.zeros((5, 3))) + + with pytest.raises(DPADataError, match="has shape"): + DPATrainer._validate_fparam([str(sys_dir)], fparam_dim=2) + + +def test_validate_fparam_correct_passes(tmp_path): + """_validate_fparam does NOT raise when shape matches.""" + sys_dir = tmp_path / "system" + set_dir = sys_dir / "set.000" + set_dir.mkdir(parents=True) + np.save(str(set_dir / "fparam.npy"), np.zeros((5, 2))) + + # Should not raise + DPATrainer._validate_fparam([str(sys_dir)], fparam_dim=2) + + +def test_validate_fparam_1d_raises_dpadataerror(tmp_path): + """A malformed 1-D fparam.npy raises DPADataError, not a bare IndexError.""" + sys_dir = tmp_path / "system" + set_dir = sys_dir / "set.000" + set_dir.mkdir(parents=True) + np.save(str(set_dir / "fparam.npy"), np.zeros((5,))) # 1-D, not (n, dim) + + with pytest.raises(DPADataError, match="2-D"): + DPATrainer._validate_fparam([str(sys_dir)], fparam_dim=2) + + +def test_validate_fparam_row_count_mismatch_raises(tmp_path): + """Row count must match the set's frame count (coord.npy).""" + sys_dir = tmp_path / "system" + set_dir = sys_dir / "set.000" + set_dir.mkdir(parents=True) + np.save(str(set_dir / "fparam.npy"), np.zeros((5, 2))) # 5 rows + np.save(str(set_dir / "coord.npy"), np.zeros((4, 6))) # 4 frames + + with pytest.raises(DPADataError, match="rows but set"): + DPATrainer._validate_fparam([str(sys_dir)], fparam_dim=2) + + +def test_validate_fparam_multiple_systems(tmp_path): + """_validate_fparam checks all set.* dirs across multiple systems.""" + for i in range(2): + sys_dir = tmp_path / f"sys_{i}" + for s in ("set.000", "set.001"): + (sys_dir / s).mkdir(parents=True) + np.save(str(sys_dir / s / "fparam.npy"), np.zeros((10, 3))) + + DPATrainer._validate_fparam( + [str(tmp_path / "sys_0"), str(tmp_path / "sys_1")], + fparam_dim=3, + ) + + +# --------------------------------------------------------------------------- +# Tests: DPAFineTuner forwards fparam_dim to DPATrainer +# --------------------------------------------------------------------------- + + +def test_finetuner_fparam_forwarded_to_trainer(): + """DPAFineTuner(fparam_dim=4, strategy='finetune') passes fparam_dim=4 to DPATrainer.""" + with patch("dpa_adapt.trainer.DPATrainer") as mock_trainer_cls: + from dpa_adapt.finetuner import ( + DPAFineTuner, + ) + + ft = DPAFineTuner( + pretrained="dummy.pt", + strategy="finetune", + fparam_dim=4, + ) + + # Call _fit_training directly (skip type_map resolution, skip actual fit) + ft._fit_training("dummy_train", "dummy_valid", DUMMY_TYPE_MAP) + + mock_trainer_cls.assert_called_once() + _, kwargs = mock_trainer_cls.call_args + assert kwargs["fparam_dim"] == 4 + + +def test_finetuner_fparam_zero_not_forwarded(): + """DPAFineTuner(fparam_dim=0) passes fparam_dim=0 (default, disabled).""" + with patch("dpa_adapt.trainer.DPATrainer") as mock_trainer_cls: + from dpa_adapt.finetuner import ( + DPAFineTuner, + ) + + ft = DPAFineTuner( + pretrained="dummy.pt", + strategy="finetune", + ) + + ft._fit_training("dummy_train", "dummy_valid", DUMMY_TYPE_MAP) + + mock_trainer_cls.assert_called_once() + _, kwargs = mock_trainer_cls.call_args + assert kwargs["fparam_dim"] == 0 + + +# --------------------------------------------------------------------------- +# Tests: CLI --fparam-dim parsing +# --------------------------------------------------------------------------- + + +def test_cli_fparam_dim_parsed(): + """--fparam-dim 3 is parsed to args.fparam_dim == 3.""" + from dpa_adapt.cli import ( + get_parser, + ) + + parser = get_parser() + args = parser.parse_args( + [ + "fit", + "--train-data", + "x", + "--fparam-dim", + "3", + ] + ) + assert args.fparam_dim == 3 + + +def test_cli_fparam_dim_default_zero(): + """Without --fparam-dim, args.fparam_dim defaults to 0.""" + from dpa_adapt.cli import ( + get_parser, + ) + + parser = get_parser() + args = parser.parse_args( + [ + "fit", + "--train-data", + "x", + ] + ) + assert args.fparam_dim == 0 + + +# --------------------------------------------------------------------------- +# Tests: MFTFineTuner.fit() calls _validate_fparam +# --------------------------------------------------------------------------- + + +def test_mft_fparam_validate_called_on_fit(): + """MFTFineTuner.fit() calls _validate_fparam when fparam_dim > 0.""" + with ( + patch("dpa_adapt.trainer.DPATrainer._validate_fparam") as mock_validate, + patch("dpa_adapt.config.manager.MFTConfigManager") as mock_cm_class, + patch("dpa_adapt.mft.subprocess.Popen") as mock_popen, + ): + from dpa_adapt.mft import ( + MFTFineTuner, + ) + + mock_process = mock_popen.return_value + mock_process.stdout = [] + mock_process.returncode = 0 + + mft = MFTFineTuner( + pretrained="dummy.pt", + property_name="homo", + fparam_dim=3, + type_map=["H"], + ) + mft.fit(train_data="dummy_train", aux_data="dummy_aux") + + mock_validate.assert_called_once() + args, _kwargs = mock_validate.call_args + assert args[0] == "dummy_train" + assert args[1] == 3 + + +def test_mft_fparam_validate_skipped_when_zero(): + """MFTFineTuner.fit() does NOT call _validate_fparam when fparam_dim=0.""" + with ( + patch("dpa_adapt.trainer.DPATrainer._validate_fparam") as mock_validate, + patch("dpa_adapt.config.manager.MFTConfigManager") as mock_cm_class, + patch("dpa_adapt.mft.subprocess.Popen") as mock_popen, + ): + from dpa_adapt.mft import ( + MFTFineTuner, + ) + + mock_process = mock_popen.return_value + mock_process.stdout = [] + mock_process.returncode = 0 + + mft = MFTFineTuner( + pretrained="dummy.pt", + property_name="homo", + fparam_dim=0, + type_map=["H"], + ) + mft.fit(train_data="dummy_train", aux_data="dummy_aux") + + mock_validate.assert_not_called() diff --git a/source/tests/dpa_adapt/test_loader.py b/source/tests/dpa_adapt/test_loader.py new file mode 100644 index 0000000000..cac089674a --- /dev/null +++ b/source/tests/dpa_adapt/test_loader.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for data loading, dpdata integration, and attach_labels.""" + +import numpy as np +import pytest + +from dpa_adapt.data.convert import ( + _key_from_head, + attach_labels, +) +from dpa_adapt.data.errors import ( + DPADataError, +) +from dpa_adapt.data.loader import ( + load_data, +) +from dpa_adapt.finetuner import ( + _load_labels, + _load_npy_system, +) + + +def _make_system(tmp_path, name="sys", set_indices=(0,), n_atoms=2, n_frames=3): + """Create a minimal deepmd/npy system dir and load it via dpdata.""" + root = tmp_path / name + root.mkdir() + (root / "type.raw").write_text("\n".join(str(i % 2) for i in range(n_atoms)) + "\n") + (root / "type_map.raw").write_text("H\nO\n") + for idx in set_indices: + sd = root / f"set.{idx:03d}" + sd.mkdir() + rng_coord = np.random.default_rng() + rng_energy = np.random.default_rng() + np.save(sd / "coord.npy", rng_coord.random((n_frames, n_atoms * 3))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (n_frames, 1))) + np.save(sd / "energy.npy", rng_energy.random((n_frames,))) + return load_data(str(root))[0] + + +# --------------------------------------------------------------------------- +# set.* sort ordering +# --------------------------------------------------------------------------- + + +class TestSetDirSorting: + """dpdata preserves set.* numeric ordering during loading.""" + + def test_sorted_order_in_load_labels(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n") + (root / "type_map.raw").write_text("H\n") + + markers = {0: 0.0, 1: 1.0, 10: 10.0, 100: 100.0} + for idx, val in markers.items(): + sd = root / f"set.{idx:03d}" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((1, 3))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + np.save(sd / "energy.npy", np.array([val])) + + system = load_data(str(root))[0] + labels = _load_labels([system], "energy") + assert list(labels) == [0.0, 1.0, 10.0, 100.0], ( + f"Expected [0, 1, 10, 100], got {list(labels)}" + ) + + def test_sorted_order_in_load_npy_system(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n") + (root / "type_map.raw").write_text("H\n") + + for idx in [0, 1, 10, 100]: + sd = root / f"set.{idx:03d}" + sd.mkdir() + np.save(sd / "coord.npy", np.full((1, 3), float(idx))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + + system = load_data(str(root))[0] + coords, _, _ = _load_npy_system(system) + frame_values = coords[:, 0].tolist() # first atom, first coord axis + assert frame_values == [0.0, 1.0, 10.0, 100.0], ( + f"Expected [0, 1, 10, 100], got {frame_values}" + ) + + +# --------------------------------------------------------------------------- +# load_data +# --------------------------------------------------------------------------- + + +class TestLoadData: + def test_valid_system_returns_dpdata_system(self, tmp_path): + system = _make_system(tmp_path) + result = load_data(system) + assert len(result) == 1 + assert result[0] is system # passthrough, no copy + + def test_path_loads_dpdata_system(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n1\n") + (root / "type_map.raw").write_text("H\nO\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((2, 6))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (2, 1))) + + result = load_data(str(root)) + assert len(result) == 1 + import dpdata + + assert isinstance(result[0], dpdata.System) + + def test_list_of_systems(self, tmp_path): + s1 = _make_system(tmp_path, "a") + s2 = _make_system(tmp_path, "b") + result = load_data([s1, s2]) + assert len(result) == 2 + + def test_mixed_list_paths_and_objects(self, tmp_path): + s1 = _make_system(tmp_path, "a") + root = tmp_path / "b" + root.mkdir() + (root / "type.raw").write_text("0\n") + (root / "type_map.raw").write_text("H\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((2, 3))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (2, 1))) + + result = load_data([s1, str(root)]) + assert len(result) == 2 + + def test_nonexistent_path_raises(self, tmp_path): + with pytest.raises(DPADataError, match="does not exist"): + load_data(str(tmp_path / "ghost")) + + def test_passthrough_no_copy(self, tmp_path): + s = _make_system(tmp_path) + result = load_data(s) + assert result[0] is s + + +class TestGlob: + def test_mixed_files_and_dirs_fails_fast(self, tmp_path): + """Glob with deepmd/npy fmt must reject non-directory matches.""" + # Create a valid deepmd/npy directory + _make_system(tmp_path, "sys") + # Create a non-directory file + (tmp_path / "file.xyz").write_text("dummy") + + with pytest.raises(DPADataError, match="non-directory paths"): + load_data(str(tmp_path / "*")) + + def test_explicit_fmt_bypasses_precheck(self, tmp_path): + """With an explicit non-deepmd/npy fmt the pre-check is skipped.""" + (tmp_path / "file.xyz").write_text("6\n\nH 0 0 0\nO 1 1 1\n") + + with pytest.raises(DPADataError, match="Failed to load"): + # Not deepmd/npy → skips the directory pre-check, tries dpdata + load_data(str(tmp_path / "file.xyz"), fmt="xyz") + + +# --------------------------------------------------------------------------- +# attach_labels — _key_from_head +# --------------------------------------------------------------------------- + + +class TestKeyFromHead: + def test_string_head(self): + assert _key_from_head("energy") == "energy" + assert _key_from_head("bandgap") == "bandgap" + + def test_dict_with_property_name(self): + assert ( + _key_from_head( + {"type": "property", "property_name": "bandgap", "task_dim": 1} + ) + == "bandgap" + ) + assert _key_from_head({"property_name": "humo"}) == "humo" + + def test_dict_known_types(self): + assert _key_from_head({"type": "dos", "numb_dos": 250}) == "dos" + assert _key_from_head({"type": "dipole"}) == "dipole" + assert _key_from_head({"type": "polar"}) == "polar" + + def test_dict_unknown_type_raises_with_supported_list(self): + with pytest.raises(ValueError, match="Unknown dict head type 'forces'"): + _key_from_head({"type": "forces"}) + with pytest.raises(ValueError, match=r"dos.*dipole|dipole.*dos"): + _key_from_head({"type": "unknown_xyz"}) + + def test_dict_property_type_without_property_name_raises(self): + with pytest.raises(ValueError, match="property_name"): + _key_from_head({"type": "property", "task_dim": 1}) + + def test_dict_missing_both_keys_raises(self): + with pytest.raises(ValueError, match=r"property_name.*type"): + _key_from_head({"task_dim": 1}) + + def test_non_str_non_dict_raises(self): + with pytest.raises(TypeError, match="str or dict"): + _key_from_head(42) + + +def _make_system_path(tmp_path, name="sys", set_indices=(0,), n_atoms=2, n_frames=3): + """Create a minimal deepmd/npy system directory on disk (no dpdata loading). + + Returns the **Path** to the system root. + """ + root = tmp_path / name + root.mkdir() + (root / "type.raw").write_text("\n".join(str(i % 2) for i in range(n_atoms)) + "\n") + (root / "type_map.raw").write_text("H\nO\n") + for idx in set_indices: + sd = root / f"set.{idx:03d}" + sd.mkdir() + rng_coord = np.random.default_rng() + rng_energy = np.random.default_rng() + np.save(sd / "coord.npy", rng_coord.random((n_frames, n_atoms * 3))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (n_frames, 1))) + np.save(sd / "energy.npy", rng_energy.random((n_frames,))) + return root + + +class TestAttachLabels: + """Path-based attach_labels: single and multi-system.""" + + # ── single-system ──────────────────────────────────────────────────── + + def test_string_head_writes_npy(self, tmp_path): + sys_path = _make_system_path(tmp_path, name="sys", n_frames=3) + attach_labels(sys_path, head="bandgap", values=np.array([1.0, 2.0, 3.0])) + written = np.load(sys_path / "set.000" / "bandgap.npy") + np.testing.assert_array_equal(written, [1.0, 2.0, 3.0]) + + def test_dict_head_property_name(self, tmp_path): + sys_path = _make_system_path(tmp_path, name="sys", n_frames=3) + values = np.array([[1.0], [2.0], [3.0]]) + attach_labels( + sys_path, + head={"type": "property", "property_name": "gap", "task_dim": 1}, + values=values, + ) + written = np.load(sys_path / "set.000" / "gap.npy") + np.testing.assert_array_equal(written, values) + + def test_2d_values_written_correctly(self, tmp_path): + sys_path = _make_system_path(tmp_path, name="sys", n_frames=3) + values = np.arange(3 * 250, dtype=float).reshape(3, 250) + attach_labels(sys_path, head={"type": "dos", "numb_dos": 250}, values=values) + written = np.load(sys_path / "set.000" / "dos.npy") + assert written.shape == (3, 250) + np.testing.assert_array_equal(written, values) + + def test_frame_count_mismatch_raises(self, tmp_path): + sys_path = _make_system_path(tmp_path, name="sys", n_frames=3) + with pytest.raises(ValueError, match="frames"): + attach_labels(sys_path, head="energy", values=np.array([1.0, 2.0])) + + def test_same_key_overwrites(self, tmp_path): + sys_path = _make_system_path(tmp_path, name="sys", n_frames=3) + attach_labels(sys_path, head="energy", values=np.array([1.0, 2.0, 3.0])) + attach_labels(sys_path, head="energy", values=np.array([9.0, 8.0, 7.0])) + written = np.load(sys_path / "set.000" / "energy.npy") + np.testing.assert_array_equal(written, [9.0, 8.0, 7.0]) + + def test_different_keys_are_additive(self, tmp_path): + sys_path = _make_system_path(tmp_path, name="sys", n_frames=3) + attach_labels(sys_path, head="energy", values=np.array([1.0, 2.0, 3.0])) + attach_labels(sys_path, head="bandgap", values=np.array([4.0, 5.0, 6.0])) + e_written = np.load(sys_path / "set.000" / "energy.npy") + b_written = np.load(sys_path / "set.000" / "bandgap.npy") + np.testing.assert_array_equal(e_written, [1.0, 2.0, 3.0]) + np.testing.assert_array_equal(b_written, [4.0, 5.0, 6.0]) + + def test_multi_set_not_implemented(self, tmp_path): + sys_path = _make_system_path( + tmp_path, name="sys", set_indices=(0, 1), n_frames=3 + ) + with pytest.raises(NotImplementedError, match="Multiple set"): + attach_labels(sys_path, head="energy", values=np.array([1.0, 2.0, 3.0])) + + def test_no_set_dir_raises(self, tmp_path): + empty = tmp_path / "empty" + empty.mkdir() + (empty / "type.raw").write_text("0\n") + with pytest.raises(ValueError, match="No set"): + attach_labels(empty, head="energy", values=np.array([1.0])) + + def test_path_is_file_raises(self, tmp_path): + f = tmp_path / "not_a_dir" + f.write_text("dummy") + with pytest.raises(ValueError, match="not a directory"): + attach_labels(f, head="energy", values=np.array([1.0])) + + def test_coord_npy_missing_raises(self, tmp_path): + sys_path = _make_system_path(tmp_path, name="sys", n_frames=3) + (sys_path / "set.000" / "coord.npy").unlink() + with pytest.raises(ValueError, match=r"coord\.npy not found"): + attach_labels(sys_path, head="energy", values=np.array([1.0, 2.0, 3.0])) + + # ── multi-system ───────────────────────────────────────────────────── + + def test_multi_system_all_written(self, tmp_path): + parent = tmp_path / "multi" + parent.mkdir() + for i in range(3): + _make_system_path(parent, name=f"sys_{i:04d}", n_frames=2) + values = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + attach_labels(parent, head="bandgap", values=values) + for i in range(3): + written = np.load(parent / f"sys_{i:04d}" / "set.000" / "bandgap.npy") + np.testing.assert_array_equal(written, values[i]) + + def test_multi_system_1d_values_written_as_one_frame_labels(self, tmp_path): + parent = tmp_path / "multi" + parent.mkdir() + for i in range(3): + _make_system_path(parent, name=f"sys_{i:04d}", n_frames=1) + values = np.array([1.0, 3.0, 5.0]) + attach_labels(parent, head="bandgap", values=values) + for i in range(3): + written = np.load(parent / f"sys_{i:04d}" / "set.000" / "bandgap.npy") + np.testing.assert_array_equal(written, [values[i]]) + + def test_multi_system_3d_values_multidim_labels(self, tmp_path): + # Multi-dim labels in multi-system mode use 3-D values of shape + # (n_systems, n_frames, dim); each system's (n_frames, dim) slice is + # written verbatim. This is the unambiguous path for a dim-vector + # label on a one-frame system. + parent = tmp_path / "multi" + parent.mkdir() + for i in range(3): + _make_system_path(parent, name=f"sys_{i:04d}", n_frames=1) + values = np.arange(3 * 1 * 4, dtype=float).reshape(3, 1, 4) + attach_labels(parent, head={"type": "dos", "numb_dos": 4}, values=values) + for i in range(3): + written = np.load(parent / f"sys_{i:04d}" / "set.000" / "dos.npy") + assert written.shape == (1, 4) + np.testing.assert_array_equal(written, values[i]) + + def test_multi_system_2d_values_are_per_frame_not_per_dim(self, tmp_path): + # A 2-D (n_systems, k) array means k frames per system (one scalar + # label per frame) — NOT a single k-dim label per system. On one-frame + # systems it therefore mismatches and must raise; the multi-dim case is + # served by 3-D (n_systems, 1, dim), see the test above. + parent = tmp_path / "multi" + parent.mkdir() + for i in range(2): + _make_system_path(parent, name=f"sys_{i:04d}", n_frames=1) + with pytest.raises(ValueError, match="frames"): + attach_labels( + parent, + head="dos", + values=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + ) + + def test_multi_system_values_mismatch_raises(self, tmp_path): + parent = tmp_path / "multi" + parent.mkdir() + _make_system_path(parent, name="sys_0000", n_frames=2) + _make_system_path(parent, name="sys_0001", n_frames=2) + with pytest.raises(ValueError, match="entries along the first axis"): + attach_labels( + parent, + head="bandgap", + values=np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + ) + + def test_multi_system_no_subdirs_raises(self, tmp_path): + empty = tmp_path / "empty" + empty.mkdir() + with pytest.raises(ValueError, match=r"No set.* directories or system"): + attach_labels(empty, head="energy", values=np.array([1.0])) + + def test_multi_system_hidden_dirs_ignored(self, tmp_path): + parent = tmp_path / "multi" + parent.mkdir() + _make_system_path(parent, name="sys_0000", n_frames=2) + (parent / ".hidden").mkdir() + values = np.array([[1.0, 2.0]]) + attach_labels(parent, head="bandgap", values=values) + written = np.load(parent / "sys_0000" / "set.000" / "bandgap.npy") + np.testing.assert_array_equal(written, [1.0, 2.0]) + + +# --------------------------------------------------------------------------- +# _load_labels — custom label key fallback +# --------------------------------------------------------------------------- + + +class TestLoadLabelsCustomKey: + """_load_labels falls back to set.*/key.npy when key not in dpdata's store.""" + + def test_custom_label_key_loaded_from_npy(self, tmp_path): + """target_key="property" loads set.000/property.npy directly.""" + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n1\n") + (root / "type_map.raw").write_text("H\nO\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((3, 6))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (3, 1))) + # Custom label — NOT loaded by dpdata into system.data + np.save(sd / "property.npy", np.array([10.0, 20.0, 30.0])) + + [system] = load_data(str(root)) + assert "property" not in system.data + + labels = _load_labels([system], "property") + np.testing.assert_array_equal(labels, [10.0, 20.0, 30.0]) + + def test_custom_key_not_found_raises_clear_error(self, tmp_path): + """When neither dpdata nor set.*/key.npy has the key, error lists both.""" + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n") + (root / "type_map.raw").write_text("H\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((2, 3))) + np.save(sd / "box.npy", np.tile(np.eye(3).ravel(), (2, 1))) + + [system] = load_data(str(root)) + + with pytest.raises(DPADataError, match="nonexistent"): + _load_labels([system], "nonexistent") diff --git a/source/tests/dpa_adapt/test_mft_config.py b/source/tests/dpa_adapt/test_mft_config.py new file mode 100644 index 0000000000..03b4a8e77e --- /dev/null +++ b/source/tests/dpa_adapt/test_mft_config.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +from typing import ( + ClassVar, +) + +import pytest + +from dpa_adapt.config.manager import ( + MFTConfigManager, +) +from dpa_adapt.mft import ( + MFTFineTuner, +) + + +class FakeTuner: + pretrained = "/share/DPA-3.1-3M.pt" + aux_branch = "MP_traj_v024_alldata_mixu" + aux_prob = 0.5 + type_map: ClassVar[list[str]] = ["Cu", "O"] + fitting_net_params: ClassVar[dict[str, object]] = { + "type": "ener", + "neuron": [240, 240, 240], + } + downstream_task_type = "ener" + learning_rate = 1e-3 + stop_lr = 1e-5 + max_steps = 1000 + batch_size = "auto:32" + seed = 42 + output_dir = "/tmp/mft_test" + save_freq = 500 + disp_freq = 100 + train_data = "/data/downstream" + aux_data = "/data/aux" + valid_data = None + + +def test_build_has_model_dict(): + config = MFTConfigManager(FakeTuner()).build() + assert "model_dict" in config["model"] + assert "shared_dict" in config["model"] + + +def test_aux_branch_key_present(): + config = MFTConfigManager(FakeTuner()).build() + assert "MP_traj_v024_alldata_mixu" in config["model"]["model_dict"] + assert "DOWNSTREAM" in config["model"]["model_dict"] + + +def test_finetune_head_correct(): + config = MFTConfigManager(FakeTuner()).build() + downstream = config["model"]["model_dict"]["DOWNSTREAM"] + assert downstream["finetune_head"] == "MP_traj_v024_alldata_mixu" + + +def test_model_prob_values(): + config = MFTConfigManager(FakeTuner()).build() + prob = config["training"]["model_prob"] + assert prob["MP_traj_v024_alldata_mixu"] == 0.5 + assert prob["DOWNSTREAM"] == 1.0 + + +def test_data_dict_paths(): + config = MFTConfigManager(FakeTuner()).build() + dd = config["training"]["data_dict"] + assert dd["MP_traj_v024_alldata_mixu"]["training_data"]["systems"] == ["/data/aux"] + assert dd["DOWNSTREAM"]["training_data"]["systems"] == ["/data/downstream"] + + +def test_training_save_ckpt_under_output_dir(): + """save_ckpt pins the checkpoint prefix under output_dir, so DeePMD writes + model.ckpt-*.pt where _freeze_ckpt()/evaluate()/predict() later look. + """ + config = MFTConfigManager(FakeTuner()).build() + assert config["training"]["save_ckpt"] == os.path.join( + "/tmp/mft_test", "model.ckpt" + ) + + +def test_mft_delegate_preserves_omitted_type_map_as_none(): + """DPAFineTuner(strategy='mft') without type_map must hand None (not []) to + the MFT delegate, so it auto-detects the type_map from the checkpoint. + """ + from dpa_adapt.finetuner import ( + DPAFineTuner, + ) + + ft = DPAFineTuner(strategy="mft", property_name="homo") + assert ft.type_map == [] # frozen-sklearn path keeps a concrete list + assert ft._ensure_mft().type_map is None # MFT delegate gets None + + +def test_mft_fit_type_map_updates_delegate(monkeypatch): + """fit(..., type_map=...) must override the constructor MFT type_map.""" + from dpa_adapt.finetuner import ( + DPAFineTuner, + ) + + ft = DPAFineTuner(strategy="mft", property_name="homo", type_map=["Cu", "O"]) + ft._ensure_mft() + + captured = {} + + def _fake_fit_mft(self, train_data, aux_data, valid_data=None): + captured["self_type_map"] = self.type_map + captured["delegate_type_map"] = self._mft.type_map + return self.output_dir + + monkeypatch.setattr(DPAFineTuner, "_fit_mft", _fake_fit_mft) + + result = ft.fit( + "train", + aux_data="aux", + type_map=["H", "C", "N", "O"], + ) + + assert result == ft.output_dir + assert captured["self_type_map"] == ["H", "C", "N", "O"] + assert captured["delegate_type_map"] == ["H", "C", "N", "O"] + + +def test_aux_fitting_net_is_ener(): + config = MFTConfigManager(FakeTuner()).build() + fn = config["model"]["model_dict"]["MP_traj_v024_alldata_mixu"]["fitting_net"] + assert fn["type"] == "ener" + + +def test_build_cmd_flags(): + cm = MFTConfigManager(FakeTuner()) + cmd = cm.build_cmd("input.json") + assert "--use-pretrain-script" not in cmd + assert "--model-branch" not in cmd + assert "--finetune" in cmd + assert cmd[cmd.index("--finetune") + 1] == "/share/DPA-3.1-3M.pt" + assert "--skip-neighbor-stat" in cmd + + +def test_descriptor_has_repflow_params(): + config = MFTConfigManager(FakeTuner()).build() + desc = config["model"]["shared_dict"]["dpa3_descriptor"] + assert desc["type"] == "dpa3" + assert "repflow" in desc + rf = desc["repflow"] + assert rf["n_dim"] == 128 + assert rf["e_dim"] == 64 + assert rf["a_dim"] == 32 + assert rf["nlayers"] == 16 + assert rf["e_rcut"] == 6.0 + assert rf["a_rcut"] == 4.0 + assert desc["activation_function"] == "custom_silu:3.0" + assert desc["precision"] == "float32" + + +def test_systems_accepts_list(): + t = FakeTuner() + t.train_data = ["/data/d1", "/data/d2"] + t.aux_data = ["/data/a1", "/data/a2", "/data/a3"] + config = MFTConfigManager(t).build() + dd = config["training"]["data_dict"] + assert dd["DOWNSTREAM"]["training_data"]["systems"] == ["/data/d1", "/data/d2"] + assert dd["MP_traj_v024_alldata_mixu"]["training_data"]["systems"] == [ + "/data/a1", + "/data/a2", + "/data/a3", + ] + + +def test_type_map_in_shared_dict(): + config = MFTConfigManager(FakeTuner()).build() + shared = config["model"]["shared_dict"] + assert "type_map" in shared + assert isinstance(shared["type_map"], list) + assert shared["type_map"] == ["Cu", "O"] + + +def test_branch_type_map_is_string(): + config = MFTConfigManager(FakeTuner()).build() + md = config["model"]["model_dict"] + assert md["MP_traj_v024_alldata_mixu"]["type_map"] == "type_map" + assert md["DOWNSTREAM"]["type_map"] == "type_map" + + +def test_data_dict_has_training_data(): + config = MFTConfigManager(FakeTuner()).build() + dd = config["training"]["data_dict"] + assert "training_data" in dd["MP_traj_v024_alldata_mixu"] + assert "training_data" in dd["DOWNSTREAM"] + + +def test_no_validation_data_when_absent(): + config = MFTConfigManager(FakeTuner()).build() + dd = config["training"]["data_dict"] + assert "validation_data" not in dd["DOWNSTREAM"] + + +def test_validation_data_written_to_downstream_branch(): + t = FakeTuner() + t.valid_data = ["/data/valid1", "/data/valid2"] + config = MFTConfigManager(t).build() + downstream = config["training"]["data_dict"]["DOWNSTREAM"] + assert downstream["validation_data"] == { + "systems": ["/data/valid1", "/data/valid2"], + "batch_size": "auto:32", + } + + +def test_aux_prob_out_of_range_raises(): + t = FakeTuner() + t.aux_prob = 1.5 + with pytest.raises(ValueError, match="aux_prob"): + MFTConfigManager(t).build() + + +def test_fitting_net_params_used(): + config = MFTConfigManager(FakeTuner()).build() + md = config["model"]["model_dict"] + assert md["MP_traj_v024_alldata_mixu"]["fitting_net"] == { + "type": "ener", + "neuron": [240, 240, 240], + } + assert md["DOWNSTREAM"]["fitting_net"] == { + "type": "ener", + "neuron": [240, 240, 240], + } + + +def test_fitting_net_default_when_none(): + t = FakeTuner() + t.fitting_net_params = None + config = MFTConfigManager(t).build() + md = config["model"]["model_dict"] + assert md["MP_traj_v024_alldata_mixu"]["fitting_net"] == {"type": "ener"} + assert md["DOWNSTREAM"]["fitting_net"] == {"type": "ener"} + + +# --- MFTFineTuner.__init__ auto-reading fitting_net from checkpoint ---------- + + +def _fake_sd(branches): + """Build a minimal state_dict mirroring the real checkpoint layout.""" + return { + "model": { + "_extra_state": { + "model_params": { + "model_dict": { + name: {"fitting_net": fn} for name, fn in branches.items() + } + } + } + } + } + + +def test_explicit_fitting_net_params_skips_ckpt_load(monkeypatch): + """Backward compat: when user supplies fitting_net_params, the + checkpoint is not touched and the user's value is kept verbatim. + """ + import torch + + def _explode(*args, **kwargs): + raise AssertionError( + "torch.load must not be called when fitting_net_params is provided" + ) + + monkeypatch.setattr(torch, "load", _explode) + + custom = {"type": "ener", "neuron": [123, 456], "resnet_dt": True} + t = MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="Domains_Alloy", + property_name="homo", + fitting_net_params=custom, + ) + assert t.fitting_net_params == custom + + +def test_fitting_net_params_auto_read_from_ckpt(monkeypatch): + """When fitting_net_params is omitted, MFTFineTuner pulls it out of the + checkpoint at the documented nested path. + """ + import torch + + expected = {"type": "ener", "neuron": [240, 240, 240], "resnet_dt": True} + fake = _fake_sd( + { + "Domains_Alloy": expected, + "MP_traj_v024_alldata_mixu": {"type": "ener", "neuron": [120, 120]}, + } + ) + monkeypatch.setattr(torch, "load", lambda *a, **kw: fake) + + t = MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="Domains_Alloy", + property_name="homo", + ) + assert t.fitting_net_params == expected + + +class TestAutoTypeMap: + """When type_map is not provided, MFTFineTuner auto-detects it from the + checkpoint and validates data type_maps. + """ + + def _fake_ckpt_sd(self, type_map=None): + """Minimal DPA-3.1-3M-like state_dict with a shared type_map.""" + if type_map is None: + type_map = ["H", "He", "Li", "Be", "B", "C", "N", "O"] + return { + "model": { + "_extra_state": { + "model_params": { + "shared_dict": { + "dpa3_descriptor": {"type": "dpa3"}, + "type_map": type_map, + }, + "model_dict": { + "Domains_Alloy": { + "fitting_net": {"type": "ener"}, + }, + }, + } + } + } + } + + def test_validate_and_resolve_sets_type_map(self, monkeypatch, tmp_path): + """_validate_and_resolve_type_map reads checkpoint type_map.""" + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: self._fake_ckpt_sd(), + ) + + t = MFTFineTuner( + pretrained="/fake.pt", + aux_branch="Domains_Alloy", + property_name="homo", + ) + assert t.type_map is None + + t._validate_and_resolve_type_map(str(tmp_path), str(tmp_path)) + assert t.type_map == ["H", "He", "Li", "Be", "B", "C", "N", "O"] + + def test_config_has_nonempty_type_map(self, monkeypatch): + """Generated mft_input.json must have a non-empty global type_map + when the user does not pass one explicitly. + """ + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: self._fake_ckpt_sd(), + ) + + t = MFTFineTuner( + pretrained="/fake.pt", + aux_branch="Domains_Alloy", + property_name="homo", + ) + t.train_data = "/data/downstream" + t.aux_data = "/data/aux" + t._validate_and_resolve_type_map(t.train_data, t.aux_data) + + config = MFTConfigManager(t).build() + shared = config["model"]["shared_dict"] + assert "type_map" in shared + assert isinstance(shared["type_map"], list) + assert len(shared["type_map"]) == 8 + assert shared["type_map"][0] == "H" + # Must NOT be empty — empty [] causes CUDA gather out-of-bounds + assert shared["type_map"] != [] + + def test_explicit_type_map_still_respected(self, monkeypatch): + """When user passes type_map explicitly, it is used verbatim.""" + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: self._fake_ckpt_sd(), + ) + + t = MFTFineTuner( + pretrained="/fake.pt", + aux_branch="Domains_Alloy", + property_name="homo", + type_map=["Cu", "O"], + ) + t.train_data = "/data/downstream" + t.aux_data = "/data/aux" + + config = MFTConfigManager(t).build() + shared = config["model"]["shared_dict"] + assert shared["type_map"] == ["Cu", "O"] + + def test_data_type_map_validated_against_checkpoint(self, monkeypatch, tmp_path): + """If data type_map.raw contains elements not in the checkpoint, + _validate_and_resolve_type_map raises ValueError. + """ + import numpy as np + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: self._fake_ckpt_sd(), + ) + + t = MFTFineTuner( + pretrained="/fake.pt", + aux_branch="Domains_Alloy", + property_name="homo", + ) + + # Create a system with an unsupported element + sysdir = tmp_path / "sys" + sysdir.mkdir() + (sysdir / "type.raw").write_text("0\n1\n") + (sysdir / "type_map.raw").write_text("Pu\nU\n") + sd = sysdir / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((1, 6))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + + with pytest.raises(ValueError, match="Pu"): + t._validate_and_resolve_type_map(str(sysdir), str(tmp_path)) + + +def test_unknown_aux_branch_raises_with_branch_list(monkeypatch): + """If aux_branch is not in the checkpoint, the error names the bad + branch and lists what IS available. With lazy loading the error is + raised on first access to ``fitting_net_params``, not at construction. + """ + import torch + + fake = _fake_sd( + { + "Domains_Alloy": {"type": "ener"}, + "MP_traj_v024_alldata_mixu": {"type": "ener"}, + "Omat24": {"type": "ener"}, + } + ) + monkeypatch.setattr(torch, "load", lambda *a, **kw: fake) + + t = MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="NotARealBranch", + property_name="homo", + ) + with pytest.raises(ValueError) as exc_info: + _ = t.fitting_net_params # triggers lazy load + msg = str(exc_info.value) + assert "NotARealBranch" in msg + assert "Domains_Alloy" in msg + assert "MP_traj_v024_alldata_mixu" in msg + assert "Omat24" in msg diff --git a/source/tests/dpa_adapt/test_mft_evaluate.py b/source/tests/dpa_adapt/test_mft_evaluate.py new file mode 100644 index 0000000000..174afec87b --- /dev/null +++ b/source/tests/dpa_adapt/test_mft_evaluate.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for dpa_adapt.mft.MFTFineTuner.evaluate output parsing and pipeline.""" + +from __future__ import ( + annotations, +) + +import os +from pathlib import ( + Path, +) +from unittest.mock import ( + patch, +) + +import pytest + +from dpa_adapt.mft import ( + MFTFineTuner, +) + +DUMMY_TYPE_MAP = ["H", "C", "N", "O"] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_systems(tmp_path, prefix: str, n: int) -> str: + """Create n empty system dirs and return a glob pattern matching them.""" + root = tmp_path / prefix + root.mkdir(parents=True, exist_ok=True) + for i in range(n): + (root / f"sys_{i:03d}").mkdir() + return str(root / "sys_*") + + +def _make_finetuner(tmp_path, max_steps=100): + """ + Build an MFTFineTuner without going through __init__'s ckpt-reading path. + We bypass __init__ because fitting_net auto-read calls torch.load on the + pretrained ckpt, which we don't have in unit tests. + """ + ft = MFTFineTuner.__new__(MFTFineTuner) + ft.pretrained = str(tmp_path / "dummy.pt") + ft.aux_branch = "SPICE2" + ft.aux_prob = 0.5 + ft.type_map = DUMMY_TYPE_MAP + ft.fitting_net_params = {} + # Paper property-mode evaluation: downstream head is named "property". + ft.downstream_task_type = "property" + ft.property_name = "homo" + ft.task_dim = 1 + ft.intensive = True + ft.learning_rate = 1e-3 + ft.stop_lr = 1e-5 + ft.max_steps = max_steps + ft.batch_size = "auto:32" + ft.seed = 42 + ft.output_dir = str(tmp_path / "out") + ft.save_freq = 10 + ft.disp_freq = 10 + ft.train_data = None + ft.aux_data = None + ft.valid_data = None + os.makedirs(ft.output_dir, exist_ok=True) + return ft + + +# --------------------------------------------------------------------------- +# Parser: real DeePMD-kit 3.1.3 output shape +# --------------------------------------------------------------------------- + + +def test_parse_real_dp_output_shape(): + """The real `dp --pt test` output prints both 'Energy MAE' (per-molecule) + and 'Energy MAE/Natoms' (per-atom). The parser must pick only the + per-molecule one. + """ + stdout = ( + "[2026-05-19 INFO] # number of test data : 1000\n" + "[2026-05-19 INFO] Energy MAE : 4.314543e-02 eV\n" + "[2026-05-19 INFO] Energy MAE/Natoms : 3.318879e-03 eV\n" + "[2026-05-19 INFO] Energy RMSE : 6.000000e-02 eV\n" + "[2026-05-19 INFO] Energy RMSE/Natoms : 4.500000e-03 eV\n" + ) + out = MFTFineTuner._parse_test_output(stdout) + assert out["mae"] == pytest.approx(4.314543e-02) + assert out["rmse"] == pytest.approx(6.000000e-02) + + +def test_parse_excludes_natoms_variant_explicitly(): + """If only the /Natoms variant appears, the parser should NOT match it. + This guards against a regex that accidentally allows /Natoms through. + """ + stdout = ( + "[INFO] Energy MAE/Natoms : 1.234567e-03 eV\n" + "[INFO] Energy RMSE/Natoms : 2.345678e-03 eV\n" + ) + with pytest.raises(RuntimeError, match="Could not parse"): + MFTFineTuner._parse_test_output(stdout) + + +# --------------------------------------------------------------------------- +# Parser: weighted-average behavior (must take LAST match) +# --------------------------------------------------------------------------- + + +def test_parse_takes_weighted_average_last_match(): + """Dp --pt test prints per-system blocks followed by a + 'weighted average of errors' block. Parser must return the weighted + average (the LAST occurrence), not the first per-system value. + """ + stdout = ( + "[INFO] # ---------------system 0--------------\n" + "[INFO] Energy MAE : 1.00e-01 eV\n" + "[INFO] Energy RMSE : 2.00e-01 eV\n" + "[INFO] # ---------------system 1--------------\n" + "[INFO] Energy MAE : 5.00e-01 eV\n" + "[INFO] Energy RMSE : 6.00e-01 eV\n" + "[INFO] # ----------weighted average of errors-----------\n" + "[INFO] Energy MAE : 3.50e-01 eV\n" + "[INFO] Energy RMSE : 4.50e-01 eV\n" + ) + out = MFTFineTuner._parse_test_output(stdout) + # Must be the weighted-average (final) values. + assert out["mae"] == pytest.approx(3.50e-01) + assert out["rmse"] == pytest.approx(4.50e-01) + + +# --------------------------------------------------------------------------- +# Parser: n_systems extraction +# --------------------------------------------------------------------------- + + +def test_parse_extracts_n_systems(): + stdout = ( + "[INFO] # number of systems : 7\n" + "[INFO] Energy MAE : 1.00e-02 eV\n" + "[INFO] Energy RMSE : 2.00e-02 eV\n" + ) + out = MFTFineTuner._parse_test_output(stdout) + assert out["n_systems"] == 7 + + +def test_parse_n_systems_falls_back_to_resolved_count(): + """If the 'number of systems' line is missing, fall back to the count of + resolved system paths so the caller still gets a usable number. + """ + stdout = "[INFO] Energy MAE : 1.00e-02 eV\n[INFO] Energy RMSE : 2.00e-02 eV\n" + out = MFTFineTuner._parse_test_output(stdout, n_resolved=42) + assert out["n_systems"] == 42 + + +# --------------------------------------------------------------------------- +# Parser: failure mode (was previously silent NaN — must now raise) +# --------------------------------------------------------------------------- + + +def test_parse_failure_raises_runtimeerror(): + """When dp test produced no Energy MAE/RMSE lines (the Bug-1 all-zero + failure mode), raise RuntimeError instead of silently returning NaN. + """ + stdout = "no MAE or RMSE lines here, just garbage" + with pytest.raises(RuntimeError) as exc_info: + MFTFineTuner._parse_test_output(stdout) + msg = str(exc_info.value) + assert "Could not parse" in msg + # Tail should be included for diagnostics. + assert "garbage" in msg + + +def test_parse_failure_includes_tail_of_output(): + """Long unparseable input: tail of last 100 lines must appear in the + error message so the user can diagnose without grepping logs. + """ + lines = [f"line_{i}" for i in range(200)] + stdout = "\n".join(lines) + with pytest.raises(RuntimeError) as exc_info: + MFTFineTuner._parse_test_output(stdout) + msg = str(exc_info.value) + # Last line should appear; very early lines should be trimmed. + assert "line_199" in msg + assert "line_0\n" not in msg + + +# --------------------------------------------------------------------------- +# Parser: scientific notation handling +# --------------------------------------------------------------------------- + + +def test_parse_scientific_notation(): + stdout = ( + "[INFO] Energy MAE : 4.314543e-02 eV\n" + "[INFO] Energy RMSE : 1.23E+01 eV\n" + ) + out = MFTFineTuner._parse_test_output(stdout) + assert out["mae"] == pytest.approx(4.314543e-02) + assert out["rmse"] == pytest.approx(1.23e01) + + +# --------------------------------------------------------------------------- +# Parser: property-mode output (PROPERTY MAE / PROPERTY RMSE) +# --------------------------------------------------------------------------- + + +def test_parse_property_output_weighted_average(): + """Property-task dp test prints per-system blocks then a + 'weighted average of errors' block. Parser must return the LAST match. + """ + stdout = ( + "[INFO] # ---------------system 0--------------\n" + "[INFO] PROPERTY MAE : 2.395307e-03 units\n" + "[INFO] PROPERTY RMSE : 2.395307e-03 units\n" + "[INFO] # ---------------system 1--------------\n" + "[INFO] PROPERTY MAE : 1.500000e-03 units\n" + "[INFO] PROPERTY RMSE : 1.500000e-03 units\n" + "[INFO] # ----------weighted average of errors----------- \n" + "[INFO] # number of systems : 291\n" + "[INFO] PROPERTY MAE : 1.972088e-03 units\n" + "[INFO] PROPERTY RMSE : 2.837059e-03 units\n" + ) + out = MFTFineTuner._parse_test_output(stdout) + assert out["mae"] == pytest.approx(1.972088e-03) + assert out["rmse"] == pytest.approx(2.837059e-03) + assert out["n_systems"] == 291 + assert "PROPERTY" in out["_parser_pattern_used"] + + +def test_parse_property_scientific_notation(): + stdout = ( + "[INFO] PROPERTY MAE : 1.23e-04 units\n[INFO] PROPERTY RMSE : 5.67E+02 units\n" + ) + out = MFTFineTuner._parse_test_output(stdout) + assert out["mae"] == pytest.approx(1.23e-04) + assert out["rmse"] == pytest.approx(5.67e02) + + +def test_parse_property_n_systems_extraction(): + stdout = ( + "[INFO] # number of systems : 42\n" + "[INFO] PROPERTY MAE : 0.01 units\n" + "[INFO] PROPERTY RMSE : 0.02 units\n" + ) + out = MFTFineTuner._parse_test_output(stdout) + assert out["n_systems"] == 42 + + +def test_parse_property_n_systems_fallback(): + stdout = "[INFO] PROPERTY MAE : 0.01 units\n[INFO] PROPERTY RMSE : 0.02 units\n" + out = MFTFineTuner._parse_test_output(stdout, n_resolved=99) + assert out["n_systems"] == 99 + + +# --------------------------------------------------------------------------- +# evaluate(): end-to-end pipeline with mocked subprocess +# --------------------------------------------------------------------------- + + +def test_evaluate_freezes_then_tests(tmp_path): + """evaluate() must (a) call dp freeze first to produce frozen .pth, + (b) then call dp test with -m pointing to that .pth, (c) parse output. + """ + ft = _make_finetuner(tmp_path, max_steps=100) + # Pretend training produced a ckpt + (Path(ft.output_dir) / "model.ckpt-100.pt").write_bytes(b"") + test_glob = _make_systems(tmp_path, "test_sys", 5) + + canned_test_output = ( + "[INFO] # number of systems : 5\n" + "[INFO] # number of test data : 50\n" + "[INFO] Energy MAE : 1.234567e-02 eV\n" + "[INFO] Energy MAE/Natoms : 9.876543e-04 eV\n" + "[INFO] Energy RMSE : 2.345678e-02 eV\n" + "[INFO] Energy RMSE/Natoms : 1.234567e-03 eV\n" + ) + + calls = [] + + class _Result: + def __init__(self, stdout="", stderr="", rc=0): + self.stdout = stdout + self.stderr = stderr + self.returncode = rc + + def _fake_run(cmd, *args, **kwargs): + calls.append({"cmd": cmd, "kwargs": kwargs}) + # First call is freeze; simulate by creating frozen.pth. + if "freeze" in cmd: + cwd = kwargs.get("cwd", ".") + Path(cwd, "frozen_property.pth").write_bytes(b"") + return _Result(stdout="frozen ok", stderr="", rc=0) + # Second call is dp test + return _Result(stdout="", stderr=canned_test_output, rc=0) + + with patch("subprocess.run", side_effect=_fake_run): + out = ft.evaluate(test_glob) + + # 1. freeze was called first with cwd=output_dir + assert len(calls) == 2 + assert isinstance(calls[0]["cmd"], list) + assert "freeze" in calls[0]["cmd"] + assert calls[0]["cmd"][calls[0]["cmd"].index("--head") + 1] == "property" + assert calls[0]["kwargs"].get("cwd") == ft.output_dir + + # 2. dp test was called with frozen .pth via -m, list-form cmd + test_cmd = calls[1]["cmd"] + assert isinstance(test_cmd, list) + m_idx = test_cmd.index("-m") + assert test_cmd[m_idx + 1].endswith("frozen_property.pth") + assert "-f" in test_cmd + assert "-s" not in test_cmd + + # 3. Parsed values are per-molecule MAE/RMSE, not /Natoms. + assert out["mae"] == pytest.approx(1.234567e-02) + assert out["rmse"] == pytest.approx(2.345678e-02) + assert out["n_systems"] == 5 + + +def test_evaluate_skips_freeze_if_pth_exists(tmp_path): + """If frozen_property.pth already exists, do NOT call dp freeze again.""" + ft = _make_finetuner(tmp_path, max_steps=100) + (Path(ft.output_dir) / "model.ckpt-100.pt").write_bytes(b"") + (Path(ft.output_dir) / "frozen_property.pth").write_bytes(b"") + test_glob = _make_systems(tmp_path, "test_skip", 3) + + canned = ( + "[INFO] # number of systems : 3\n" + "[INFO] Energy MAE : 5.0e-03 eV\n" + "[INFO] Energy RMSE : 6.0e-03 eV\n" + ) + + calls = [] + + class _Result: + stdout = "" + stderr = canned + returncode = 0 + + def _fake_run(cmd, *args, **kwargs): + calls.append(cmd) + return _Result() + + with patch("subprocess.run", side_effect=_fake_run): + out = ft.evaluate(test_glob) + + assert len(calls) == 1, f"Expected only dp test, got {len(calls)} calls" + assert isinstance(calls[0], list) + assert calls[0][1:3] == ["--pt", "test"] + assert out["mae"] == pytest.approx(5.0e-03) + + +def test_evaluate_freeze_failure_raises(tmp_path): + """If dp freeze fails, evaluate() must raise RuntimeError with diagnostics + rather than proceeding into a doomed dp test. + """ + ft = _make_finetuner(tmp_path, max_steps=100) + (Path(ft.output_dir) / "model.ckpt-100.pt").write_bytes(b"") + test_glob = _make_systems(tmp_path, "test_fz_fail", 2) + + class _Result: + stdout = "freeze stdout" + stderr = "freeze failed: missing branch" + returncode = 1 + + with patch("subprocess.run", return_value=_Result()): + with pytest.raises(RuntimeError, match="freeze"): + ft.evaluate(test_glob) + + +def test_evaluate_accepts_single_path(tmp_path): + """A single non-glob string path should be written verbatim into the + datafile (single line) and passed via -f. + """ + ft = _make_finetuner(tmp_path, max_steps=100) + (Path(ft.output_dir) / "model.ckpt-100.pt").write_bytes(b"") + (Path(ft.output_dir) / "frozen_property.pth").write_bytes(b"") + + single = tmp_path / "single_sys" + single.mkdir() + test_data = str(single) + + canned = ( + "[INFO] # number of systems : 1\n" + "[INFO] Energy MAE : 7.0e-03 eV\n" + "[INFO] Energy RMSE : 8.0e-03 eV\n" + ) + + captured = {} + + class _Result: + stdout = "" + stderr = canned + returncode = 0 + + def _fake_run(cmd, *args, **kwargs): + captured["cmd"] = cmd + return _Result() + + with patch("subprocess.run", side_effect=_fake_run): + out = ft.evaluate(test_data) + + cmd = captured["cmd"] + f_idx = cmd.index("-f") + datafile = cmd[f_idx + 1] + with open(datafile) as f: + lines = [l for l in f.read().split("\n") if l.strip()] + assert lines == [test_data] + assert out["mae"] == pytest.approx(7.0e-03) + assert out["n_systems"] == 1 + + +def test_evaluate_accepts_list(tmp_path): + """A list of paths should be written one-per-line into the datafile.""" + ft = _make_finetuner(tmp_path, max_steps=100) + (Path(ft.output_dir) / "model.ckpt-100.pt").write_bytes(b"") + (Path(ft.output_dir) / "frozen_property.pth").write_bytes(b"") + + paths = [] + for i in range(4): + d = tmp_path / f"list_sys_{i}" + d.mkdir() + paths.append(str(d)) + + canned = ( + "[INFO] # number of systems : 4\n" + "[INFO] Energy MAE : 9.0e-03 eV\n" + "[INFO] Energy RMSE : 1.0e-02 eV\n" + ) + + captured = {} + + class _Result: + stdout = "" + stderr = canned + returncode = 0 + + def _fake_run(cmd, *args, **kwargs): + captured["cmd"] = cmd + return _Result() + + with patch("subprocess.run", side_effect=_fake_run): + out = ft.evaluate(paths) + + cmd = captured["cmd"] + datafile = cmd[cmd.index("-f") + 1] + with open(datafile) as f: + lines = [l for l in f.read().split("\n") if l.strip()] + assert lines == paths + assert out["n_systems"] == 4 + + +def test_evaluate_missing_ckpt_raises(tmp_path): + """If no model.ckpt-{max_steps}.pt exists and frozen.pth also missing, + _freeze_ckpt must raise rather than silently call freeze and explode. + """ + ft = _make_finetuner(tmp_path, max_steps=100) + test_glob = _make_systems(tmp_path, "test_no_ckpt", 2) + + with pytest.raises(RuntimeError, match="not found"): + ft.evaluate(test_glob) diff --git a/source/tests/dpa_adapt/test_mft_property_task.py b/source/tests/dpa_adapt/test_mft_property_task.py new file mode 100644 index 0000000000..e5fa5f3045 --- /dev/null +++ b/source/tests/dpa_adapt/test_mft_property_task.py @@ -0,0 +1,391 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for MFT downstream_task_type='property' branch. + +These cover the paper-faithful (arXiv:2601.08486) DOWNSTREAM=property +configuration: a fresh property fitting_net + property loss for the +downstream head, while the aux branch keeps its ener fitting_net pulled +from the ckpt. + +Back-compat: callers that don't pass downstream_task_type stay on the +legacy ener path (used by mp_data MFT sensitivity-analysis experiments). +""" + +from __future__ import ( + annotations, +) + +from typing import ( + ClassVar, +) + +import pytest + +from dpa_adapt.config.manager import ( + MFTConfigManager, +) +from dpa_adapt.mft import ( + MFTFineTuner, +) + + +class _FakePropertyTuner: + """Tuner-shaped object configured for downstream_task_type='property'. + + Bypasses MFTFineTuner.__init__ so tests don't need a real ckpt. + """ + + pretrained = "/share/DPA-3.1-3M.pt" + aux_branch = "SPICE2" + aux_prob = 0.5 + type_map: ClassVar[list[str]] = ["H", "C", "N", "O"] + # aux fitting_net pulled from ckpt — an ener config (the actual SPICE2 head) + fitting_net_params: ClassVar[dict[str, object]] = { + "type": "ener", + "neuron": [240, 240, 240], + } + downstream_task_type = "property" + property_name = "homo" + task_dim = 1 + intensive = True + learning_rate = 1e-3 + stop_lr = 1e-5 + max_steps = 1000 + batch_size = "auto:32" + seed = 42 + output_dir = "/tmp/mft_property_test" + save_freq = 500 + disp_freq = 100 + train_data = "/data/qm9_train" + aux_data = "/data/spice2" + valid_data = None + + +class _FakeEnerTuner: + """Legacy back-compat tuner. + + NO downstream_task_type attr at all — must still build a valid ener-mode + config (mp_data sensitivity callers construct tuners this way). + """ + + pretrained = "/share/DPA-3.1-3M.pt" + aux_branch = "MP_traj_v024_alldata_mixu" + aux_prob = 0.5 + type_map: ClassVar[list[str]] = ["Cu", "O"] + fitting_net_params: ClassVar[dict[str, object]] = { + "type": "ener", + "neuron": [240, 240, 240], + } + learning_rate = 1e-3 + stop_lr = 1e-5 + max_steps = 1000 + batch_size = "auto:32" + seed = 42 + output_dir = "/tmp/mft_ener_test" + save_freq = 500 + disp_freq = 100 + train_data = "/data/downstream" + aux_data = "/data/aux" + valid_data = None + + +# --------------------------------------------------------------------------- +# Property task: config shape +# --------------------------------------------------------------------------- + + +def test_property_task_config_has_property_fitting_net(): + """DOWNSTREAM fitting_net must be type='property' with the right + property_name / task_dim / intensive, NOT the aux ener fitting_net. + """ + config = MFTConfigManager(_FakePropertyTuner()).build() + fn = config["model"]["model_dict"]["property"]["fitting_net"] + assert fn["type"] == "property" + assert fn["property_name"] == "homo" + assert fn["task_dim"] == 1 + assert fn["intensive"] is True + assert fn["neuron"] == [240, 240, 240] + assert fn["activation_function"] == "tanh" + assert fn["seed"] == 42 + # Required for DPA-3.1-3M multi-task case-embedding layer. + assert fn["dim_case_embd"] == 31 + + +def test_property_task_config_has_property_loss(): + """DOWNSTREAM loss must be type='property' with mse + mae/rmse metrics.""" + config = MFTConfigManager(_FakePropertyTuner()).build() + loss = config["loss_dict"]["property"] + assert loss["type"] == "property" + assert loss["loss_func"] == "mse" + assert "mae" in loss["metric"] + assert "rmse" in loss["metric"] + + +def test_property_task_no_force_pref_in_loss(): + """The ener-task force/virial prefs MUST NOT leak into property loss. + This is the regression that made MFT/homo training useless: the loss + forced the model to predict zero forces against QM9 labels that don't + have forces. + """ + config = MFTConfigManager(_FakePropertyTuner()).build() + loss = config["loss_dict"]["property"] + for forbidden in ( + "start_pref_f", + "limit_pref_f", + "start_pref_v", + "limit_pref_v", + "start_pref_e", + "limit_pref_e", + ): + assert forbidden not in loss, ( + f"property loss must not contain {forbidden}; got loss={loss!r}" + ) + + +def test_property_task_no_property_name_in_loss(): + """Deepmd 3.1.3 strict-mode dargs rejects unknown keys inside loss_property. + + Property_name belongs on fitting_net, not loss. + (Verified empirically; see manager.py _build_property_loss docstring.) + """ + config = MFTConfigManager(_FakePropertyTuner()).build() + loss = config["loss_dict"]["property"] + assert "property_name" not in loss + + +# --------------------------------------------------------------------------- +# Property task: aux branch is unaffected +# --------------------------------------------------------------------------- + + +def test_property_task_aux_branch_keeps_ener_fitting_net(): + """The aux branch (SPICE2 force-field) must keep its ener fitting_net. + Only DOWNSTREAM gets the new property head. + """ + config = MFTConfigManager(_FakePropertyTuner()).build() + aux_fn = config["model"]["model_dict"]["SPICE2"]["fitting_net"] + assert aux_fn["type"] == "ener" + assert aux_fn == {"type": "ener", "neuron": [240, 240, 240]} + + +def test_property_task_aux_branch_keeps_ener_loss(): + """The aux branch loss must remain ener-style (it has forces+virials).""" + config = MFTConfigManager(_FakePropertyTuner()).build() + aux_loss = config["loss_dict"]["SPICE2"] + assert aux_loss["type"] == "ener" + assert "start_pref_f" in aux_loss + + +def test_property_task_extensive_property(): + """When intensive=False, the property head reflects that — extensive + properties like total dipole moment use sum-pool. + """ + + class _T(_FakePropertyTuner): + property_name = "total_dipole" + intensive = False + + config = MFTConfigManager(_T()).build() + fn = config["model"]["model_dict"]["property"]["fitting_net"] + assert fn["intensive"] is False + assert fn["property_name"] == "total_dipole" + + +def test_property_task_multidim_task_dim(): + """task_dim > 1 is honored (e.g. multitask HOMO+LUMO regression).""" + + class _T(_FakePropertyTuner): + task_dim = 2 + property_name = "homo_lumo" + + config = MFTConfigManager(_T()).build() + fn = config["model"]["model_dict"]["property"]["fitting_net"] + assert fn["task_dim"] == 2 + + +# --------------------------------------------------------------------------- +# Back-compat: ener mode is unchanged +# --------------------------------------------------------------------------- + + +def test_ener_task_unchanged_when_no_attr(): + """Tuners without downstream_task_type attr (existing mp_data callers) + must still get the legacy ener-mode config: DOWNSTREAM reuses the aux + fitting_net and gets an ener loss with force/virial prefs. + """ + config = MFTConfigManager(_FakeEnerTuner()).build() + md = config["model"]["model_dict"] + # DOWNSTREAM fitting_net == aux fitting_net (the legacy behavior) + assert ( + md["DOWNSTREAM"]["fitting_net"] + == md["MP_traj_v024_alldata_mixu"]["fitting_net"] + ) + assert md["DOWNSTREAM"]["fitting_net"]["type"] == "ener" + # ener loss with force/virial prefs + loss = config["loss_dict"]["DOWNSTREAM"] + assert loss["type"] == "ener" + assert loss["start_pref_f"] == 100 + assert loss["start_pref_v"] == 0.02 + + +def test_ener_task_explicit_attr_unchanged(): + """Explicitly setting downstream_task_type='ener' is equivalent to + not setting it at all. + """ + t = _FakeEnerTuner() + t.downstream_task_type = "ener" + config = MFTConfigManager(t).build() + md = config["model"]["model_dict"] + assert md["DOWNSTREAM"]["fitting_net"]["type"] == "ener" + assert config["loss_dict"]["DOWNSTREAM"]["type"] == "ener" + + +# --------------------------------------------------------------------------- +# MFTFineTuner.__init__: argument validation +# --------------------------------------------------------------------------- + + +def test_property_task_requires_property_name(monkeypatch): + """downstream_task_type='property' without property_name must raise.""" + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: { + "model": { + "_extra_state": { + "model_params": { + "model_dict": {"SPICE2": {"fitting_net": {"type": "ener"}}} + } + } + } + }, + ) + with pytest.raises(ValueError, match="property_name"): + MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="SPICE2", + downstream_task_type="property", + # property_name omitted on purpose + ) + + +def test_property_task_property_name_must_be_identifier(monkeypatch): + """property_name with slashes/spaces is rejected.""" + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: { + "model": { + "_extra_state": { + "model_params": { + "model_dict": {"SPICE2": {"fitting_net": {"type": "ener"}}} + } + } + } + }, + ) + with pytest.raises(ValueError, match="property_name"): + MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="SPICE2", + downstream_task_type="property", + property_name="homo lumo", # invalid identifier + ) + + +def test_invalid_downstream_task_type_raises(monkeypatch): + """Typos like 'properties' or 'energy' must raise immediately.""" + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: { + "model": { + "_extra_state": { + "model_params": { + "model_dict": {"SPICE2": {"fitting_net": {"type": "ener"}}} + } + } + } + }, + ) + with pytest.raises(ValueError, match="downstream_task_type"): + MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="SPICE2", + downstream_task_type="properties", # typo + ) + + +@pytest.mark.parametrize("aux_prob", [-0.1, 1.2, "not-a-number"]) +def test_aux_prob_must_be_probability(aux_prob): + """Invalid MFT branch probabilities must fail at construction time.""" + with pytest.raises(ValueError, match="aux_prob"): + MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="SPICE2", + downstream_task_type="property", + property_name="homo", + aux_prob=aux_prob, + ) + + +def test_property_task_stores_attrs(monkeypatch): + """The MFTFineTuner exposes downstream_task_type / property_name / + task_dim / intensive so MFTConfigManager can read them. + """ + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: { + "model": { + "_extra_state": { + "model_params": { + "model_dict": {"SPICE2": {"fitting_net": {"type": "ener"}}} + } + } + } + }, + ) + t = MFTFineTuner( + pretrained="/does/not/exist.pt", + aux_branch="SPICE2", + downstream_task_type="property", + property_name="lumo", + task_dim=1, + intensive=True, + ) + assert t.downstream_task_type == "property" + assert t.property_name == "lumo" + assert t.task_dim == 1 + assert t.intensive is True + + +def test_ener_default_when_unspecified(monkeypatch): + """Back-compat: not passing downstream_task_type defaults to 'ener'.""" + import torch + + monkeypatch.setattr( + torch, + "load", + lambda *a, **kw: { + "model": { + "_extra_state": { + "model_params": { + "model_dict": {"Foo": {"fitting_net": {"type": "ener"}}} + } + } + } + }, + ) + t = MFTFineTuner( + pretrained="/does/not/exist.pt", aux_branch="Foo", downstream_task_type="ener" + ) + assert t.downstream_task_type == "ener" + assert t.property_name is None diff --git a/source/tests/dpa_adapt/test_paper_alignment.py b/source/tests/dpa_adapt/test_paper_alignment.py new file mode 100644 index 0000000000..78fa9c44ef --- /dev/null +++ b/source/tests/dpa_adapt/test_paper_alignment.py @@ -0,0 +1,456 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Verify emitted input.json matches the MFT paper repo +(Chengqian-Zhang/Multitask-finetuning/examples/qm9_gap/). + +Covers single-task FT/LP/Scratch (DPATrainer) and multi-task property-mode +MFT (MFTConfigManager). Configs are round-tripped through json to confirm +the fields survive serialization (no GPU / no real ckpt needed). + +Backward-compat note: legacy ener-mode MFT (mp_data sensitivity analysis) +must stay byte-for-byte unchanged; that is locked by +test_ener_mode_byte_for_byte_unchanged. +""" + +from __future__ import ( + annotations, +) + +import json +from typing import ( + ClassVar, +) +from unittest.mock import ( + patch, +) + +from dpa_adapt.config.manager import ( + MFTConfigManager, +) +from dpa_adapt.trainer import ( + DPATrainer, +) + +TYPE_MAP = ["H", "C", "N", "O"] + + +def _make_sys(tmp_path) -> str: + """Create one real system dir and return a glob matching it (DPATrainer + expands globs against the filesystem). + """ + root = tmp_path / "sys" + root.mkdir(parents=True, exist_ok=True) + (root / "s_000").mkdir(exist_ok=True) + return str(root / "s_*") + + +# --------------------------------------------------------------------------- +# DPATrainer (FT / LP / Scratch) helpers +# --------------------------------------------------------------------------- + + +def _fake_descriptor_sd() -> dict: + """Checkpoint state_dict shaped like DPA-3.1-3M: a custom_silu descriptor + with no fix_stat_std, to prove _get_descriptor overrides both. + """ + descriptor = { + "type": "dpa3", + "repflow": {"n_dim": 128, "e_dim": 64, "a_dim": 32, "nlayers": 16}, + "activation_function": "custom_silu:3.0", + "precision": "float32", + "trainable": True, + } + return { + "model": { + "_extra_state": { + "model_params": {"shared_dict": {"dpa3_descriptor": descriptor}} + } + } + } + + +def _patch_torch_load(): + return patch("torch.load", lambda *a, **kw: _fake_descriptor_sd()) + + +def _trainer(pretrained, tmp_path, **overrides): + sys_glob = _make_sys(tmp_path) + kwargs = { + "pretrained": pretrained, + "train_systems": sys_glob, + "valid_systems": sys_glob, + "type_map": TYPE_MAP, + } + kwargs.update(overrides) + return DPATrainer(**kwargs) + + +def _lp_config(tmp_path): + ckpt = tmp_path / "ckpt.pt" + ckpt.write_bytes(b"") + t = _trainer( + str(ckpt), tmp_path, freeze_backbone=True, output_dir=str(tmp_path / "o") + ) + with _patch_torch_load(): + config = t._build_config() + # Round-trip through json to mirror how fit() writes input.json. + return json.loads(json.dumps(config)), t + + +def _ft_config(tmp_path): + ckpt = tmp_path / "ckpt.pt" + ckpt.write_bytes(b"") + t = _trainer( + str(ckpt), tmp_path, freeze_backbone=False, output_dir=str(tmp_path / "o") + ) + with _patch_torch_load(): + config = t._build_config() + return json.loads(json.dumps(config)), t + + +# --------------------------------------------------------------------------- +# LP single-task input.json +# --------------------------------------------------------------------------- + + +def test_lp_input_json_no_dim_case_embd(tmp_path): + config, _ = _lp_config(tmp_path) + assert "dim_case_embd" not in config["model"]["fitting_net"] + + +def test_lp_input_json_descriptor_trainable_false(tmp_path): + config, _ = _lp_config(tmp_path) + assert config["model"]["descriptor"]["trainable"] is False + + +def test_lp_input_json_activation_silut(tmp_path): + config, _ = _lp_config(tmp_path) + assert config["model"]["descriptor"]["activation_function"] == "silut:3.0" + + +def test_lp_input_json_fix_stat_std_0_3(tmp_path): + config, _ = _lp_config(tmp_path) + assert config["model"]["descriptor"]["repflow"]["fix_stat_std"] == 0.3 + + +def test_lp_input_json_decay_steps_1000(tmp_path): + config, _ = _lp_config(tmp_path) + assert config["learning_rate"]["decay_steps"] == 1000 + + +def test_lp_input_json_gradient_max_norm_5(tmp_path): + config, _ = _lp_config(tmp_path) + assert config["training"]["gradient_max_norm"] == 5.0 + + +def test_lp_cmd_no_model_branch_flag(tmp_path): + _, t = _lp_config(tmp_path) + cmd = t._build_cmd("input.json") + assert "--model-branch" not in cmd + assert "--finetune" in cmd + assert "--skip-neighbor-stat" in cmd + + +def test_lp_input_json_loss_is_property(tmp_path): + config, _ = _lp_config(tmp_path) + loss = config["loss"] + assert loss["type"] == "property" + assert loss["loss_func"] == "mse" + assert loss["metric"] == ["mae", "rmse"] + + +# --------------------------------------------------------------------------- +# FT single-task input.json +# --------------------------------------------------------------------------- + + +def test_ft_input_json_descriptor_trainable_true(tmp_path): + """FT (freeze_backbone=False) keeps the descriptor trainable; paper FT + input.json omits trainable (defaults true). We emit trainable=true, which + is the same effective config. + """ + config, _ = _ft_config(tmp_path) + assert config["model"]["descriptor"]["trainable"] is True + + +def test_ft_input_json_no_dim_case_embd(tmp_path): + config, _ = _ft_config(tmp_path) + assert "dim_case_embd" not in config["model"]["fitting_net"] + + +def test_ft_cmd_no_model_branch_flag(tmp_path): + _, t = _ft_config(tmp_path) + cmd = t._build_cmd("input.json") + assert "--model-branch" not in cmd + assert "--finetune" in cmd + + +# --------------------------------------------------------------------------- +# MFT multi-task property-mode input.json +# --------------------------------------------------------------------------- + + +class _PropertyTuner: + pretrained = "/share/DPA-3.1-3M.pt" + aux_branch = "SPICE2" + aux_prob = 0.5 + type_map: ClassVar[list[str]] = ["H", "C", "N", "O"] + fitting_net_params: ClassVar[dict[str, object]] = { + "type": "ener", + "neuron": [240, 240, 240], + "dim_case_embd": 31, + "seed": 1, + } + downstream_task_type = "property" + property_name = "homo" + task_dim = 1 + intensive = True + learning_rate = 1e-3 + stop_lr = 1e-5 + max_steps = 100000 + batch_size = "auto:32" + seed = 42 + output_dir = "/tmp/mft_paper" + save_freq = 500 + disp_freq = 100 + train_data = "/data/qm9" + aux_data = "/data/spice2" + valid_data = None + + +def _mft_property_config(): + return json.loads(json.dumps(MFTConfigManager(_PropertyTuner()).build())) + + +def test_mft_input_json_downstream_branch_key_is_property(): + """Paper repo names the downstream branch "property" (not "DOWNSTREAM") + across model_dict / loss_dict / model_prob / data_dict. + """ + config = _mft_property_config() + md = config["model"]["model_dict"] + assert "property" in md + assert "DOWNSTREAM" not in md + assert "property" in config["loss_dict"] + assert "property" in config["training"]["model_prob"] + assert "property" in config["training"]["data_dict"] + + +def test_mft_input_json_downstream_finetune_head_random(): + config = _mft_property_config() + assert config["model"]["model_dict"]["property"]["finetune_head"] == "RANDOM" + + +def test_mft_input_json_aux_finetune_head_branch_name(): + config = _mft_property_config() + assert config["model"]["model_dict"]["SPICE2"]["finetune_head"] == "SPICE2" + + +def test_mft_input_json_downstream_has_dim_case_embd(): + config = _mft_property_config() + fn = config["model"]["model_dict"]["property"]["fitting_net"] + assert fn["dim_case_embd"] == 31 + + +def test_mft_input_json_aux_keeps_dim_case_embd(): + config = _mft_property_config() + fn = config["model"]["model_dict"]["SPICE2"]["fitting_net"] + assert fn["dim_case_embd"] == 31 + + +def test_mft_input_json_property_mode_loss_is_property(): + config = _mft_property_config() + loss = config["loss_dict"]["property"] + assert loss["type"] == "property" + assert loss["loss_func"] == "mse" + # aux branch keeps ener loss + assert config["loss_dict"]["SPICE2"]["type"] == "ener" + + +def test_mft_input_json_descriptor_silut_and_fix_stat_std(): + config = _mft_property_config() + desc = config["model"]["shared_dict"]["dpa3_descriptor"] + assert desc["activation_function"] == "silut:3.0" + assert desc["repflow"]["fix_stat_std"] == 0.3 + + +def test_mft_input_json_decay_steps_1000_and_grad_norm(): + config = _mft_property_config() + assert config["learning_rate"]["decay_steps"] == 1000 + assert config["training"]["gradient_max_norm"] == 5.0 + + +def test_mft_input_json_batch_sizes(): + config = _mft_property_config() + dd = config["training"]["data_dict"] + assert dd["SPICE2"]["training_data"]["batch_size"] == "auto:128" + assert dd["property"]["training_data"]["batch_size"] == "auto:512" + + +def test_mft_input_json_model_prob_default_half_half(): + config = _mft_property_config() + prob = config["training"]["model_prob"] + assert prob["SPICE2"] == 0.5 + assert prob["property"] == 0.5 + + +def test_mft_cmd_no_model_branch(): + cm = MFTConfigManager(_PropertyTuner()) + cmd = cm.build_cmd("input.json") + assert "--model-branch" not in cmd + assert "--finetune" in cmd + + +# --------------------------------------------------------------------------- +# Backward compat: legacy ener-mode MFT must be byte-for-byte unchanged +# --------------------------------------------------------------------------- + + +class _EnerTuner: + """No downstream_task_type attr — legacy mp_data sensitivity-analysis + caller. Must produce the pre-paper-alignment config exactly. + """ + + pretrained = "/share/DPA-3.1-3M.pt" + aux_branch = "MP_traj_v024_alldata_mixu" + aux_prob = 0.5 + type_map: ClassVar[list[str]] = ["Cu", "O"] + fitting_net_params: ClassVar[dict[str, object]] = { + "type": "ener", + "neuron": [240, 240, 240], + } + learning_rate = 1e-3 + stop_lr = 1e-5 + max_steps = 1000 + batch_size = "auto:32" + seed = 42 + output_dir = "/tmp/mft_ener" + save_freq = 500 + disp_freq = 100 + train_data = "/data/downstream" + aux_data = "/data/aux" + valid_data = None + + +# The expected legacy config, frozen from the pre-2026-05-20 manager.py output. +_LEGACY_ENER_EXPECTED = { + "model": { + "shared_dict": { + "dpa3_descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 16, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 1200, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 300, + "axis_neuron": 4, + "skip_stat": True, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "smooth_edge_update": True, + "use_dynamic_sel": True, + "sel_reduce_factor": 10.0, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "n_multi_edge_message": 1, + "optim_update": True, + "use_exp_switch": True, + }, + "activation_function": "custom_silu:3.0", + "precision": "float32", + "use_tebd_bias": False, + "concat_output_tebd": False, + "exclude_types": [], + "env_protection": 0.0, + "trainable": True, + "use_econf_tebd": False, + }, + "type_map": ["Cu", "O"], + }, + "model_dict": { + "MP_traj_v024_alldata_mixu": { + "type_map": "type_map", + "descriptor": "dpa3_descriptor", + "fitting_net": {"type": "ener", "neuron": [240, 240, 240]}, + }, + "DOWNSTREAM": { + "finetune_head": "MP_traj_v024_alldata_mixu", + "type_map": "type_map", + "descriptor": "dpa3_descriptor", + "fitting_net": {"type": "ener", "neuron": [240, 240, 240]}, + }, + }, + }, + "learning_rate": { + "type": "exp", + "start_lr": 1e-3, + "stop_lr": 1e-5, + "decay_steps": 5000, + }, + "loss_dict": { + "MP_traj_v024_alldata_mixu": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1, + }, + "DOWNSTREAM": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1, + }, + }, + "training": { + "model_prob": {"MP_traj_v024_alldata_mixu": 0.5, "DOWNSTREAM": 1.0}, + "data_dict": { + "MP_traj_v024_alldata_mixu": { + "training_data": {"systems": ["/data/aux"], "batch_size": "auto:32"} + }, + "DOWNSTREAM": { + "training_data": { + "systems": ["/data/downstream"], + "batch_size": "auto:32", + } + }, + }, + "numb_steps": 1000, + "save_freq": 500, + "save_ckpt": "/tmp/mft_ener/model.ckpt", + "disp_freq": 100, + "seed": 42, + }, +} + + +def test_ener_mode_byte_for_byte_unchanged(): + """Legacy ener MFT config (and its JSON serialization) must equal the + frozen pre-paper-alignment output exactly — including key order. + """ + config = MFTConfigManager(_EnerTuner()).build() + assert config == _LEGACY_ENER_EXPECTED + # Byte-for-byte JSON (key order preserved by Python dict insertion order). + assert json.dumps(config) == json.dumps(_LEGACY_ENER_EXPECTED) + + +def test_ener_mode_no_gradient_max_norm(): + config = MFTConfigManager(_EnerTuner()).build() + assert "gradient_max_norm" not in config["training"] + + +def test_ener_mode_no_fix_stat_std(): + config = MFTConfigManager(_EnerTuner()).build() + assert "fix_stat_std" not in config["model"]["shared_dict"]["dpa3_descriptor"] diff --git a/source/tests/dpa_adapt/test_predictor.py b/source/tests/dpa_adapt/test_predictor.py new file mode 100644 index 0000000000..5f0a8135ba --- /dev/null +++ b/source/tests/dpa_adapt/test_predictor.py @@ -0,0 +1,649 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for DPAPredictor — no real DPA checkpoint or torch required. + +A mock torch module is injected into sys.modules so that torch.save / +torch.load are backed by pickle. All DPA descriptor calls are also mocked. +""" + +import pickle +import sys +from pathlib import ( + Path, +) +from unittest.mock import ( + MagicMock, + patch, +) + +import numpy as np +import pytest + +# --------------------------------------------------------------------------- +# Use real torch serialization when available; otherwise fall back to a minimal +# pickle-backed mock so these tests can still run without a torch install. +# --------------------------------------------------------------------------- + + +def _pickle_save(obj, path, **kwargs): + with open(path, "wb") as f: + pickle.dump(obj, f) + + +def _pickle_load(path, **kwargs): + with open(path, "rb") as f: + return pickle.load(f) + + +try: + import torch as _torch_for_test +except Exception: + _mock_torch = MagicMock() + _mock_torch.save = _pickle_save + _mock_torch.load = _pickle_load + _mock_torch.cuda.is_available.return_value = False + # Prevent scipy._lib.array_api_compat.is_torch_array from crashing + # (it tries issubclass(cls, torch.Tensor); we make Tensor a real class). + _mock_torch.Tensor = type("Tensor", (), {}) + _torch_for_test = _mock_torch + + # Inject before any dpa_adapt import so the lazy `import torch` lines inside + # freeze() / DPAPredictor.__init__ pick up the mock. + sys.modules.setdefault("torch", _mock_torch) +else: + _torch_for_test.set_default_device(None) + +from dpa_adapt import ( + DPAFineTuner, + DPAPredictor, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_npy_system(root: Path, n_frames: int = 3, n_atoms: int = 2) -> None: + """Create a minimal deepmd/npy system directory for testing.""" + (root / "type.raw").write_text("0\n1\n") + (root / "type_map.raw").write_text("Cu\nO\n") + set_dir = root / "set.000" + set_dir.mkdir() + np.save(set_dir / "coord.npy", np.zeros((n_frames, n_atoms * 3))) + np.save(set_dir / "box.npy", np.eye(3).reshape(1, 9).repeat(n_frames, 0)) + np.save(set_dir / "energy.npy", np.arange(n_frames, dtype=float)) + + +FEAT_DIM = 8 + + +def _mock_extract_features(self, systems): + n_frames = sum(s.data["coords"].shape[0] for s in systems) + return np.random.default_rng(0).random((n_frames, FEAT_DIM)) + + +def _mock_load_descriptor_model(self): + self._checkpoint_type_map = ["Cu", "O"] + return None + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestPredictRoundtrip: + """Freeze a Ridge on mock features, reload with DPAPredictor, check shape.""" + + def test_predict_roundtrip(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=4) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear") + ft.fit(str(system), target_key="energy") + frozen = ft.freeze(str(tmp_path / "model.pth")) + + pred = DPAPredictor(frozen) + result = pred.predict(str(system)) + + assert hasattr(result, "predictions") + assert result.predictions.shape == (4, 1) + + +class TestEvaluateReturnsMetrics: + """evaluate() must return mae/rmse/r2/predictions/labels with consistent shapes.""" + + def test_evaluate_returns_metrics(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=5) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear") + ft.fit(str(system), target_key="energy") + frozen = ft.freeze(str(tmp_path / "model.pth")) + + pred = DPAPredictor(frozen) + result = pred.evaluate(str(system)) + + for key in ("mae", "rmse", "r2", "predictions", "labels"): + assert hasattr(result, key), f"Missing key: {key}" + + assert result.predictions.shape == result.labels.shape + assert result.predictions.shape[0] == 5 + assert isinstance(result.mae, float) + assert isinstance(result.rmse, float) + + +class TestFreezeBundleHasModelBranch: + """freeze() bundle must include model_branch (guards the §1 bug fix).""" + + def test_freeze_bundle_has_model_branch(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=3) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner( + pretrained="fake.pt", + model_branch="Omat24", + predictor="linear", + ) + ft.fit(str(system), target_key="energy") + frozen = ft.freeze(str(tmp_path / "model.pth")) + + from dpa_adapt._backend import ( + load_torch_file, + ) + + bundle = load_torch_file(frozen) + + assert "model_branch" in bundle, "Bundle is missing 'model_branch' key" + assert bundle["model_branch"] == "Omat24" + + +# --------------------------------------------------------------------------- +# Committee helpers +# --------------------------------------------------------------------------- + + +def _make_mlp_bundle(tmp_path, n_frames=20): + """Create a frozen bundle with an MLPRegressor (uses random_state).""" + from sklearn.neural_network import ( + MLPRegressor, + ) + from sklearn.pipeline import ( + make_pipeline, + ) + from sklearn.preprocessing import ( + StandardScaler, + ) + + pipeline = make_pipeline( + StandardScaler(), + MLPRegressor( + hidden_layer_sizes=(10, 5), + max_iter=300, + random_state=42, + early_stopping=False, + ), + ) + + from dpa_adapt._backend import ( + load_torch_file, + ) + + bundle = { + "predictor": pipeline, + "target_key": "energy", + "type_map": ["Cu", "O"], + "task_dim": 1, + "pretrained": "fake.pt", + "pooling": "mean", + "model_branch": None, + "condition_manager": None, + } + path = str(tmp_path / "mlp_model.pth") + _torch_for_test.save(bundle, path) + assert load_torch_file(path)["target_key"] == "energy" + return path + + +def _make_rf_bundle(tmp_path, n_frames=20): + """Create a frozen bundle with a pre-fitted RandomForestRegressor.""" + from sklearn.ensemble import ( + RandomForestRegressor, + ) + from sklearn.pipeline import ( + make_pipeline, + ) + from sklearn.preprocessing import ( + StandardScaler, + ) + + pipeline = make_pipeline( + StandardScaler(), + RandomForestRegressor( + n_estimators=100, + random_state=42, + ), + ) + # Pre-fit on synthetic data so that tree estimators are available. + rng = np.random.default_rng(0) + X = rng.random((n_frames, FEAT_DIM)) + y = rng.random(n_frames) + pipeline.fit(X, y) + + from dpa_adapt._backend import ( + load_torch_file, + ) + + bundle = { + "predictor": pipeline, + "target_key": "energy", + "type_map": ["Cu", "O"], + "task_dim": 1, + "pretrained": "fake.pt", + "pooling": "mean", + "model_branch": None, + "condition_manager": None, + } + path = str(tmp_path / "rf_model.pth") + _torch_for_test.save(bundle, path) + assert load_torch_file(path)["target_key"] == "energy" + return path + + +def _make_multioutput_rf_bundle(tmp_path, n_frames=20): + """Create a frozen bundle with MultiOutputRegressor(RandomForestRegressor).""" + from sklearn.ensemble import ( + RandomForestRegressor, + ) + from sklearn.multioutput import ( + MultiOutputRegressor, + ) + from sklearn.pipeline import ( + make_pipeline, + ) + from sklearn.preprocessing import ( + StandardScaler, + ) + + pipeline = make_pipeline( + StandardScaler(), + MultiOutputRegressor( + RandomForestRegressor( + n_estimators=100, + random_state=42, + ) + ), + ) + rng = np.random.default_rng(0) + X = rng.random((n_frames, FEAT_DIM)) + y = rng.random((n_frames, 2)) + pipeline.fit(X, y) + + from dpa_adapt._backend import ( + load_torch_file, + ) + + bundle = { + "predictor": pipeline, + "target_key": ["homo", "lumo"], + "type_map": ["Cu", "O"], + "task_dim": 2, + "pretrained": "fake.pt", + "pooling": "mean", + "model_branch": None, + "condition_manager": None, + } + path = str(tmp_path / "multioutput_rf_model.pth") + _torch_for_test.save(bundle, path) + assert load_torch_file(path)["target_key"] == ["homo", "lumo"] + return path + + +# --------------------------------------------------------------------------- +# Committee tests +# --------------------------------------------------------------------------- + + +class TestCommitteeFitPredict: + """n_committee > 1 trains ensemble and returns mean+std.""" + + def test_committee_fit_predict(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=20) + bundle_path = _make_mlp_bundle(tmp_path, n_frames=20) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + pred = DPAPredictor(bundle_path, n_committee=5) + pred.fit(str(system), target_key="energy") + result = pred.predict(str(system), return_uncertainty=True) + + assert hasattr(result, "predictions") + assert hasattr(result, "uncertainty") + assert result.predictions.shape == (20, 1) + assert result.uncertainty.shape == (20, 1) + assert np.all(result.uncertainty >= 0) + assert np.any(result.uncertainty > 0), ( + "Committee std should be > 0 for some samples" + ) + + +class TestCommitteeThreshold: + """After fit, uncertainty_threshold_ is set.""" + + def test_committee_threshold_set(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=20) + bundle_path = _make_mlp_bundle(tmp_path, n_frames=20) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + pred = DPAPredictor(bundle_path, n_committee=5) + pred.fit(str(system), target_key="energy") + + assert hasattr(pred, "uncertainty_threshold_") + assert isinstance(pred.uncertainty_threshold_, float) + assert pred.uncertainty_threshold_ > 0 + + +class TestCommitteeN1BackwardCompat: + """n_committee=1 must behave identically to the current single-estimator behaviour.""" + + def test_committee_n1_backward_compat(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=4) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear") + ft.fit(str(system), target_key="energy") + frozen = ft.freeze(str(tmp_path / "model.pth")) + + pred = DPAPredictor(frozen, n_committee=1) + result = pred.predict(str(system)) + + assert hasattr(result, "predictions") + assert result.predictions.shape == (4, 1) + + +class TestReturnUncertaintyFalse: + """Default return_uncertainty=False returns DotDict (not a tuple).""" + + def test_return_uncertainty_false(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=20) + bundle_path = _make_mlp_bundle(tmp_path, n_frames=20) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + pred = DPAPredictor(bundle_path, n_committee=5) + pred.fit(str(system), target_key="energy") + result = pred.predict(str(system)) # default return_uncertainty=False + + assert not isinstance(result, tuple) + assert hasattr(result, "predictions") + assert not hasattr(result, "uncertainty"), ( + "uncertainty should not be present when return_uncertainty=False" + ) + + +class TestRfUncertainty: + """RF natively supports uncertainty via per-tree std.""" + + def test_rf_uncertainty(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=20) + bundle_path = _make_rf_bundle(tmp_path, n_frames=20) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + pred = DPAPredictor(bundle_path) + result = pred.predict(str(system), return_uncertainty=True) + + assert hasattr(result, "predictions") + assert hasattr(result, "uncertainty") + assert result.predictions.shape == (20, 1) + assert result.uncertainty.shape == (20, 1) + assert np.all(result.uncertainty >= 0) + assert np.any(result.uncertainty > 0), ( + "RF tree-level std should be > 0 for some samples" + ) + + def test_multioutput_rf_uncertainty(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_multi_npy_system(system, n_frames=20) + bundle_path = _make_multioutput_rf_bundle(tmp_path, n_frames=20) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + pred = DPAPredictor(bundle_path) + result = pred.predict(str(system), return_uncertainty=True) + + assert hasattr(result, "predictions") + assert hasattr(result, "uncertainty") + assert result.predictions.shape == (20, 2) + assert result.uncertainty.shape == (20, 2) + assert np.all(result.uncertainty >= 0) + assert np.any(result.uncertainty > 0), ( + "Multi-output RF tree-level std should be > 0 for some samples" + ) + + +class TestRidgeUncertaintyRaises: + """Ridge cannot produce uncertainty — calling return_uncertainty=True must raise.""" + + def test_ridge_uncertainty_raises(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=4) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="linear") + ft.fit(str(system), target_key="energy") + frozen = ft.freeze(str(tmp_path / "model.pth")) + + pred = DPAPredictor(frozen) + with pytest.raises(ValueError, match="Ridge regression"): + pred.predict(str(system), return_uncertainty=True) + + +# --------------------------------------------------------------------------- +# Multi-property tests +# --------------------------------------------------------------------------- + + +def _make_multi_npy_system(root: Path, n_frames: int = 5, n_atoms: int = 2) -> None: + """Create a minimal system with homo.npy and lumo.npy label files.""" + (root / "type.raw").write_text("0\n1\n") + (root / "type_map.raw").write_text("Cu\nO\n") + set_dir = root / "set.000" + set_dir.mkdir() + np.save(set_dir / "coord.npy", np.zeros((n_frames, n_atoms * 3))) + np.save(set_dir / "box.npy", np.eye(3).reshape(1, 9).repeat(n_frames, 0)) + np.save(set_dir / "homo.npy", -np.arange(n_frames, dtype=float) - 0.1) + np.save(set_dir / "lumo.npy", np.arange(n_frames, dtype=float) + 0.1) + + +class TestMultiPropertyFit: + """fit() with list[str] target_key must produce multi-output predictions.""" + + @pytest.mark.parametrize("predictor_type", ["ridge", "rf", "mlp"]) + def test_multi_output_all_predictors(self, tmp_path, predictor_type): + # MLP needs enough samples to split a validation set (10% of n_frames). + n = 50 if predictor_type == "mlp" else 5 + system = tmp_path / "sys" + system.mkdir() + _make_multi_npy_system(system, n_frames=n) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor=predictor_type) + ft.fit(str(system), target_key=["homo", "lumo"]) + + assert ft._task_dim == 2 + assert ft._fitted is True + + result = ft.predict(str(system)) + assert result.predictions.shape == (n, 2), ( + f"{predictor_type}: expected ({n},2), got {result.predictions.shape}" + ) + + +class TestMultiPropertyEvaluate: + """evaluate() with list target_key returns per-property metrics dict.""" + + def test_evaluate_returns_per_property_dict(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_multi_npy_system(system, n_frames=5) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="ridge") + ft.fit(str(system), target_key=["homo", "lumo"]) + result = ft.evaluate(str(system)) + + assert isinstance(result.mae, dict), ( + f"Expected dict mae, got {type(result.mae)}" + ) + assert isinstance(result.rmse, dict) + assert isinstance(result.r2, dict) + assert set(result.mae.keys()) == {"homo", "lumo"} + assert set(result.rmse.keys()) == {"homo", "lumo"} + assert set(result.r2.keys()) == {"homo", "lumo"} + assert all(isinstance(v, float) for v in result.mae.values()) + assert result.predictions.shape == result.labels.shape + assert result.predictions.shape[0] == 5 + + def test_single_property_still_returns_float(self, tmp_path): + """Backward compat: single str target_key returns flat floats, not dict.""" + system = tmp_path / "sys" + system.mkdir() + _make_npy_system(system, n_frames=5) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="ridge") + ft.fit(str(system), target_key="energy") + result = ft.evaluate(str(system)) + + assert isinstance(result.mae, float), ( + f"Expected float mae, got {type(result.mae)}" + ) + assert isinstance(result.rmse, float) + assert isinstance(result.r2, float) + + +class TestMultiPropertyFreezeRoundtrip: + """freeze/load round-trip preserves list target_key and multi-output.""" + + def test_freeze_load_roundtrip_list_target_key(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_multi_npy_system(system, n_frames=5) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="ridge") + ft.fit(str(system), target_key=["homo", "lumo"]) + frozen = ft.freeze(str(tmp_path / "model.pth")) + + pred = DPAPredictor(frozen) + result = pred.predict(str(system)) + + assert result.predictions.shape == (5, 2) + assert pred._target_key == ["homo", "lumo"] + assert pred._task_dim == 2 + + def test_freeze_load_roundtrip_evaluate_per_property(self, tmp_path): + system = tmp_path / "sys" + system.mkdir() + _make_multi_npy_system(system, n_frames=50) + + with ( + patch.object( + DPAFineTuner, "_load_descriptor_model", _mock_load_descriptor_model + ), + patch.object(DPAFineTuner, "_extract_features", _mock_extract_features), + ): + ft = DPAFineTuner(pretrained="fake.pt", predictor="mlp") + ft.fit(str(system), target_key=["homo", "lumo"]) + frozen = ft.freeze(str(tmp_path / "model.pth")) + + pred = DPAPredictor(frozen) + metrics = pred.evaluate(str(system)) + + assert isinstance(metrics.mae, dict) + assert set(metrics.mae.keys()) == {"homo", "lumo"} + assert metrics.predictions.shape == (50, 2) diff --git a/source/tests/dpa_adapt/test_smiles_data.py b/source/tests/dpa_adapt/test_smiles_data.py new file mode 100644 index 0000000000..6ceabf4442 --- /dev/null +++ b/source/tests/dpa_adapt/test_smiles_data.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, +) +from unittest import ( + mock, +) + +if TYPE_CHECKING: + from pathlib import ( + Path, + ) + +import numpy as np + +from dpa_adapt.data import smiles as mol_module +from dpa_adapt.data.smiles import ( + _build_type_map_from_elements, + _has_overlapping_atoms, + _parse_property_value, + _records_from_csv_mol, + _records_from_csv_smiles, + predict_records_from_data, + read_mol_coords, + records_from_direct_data, +) + + +def test__parse_property_value_accepts_text_with_units() -> None: + assert _parse_property_value("gap = -1.25 eV") == -1.25 + + +def test_overlap_detection() -> None: + coords = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=np.float32) + + assert _has_overlapping_atoms(coords, 1e-6) + + +def test_type_map_uses_periodic_table_order() -> None: + assert _build_type_map_from_elements({"O", "C", "H"}) == ["H", "C", "O"] + + +def test_records_from_direct_data() -> None: + records, rows = records_from_direct_data( + { + "atoms": [["O", "H", "H"]], + "coordinates": [np.zeros((3, 3), dtype=np.float32)], + "target": [1.5], + } + ) + + assert records[0][0] == ["O", "H", "H"] + assert records[0][2] == 1.5 + assert rows == [{"sample_id": 0, "target": 1.5}] + + +def test_records_from_csv_smiles_generates_coordinates(tmp_path: Path) -> None: + dataset = tmp_path / "dataset.csv" + dataset.write_text("SMILES,Property\nO,1.5\n", encoding="utf-8") + + with mock.patch.object( + mol_module, + "smiles_to_3d_coords", + return_value=( + ["O", "H", "H"], + np.array( + [[0.0, 0.0, 0.0], [0.9, 0.0, 0.0], [-0.2, 0.9, 0.0]], + dtype=np.float32, + ), + ), + ) as smiles_mock: + records, failed_rows, skipped_zero, skipped_overlap, rows = ( + _records_from_csv_smiles( + dataset=dataset, + property_col="Property", + ) + ) + + smiles_mock.assert_called_once_with("O", random_seed=42) + assert records[0][0] == ["O", "H", "H"] + assert records[0][2] == 1.5 + assert failed_rows == [] + assert skipped_zero == 0 + assert skipped_overlap == 0 + assert rows == [{"SMILES": "O", "Property": "1.5"}] + + +def test_records_from_csv_smiles_collects_failed_rows(tmp_path: Path) -> None: + dataset = tmp_path / "dataset.csv" + dataset.write_text("SMILES,Property\nbad,1.5\n", encoding="utf-8") + + with mock.patch.object( + mol_module, + "smiles_to_3d_coords", + side_effect=ValueError("bad smiles"), + ): + records, failed_rows, skipped_zero, skipped_overlap, rows = ( + _records_from_csv_smiles( + dataset=dataset, + property_col="Property", + ) + ) + + assert records == [] + assert failed_rows == [(0, "bad", "bad smiles")] + assert skipped_zero == 0 + assert skipped_overlap == 0 + assert rows == [] + + +def test_csv_mol_path_does_not_use_smiles_generation(tmp_path: Path) -> None: + dataset = tmp_path / "dataset.csv" + dataset.write_text("SMILES,Property\nbad,1.5\n", encoding="utf-8") + mol_dir = tmp_path / "mol" + mol_dir.mkdir() + mol_path = mol_dir / "id0.mol" + mol_path.write_text( + "\n".join( + [ + "methane", + "", + "", + " 1 0 0 0 0 0 999 V2000", + " 0.1000 0.2000 0.3000 C 0 0 0 0 0 0 0 0 0 0 0 0", + "M END", + ] + ), + encoding="utf-8", + ) + + with mock.patch.object( + mol_module, + "smiles_to_3d_coords", + side_effect=AssertionError("SMILES generation should not be used"), + ): + records, failed_rows, skipped_zero, skipped_overlap, rows = ( + _records_from_csv_mol( + dataset=dataset, + mol_dir=mol_dir, + property_col="Property", + ) + ) + atoms, coords, pred_rows = predict_records_from_data( + {"dataset": dataset, "mol_dir": mol_dir}, + property_col=None, + ) + + assert records[0][0] == ["C"] + assert failed_rows == [] + assert skipped_zero == 0 + assert skipped_overlap == 0 + assert rows == [{"SMILES": "bad", "Property": "1.5"}] + assert atoms == [["C"]] + assert coords[0].shape == (1, 3) + assert pred_rows == [{"SMILES": "bad", "Property": "1.5"}] + + +def test_read_mol_coords(tmp_path: Path) -> None: + mol_path = tmp_path / "id0.mol" + mol_path.write_text( + "\n".join( + [ + "methane", + "", + "", + " 1 0 0 0 0 0 999 V2000", + " 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0", + "M END", + ] + ), + encoding="utf-8", + ) + + symbols, coords = read_mol_coords(mol_path) + + assert symbols == ["C"] + assert coords.shape == (1, 3) diff --git a/source/tests/dpa_adapt/test_split_cv.py b/source/tests/dpa_adapt/test_split_cv.py new file mode 100644 index 0000000000..aedc188e23 --- /dev/null +++ b/source/tests/dpa_adapt/test_split_cv.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for train_test_split() and cross_validate().""" + +import json +import os +import tempfile +from pathlib import ( + Path, +) + +import numpy as np +import pytest + +from dpa_adapt.cv import ( + _build_fold_groups, + _extract_formula, + _formula_to_group, + cross_validate, + train_test_split, +) +from dpa_adapt.data.loader import ( + load_data, +) + + +def _write_system( + root: str, + natoms: int = 2, + nframes: int = 3, + label_key: str = "energy", + elements: list[str] | None = None, +): + """Create a deepmd/npy system dir, load it, return dpdata.System.""" + if elements is None: + elements = ["H", "O"] + root = Path(root) + root.mkdir(parents=True, exist_ok=True) + n_atoms = len(elements) + (root / "type.raw").write_text( + "\n".join(str(i % n_atoms) for i in range(natoms)) + "\n" + ) + (root / "type_map.raw").write_text("\n".join(elements) + "\n") + sdir = root / "set.000" + sdir.mkdir(exist_ok=True) + np.save(sdir / "coord.npy", np.zeros((nframes, natoms * 3))) + np.save(sdir / "box.npy", np.tile(np.eye(3).ravel(), (nframes, 1))) + np.save(sdir / f"{label_key}.npy", np.ones((nframes, 1))) + return load_data(str(root))[0] + + +def _write_oer_tree( + tmpdir: str, formulas: list[str], nsets: int = 3, label_key: str = "energy" +) -> list: + """Create an OER-style tree and return loaded dpdata.System objects.""" + systems = [] + for formula in formulas: + for s in range(1, nsets + 1): + sysdir = Path(tmpdir) / f"set_{s:02d}" / formula / "353" + sys = _write_system(str(sysdir), natoms=10, nframes=3, label_key=label_key) + systems.append(sys) + return sorted(systems, key=lambda s: s._dpa_source) + + +def _make_manifest( + formula_parts: list[list[str]], test: list[str], tag: str = "ni" +) -> str: + m = { + "meta": {"mode": "stratified", "k": len(formula_parts), "seed": 123}, + "co": {"test": [], "parts": []}, + tag: {"test": test, "parts": formula_parts}, + } + fd, path = tempfile.mkstemp(suffix=".json") + os.close(fd) + Path(path).write_text(json.dumps(m)) + return path + + +class TestExtractFormula: + def test_oer_path_from_dpa_source(self, tmp_path): + sys = _write_system(str(tmp_path / "set_01" / "Ni0.5Fe0.5O2H1" / "353")) + assert "Ni0.5Fe0.5O2H1" in _extract_formula(sys) + + def test_formula_to_group(self, tmp_path): + s1 = _write_system(str(tmp_path / "set_01" / "A" / "1")) + s2 = _write_system(str(tmp_path / "set_02" / "A" / "1")) + s3 = _write_system(str(tmp_path / "set_01" / "B" / "1")) + groups = _formula_to_group([s1, s2, s3]) + assert groups == ["A", "A", "B"] + + +class TestBuildFoldGroups: + def test_three_folds(self): + parts = [["A", "B"], ["C", "D"], ["E"]] + path = _make_manifest(parts, test=["F"]) + folds, test = _build_fold_groups(path) + assert len(folds) == 3 + assert folds[0] == {"A", "B"} + assert test == {"F"} + + +class TestTrainTestSplit: + @pytest.fixture(autouse=True) + def setup(self, tmp_path): + self.tmp = tmp_path + formulas = [f"Comp{i}" for i in range(10)] + self.systems = _write_oer_tree( + str(tmp_path), formulas, nsets=2, label_key="energy" + ) + + def test_manifest_split(self): + parts = [ + ["Comp0", "Comp1"], + ["Comp2", "Comp3"], + ["Comp4", "Comp5"], + ["Comp6", "Comp7"], + ["Comp8"], + ] + mpath = _make_manifest(parts, test=["Comp9"]) + train, valid, test = train_test_split(self.systems, manifest=mpath) + assert len(train) == 16, f"got {len(train)}" + assert len(valid) == 2 + assert len(test) == 2 + t = set(_formula_to_group(train)) + v = set(_formula_to_group(valid)) + e = set(_formula_to_group(test)) + assert len(t & v) == 0 + assert len(t & e) == 0 + assert "Comp9" in e + assert "Comp8" in v + + def test_group_by_formula(self): + train, valid, test = train_test_split( + self.systems, + group_by="formula", + test_size=0.1, + valid_size=0.2, + seed=42, + ) + t = set(_formula_to_group(train)) + v = set(_formula_to_group(valid)) + e = set(_formula_to_group(test)) + assert len(t & v) == 0 + assert len(t & e) == 0 + assert len(v & e) == 0 + + def test_group_by_explicit_list(self): + groups = _formula_to_group(self.systems) + train, valid, test = train_test_split( + self.systems, + group_by=groups, + test_size=0.1, + valid_size=0.1, + seed=42, + ) + t = set(_formula_to_group(train)) + v = set(_formula_to_group(valid)) + assert len(t & v) == 0 + + def test_no_group_by_raises(self): + with pytest.raises(ValueError, match="Either manifest"): + train_test_split(self.systems) + + +class TestCrossValidate: + @pytest.fixture(autouse=True) + def setup(self, tmp_path): + self.tmp = tmp_path + formulas = [f"Comp{i}" for i in range(5)] + self.systems = _write_oer_tree( + str(tmp_path), formulas, nsets=2, label_key="energy" + ) + + def test_expensive_cv_guard(self): + class FakeModel: + strategy = "finetune" + pretrained = None + model_branch = None + pooling = "mean" + + with pytest.raises(ValueError, match="allow_expensive_cv"): + cross_validate( + FakeModel(), + self.systems, + label_key="energy", + cv=3, + group_by="formula", + ) + + def test_invalid_granularity(self): + class FakeModel: + strategy = "frozen_sklearn" + pretrained = None + model_branch = None + pooling = "mean" + + with pytest.raises(ValueError, match="granularity"): + cross_validate( + FakeModel(), + self.systems, + label_key="energy", + cv=5, + group_by="formula", + granularity="invalid", + ) + + def test_invalid_cv_value(self): + class FakeModel: + strategy = "frozen_sklearn" + pretrained = None + model_branch = None + pooling = "mean" + + with pytest.raises(ValueError, match="cv must be"): + cross_validate( + FakeModel(), + self.systems, + label_key="energy", + cv=1, + group_by="formula", + ) + + +class TestStandardScalerConsistency: + def test_same_predictions_on_same_data(self): + from sklearn.linear_model import ( + Ridge, + ) + from sklearn.pipeline import ( + make_pipeline, + ) + from sklearn.preprocessing import ( + StandardScaler, + ) + + from dpa_adapt.cv import ( + _build_sklearn_head, + ) + + rng = np.random.default_rng(42) + X = rng.normal(size=(100, 32)) + y = rng.normal(size=(100,)) + + head1 = make_pipeline(StandardScaler(), _build_sklearn_head("ridge", seed=42)) + head1.fit(X, y) + pred1 = head1.predict(X) + + head2 = make_pipeline(StandardScaler(), Ridge(alpha=1.0, random_state=42)) + head2.fit(X, y) + pred2 = head2.predict(X) + + np.testing.assert_array_almost_equal(pred1, pred2) + + +class TestDeterministicCV: + """Ensures cross_validate with frozen_sklearn + GroupKFold is deterministic.""" + + def test_deterministic_folds_same_result_twice(self, tmp_path, monkeypatch): + raise pytest.skip("needs real DPA checkpoint to extract descriptors") + + def test_manifest_folds(self, tmp_path, monkeypatch): + raise pytest.skip("needs real DPA checkpoint to extract descriptors") diff --git a/source/tests/dpa_adapt/test_trainer.py b/source/tests/dpa_adapt/test_trainer.py new file mode 100644 index 0000000000..19ec3f3da5 --- /dev/null +++ b/source/tests/dpa_adapt/test_trainer.py @@ -0,0 +1,534 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for dpa_adapt.trainer.DPATrainer.""" + +from __future__ import ( + annotations, +) + +import os +from pathlib import ( + Path, +) +from unittest.mock import ( + patch, +) + +import pytest + +from dpa_adapt.trainer import ( + DPATrainer, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +DUMMY_TYPE_MAP = ["H", "C", "N", "O"] + + +def _make_systems(tmp_path, prefix: str, n: int) -> str: + """Create n empty system dirs and return a glob pattern matching them.""" + root = tmp_path / prefix + root.mkdir(parents=True, exist_ok=True) + for i in range(n): + (root / f"sys_{i:03d}").mkdir() + return str(root / "sys_*") + + +def _fake_descriptor_sd() -> dict: + """Minimal checkpoint state_dict with the descriptor path the trainer reads.""" + descriptor = { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 16, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 1200, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 300, + "axis_neuron": 4, + "skip_stat": True, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "smooth_edge_update": True, + "use_dynamic_sel": True, + "sel_reduce_factor": 10.0, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "n_multi_edge_message": 1, + "optim_update": True, + "use_exp_switch": True, + }, + "activation_function": "custom_silu:3.0", + "precision": "float32", + "use_tebd_bias": False, + "concat_output_tebd": False, + "exclude_types": [], + "env_protection": 0.0, + "trainable": True, + "use_econf_tebd": False, + } + return { + "model": { + "_extra_state": { + "model_params": { + "shared_dict": {"dpa3_descriptor": descriptor}, + } + } + } + } + + +@pytest.fixture +def systems(tmp_path): + """Build train + valid system globs in a tmp directory.""" + train_glob = _make_systems(tmp_path, "train", 60) + valid_glob = _make_systems(tmp_path, "valid", 60) + return train_glob, valid_glob + + +@pytest.fixture +def dummy_ckpt(tmp_path): + """Create an empty file to act as a 'pretrained' checkpoint path.""" + ckpt = tmp_path / "dummy.pt" + ckpt.write_bytes(b"") + return str(ckpt) + + +def _patch_torch_load(): + """Patch torch.load to return our fake descriptor state_dict.""" + return patch("torch.load", lambda *a, **kw: _fake_descriptor_sd()) + + +# --------------------------------------------------------------------------- +# 1. init validation +# --------------------------------------------------------------------------- + + +def test_init_validation(tmp_path, systems): + train_glob, valid_glob = systems + + # train_systems is None + with pytest.raises(ValueError, match="train_systems"): + DPATrainer( + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + ) + + # type_map is None + with pytest.raises(ValueError, match="type_map"): + DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + ) + + # freeze_backbone=True without pretrained + with pytest.raises(ValueError, match="LP requires"): + DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + freeze_backbone=True, + ) + + # pretrained path does not exist + with pytest.raises(ValueError, match="not found"): + DPATrainer( + pretrained=str(tmp_path / "does_not_exist.pt"), + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + ) + + +# --------------------------------------------------------------------------- +# 2. FT config +# --------------------------------------------------------------------------- + + +def test_config_ft(systems, dummy_ckpt, tmp_path): + train_glob, valid_glob = systems + t = DPATrainer( + pretrained=dummy_ckpt, + freeze_backbone=False, + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + output_dir=str(tmp_path / "out"), + ) + with _patch_torch_load(): + config = t._build_config() + cmd = t._build_cmd("input.json") + + assert "--finetune" in cmd + # pretrained must immediately follow --finetune + assert cmd[cmd.index("--finetune") + 1] == dummy_ckpt + # Paper alignment: single-task fine-tune passes NO --model-branch. + assert "--model-branch" not in cmd + assert "--skip-neighbor-stat" in cmd + + assert config["model"]["descriptor"]["trainable"] is True + + +# --------------------------------------------------------------------------- +# 4. LP config +# --------------------------------------------------------------------------- + + +def test_config_lp(systems, dummy_ckpt, tmp_path): + train_glob, valid_glob = systems + t = DPATrainer( + pretrained=dummy_ckpt, + freeze_backbone=True, + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + output_dir=str(tmp_path / "out"), + ) + with _patch_torch_load(): + config = t._build_config() + cmd = t._build_cmd("input.json") + + assert "--finetune" in cmd + assert cmd[cmd.index("--finetune") + 1] == dummy_ckpt + # Paper alignment: single-task fine-tune passes NO --model-branch. + assert "--model-branch" not in cmd + assert "--skip-neighbor-stat" in cmd + assert config["model"]["descriptor"]["trainable"] is False + + +# --------------------------------------------------------------------------- +# 5. Glob expansion +# --------------------------------------------------------------------------- + + +def test_glob_expansion(tmp_path): + train_glob = _make_systems(tmp_path, "train", 70) + valid_glob = _make_systems(tmp_path, "valid", 70) + + t = DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + output_dir=str(tmp_path / "out"), + ) + config = t._build_config() + assert len(config["training"]["training_data"]["systems"]) == 70 + assert len(config["training"]["validation_data"]["systems"]) == 70 + + # Empty glob raises + empty_glob = str(tmp_path / "nope" / "*") + t_empty = DPATrainer( + train_systems=empty_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + output_dir=str(tmp_path / "out2"), + ) + with pytest.raises(ValueError, match="resolved to 0 systems"): + t_empty._build_config() + + +# --------------------------------------------------------------------------- +# 6. evaluate() output parsing +# --------------------------------------------------------------------------- + + +def test_evaluate_parse(systems, tmp_path): + train_glob, valid_glob = systems + t = DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + output_dir=str(tmp_path / "out"), + ) + + # Place a fake checkpoint so _final_ckpt_path() finds it. + os.makedirs(t.output_dir, exist_ok=True) + fake_ckpt = os.path.join(t.output_dir, "model.ckpt-100.pt") + open(fake_ckpt, "w").close() + + # Need an existing system path for the test glob to resolve. + test_glob = _make_systems(tmp_path, "test", 5) + + canned_stdout = ( + "DEEPMD INFO # number of test data : 42\n" + "DEEPMD INFO PROPERTY MAE : 0.006789 units\n" + "DEEPMD INFO PROPERTY RMSE : 0.012345 units\n" + ) + + class _Result: + stdout = canned_stdout + stderr = "" + returncode = 0 + + with patch("subprocess.run", return_value=_Result()): + out = t.evaluate(test_glob) + + assert out["rmse"] == pytest.approx(0.012345) + assert out["mae"] == pytest.approx(0.006789) + assert out["n_frames"] == 42 + # evaluate() concatenates stdout + "\n" + stderr; canned_stdout must be in it. + assert canned_stdout in out["_raw_stdout"] + assert ( + "rmse" in out["_parser_pattern_used"].lower() + or "mae" in out["_parser_pattern_used"].lower() + ) + + +# --------------------------------------------------------------------------- +# 7. Parser: property-explicit pattern +# --------------------------------------------------------------------------- + + +def test_evaluate_parse_property_explicit(): + stdout = ( + "DEEPMD INFO PROPERTY RMSE : 0.0123 units\n" + "DEEPMD INFO PROPERTY MAE : 0.0080 units\n" + ) + out = DPATrainer._parse_test_output(stdout) + assert out["rmse"] == pytest.approx(0.0123) + assert out["mae"] == pytest.approx(0.0080) + assert "PROPERTY" in out["_parser_pattern_used"] + assert out["_raw_stdout"] == stdout + + +# --------------------------------------------------------------------------- +# 8. Parser: property format (no generic fallback — removed during refactor) +# --------------------------------------------------------------------------- + + +def test_evaluate_parse_property_format_explicit(): + r"""Parser auto-detects PROPERTY output and matches the well-anchored regex. + Generic \brmse\b / \bmae\b fallback patterns were removed. + """ + stdout = ( + "DEEPMD INFO PROPERTY MAE : 0.0234 units\n" + "DEEPMD INFO PROPERTY RMSE : 0.0150 units\n" + ) + out = DPATrainer._parse_test_output(stdout) + assert out["mae"] == pytest.approx(0.0234) + assert out["rmse"] == pytest.approx(0.0150) + assert "PROPERTY" in out["_parser_pattern_used"] + + +# --------------------------------------------------------------------------- +# 9. Parser: unparseable input raises RuntimeError +# --------------------------------------------------------------------------- + + +def test_evaluate_parse_unparseable(): + stdout = "no numbers here" + with pytest.raises(RuntimeError) as exc_info: + DPATrainer._parse_test_output(stdout) + assert "no numbers here" in str(exc_info.value) + + +# --------------------------------------------------------------------------- +# 10. Idempotency: skip when a longer checkpoint exists +# --------------------------------------------------------------------------- + + +def test_idempotency_skip_when_longer_ckpt_exists(systems, tmp_path): + train_glob, valid_glob = systems + out_dir = tmp_path / "out_skip" + out_dir.mkdir() + # Place a model.ckpt-100.pt; ask for max_steps=50 → should skip. + longer_ckpt = out_dir / "model.ckpt-100.pt" + longer_ckpt.write_bytes(b"") + + t = DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + max_steps=50, + output_dir=str(out_dir), + ) + with patch("subprocess.run") as run_mock: + result = t.fit() + run_mock.assert_not_called() + assert result == str(longer_ckpt) + + +# --------------------------------------------------------------------------- +# 11. Idempotency: retrain when only a shorter checkpoint exists +# --------------------------------------------------------------------------- + + +def test_idempotency_retrain_when_shorter_ckpt_exists(systems, tmp_path): + train_glob, valid_glob = systems + out_dir = tmp_path / "out_retrain" + out_dir.mkdir() + # Place a model.ckpt-50.pt; ask for max_steps=100 → should retrain. + (out_dir / "model.ckpt-50.pt").write_bytes(b"") + + t = DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + max_steps=100, + output_dir=str(out_dir), + ) + + # Mock subprocess.run so we never call real `dp`. After "training", + # create the model.ckpt-100.pt the production code will look for. + final_ckpt = out_dir / "model.ckpt-100.pt" + + def _fake_run(cmd, *args, **kwargs): + final_ckpt.write_bytes(b"") + + class R: + returncode = 0 + + return R() + + with patch("subprocess.run", side_effect=_fake_run) as run_mock: + result = t.fit() + run_mock.assert_called_once() + assert result == str(final_ckpt) + + +# --------------------------------------------------------------------------- +# 12. Seed propagation +# --------------------------------------------------------------------------- + + +def test_seed_propagation(systems, tmp_path): + train_glob, valid_glob = systems + t = DPATrainer( + pretrained=None, + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + seed=12345, + output_dir=str(tmp_path / "out_seed"), + ) + cfg = t._build_config() + assert cfg["model"]["descriptor"]["seed"] == 12345 + assert cfg["model"]["fitting_net"]["seed"] == 12345 + assert cfg["training"]["seed"] == 12345 + # Top-level "seed" was removed: deepmd 3.1.3 dargs is strict-mode and + # rejects unknown root keys. Seeds live on descriptor, fitting_net, and + # training instead. + assert "seed" not in cfg + + +# --------------------------------------------------------------------------- +# 13. Parser: takes weighted-average (last) match +# --------------------------------------------------------------------------- + + +def test_evaluate_parse_takes_weighted_average(): + """When dp prints per-system + weighted-average blocks, return the + weighted average (last match). + """ + stdout = ( + "PROPERTY MAE : 0.10 units\n" + "PROPERTY RMSE : 0.20 units\n" + "# ----------weighted average of errors-----------\n" + "PROPERTY MAE : 0.05 units\n" + "PROPERTY RMSE : 0.08 units\n" + ) + out = DPATrainer._parse_test_output(stdout) + # Must be the weighted-average (second/last) values, not the per-system + # (first) values. + assert out["mae"] == pytest.approx(0.05) + assert out["rmse"] == pytest.approx(0.08) + + +# --------------------------------------------------------------------------- +# 14. evaluate() combines stdout + stderr +# --------------------------------------------------------------------------- + + +def test_evaluate_combines_stderr(systems, tmp_path): + train_glob, valid_glob = systems + t = DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + output_dir=str(tmp_path / "out_stderr"), + ) + os.makedirs(t.output_dir, exist_ok=True) + (Path(t.output_dir) / "model.ckpt-100.pt").write_bytes(b"") + test_glob = _make_systems(tmp_path, "test_stderr", 5) + + canned_stderr = ( + "DEEPMD INFO # number of test data : 100\n" + "DEEPMD INFO PROPERTY MAE : 0.0123 units\n" + "DEEPMD INFO PROPERTY RMSE : 0.0456 units\n" + ) + + class _Result: + stdout = "" + stderr = canned_stderr + returncode = 0 + + with patch("subprocess.run", return_value=_Result()): + out = t.evaluate(test_glob) + assert out["mae"] == pytest.approx(0.0123) + assert out["rmse"] == pytest.approx(0.0456) + + +# --------------------------------------------------------------------------- +# 15. evaluate() writes datafile and passes -f, not -s +# --------------------------------------------------------------------------- + + +def test_evaluate_writes_datafile_and_uses_f_flag(systems, tmp_path): + """evaluate() must write a datafile with one system per line and + pass it to dp test via -f (single value), not multiplex -s flags. + """ + train_glob, valid_glob = systems + out_dir = tmp_path / "out_datafile" + t = DPATrainer( + train_systems=train_glob, + valid_systems=valid_glob, + type_map=DUMMY_TYPE_MAP, + output_dir=str(out_dir), + ) + os.makedirs(t.output_dir, exist_ok=True) + (Path(t.output_dir) / "model.ckpt-100.pt").write_bytes(b"") + test_glob = _make_systems(tmp_path, "test_df", 5) + + captured_cmd = [] + canned_stderr = ( + "DEEPMD INFO # number of test data : 50\n" + "DEEPMD INFO # number of systems : 5\n" + "DEEPMD INFO PROPERTY MAE : 0.01 units\n" + "DEEPMD INFO PROPERTY RMSE : 0.02 units\n" + ) + + class _Result: + stdout = "" + stderr = canned_stderr + returncode = 0 + + def _capture(cmd, *args, **kwargs): + captured_cmd.extend(cmd) + return _Result() + + with patch("subprocess.run", side_effect=_capture): + out = t.evaluate(test_glob) + + # No -s anywhere; exactly one -f flag. + assert "-s" not in captured_cmd, f"-s should not appear: {captured_cmd}" + assert captured_cmd.count("-f") == 1 + + # -f points to a real datafile with 5 lines. + f_idx = captured_cmd.index("-f") + datafile = captured_cmd[f_idx + 1] + assert os.path.isfile(datafile), f"datafile not written: {datafile}" + with open(datafile) as f: + lines = [l for l in f.read().split("\n") if l.strip()] + assert len(lines) == 5, f"Expected 5 systems in datafile, got {len(lines)}" + + assert out["mae"] == pytest.approx(0.01) + assert out["rmse"] == pytest.approx(0.02) + assert out["n_systems"] == 5 diff --git a/source/tests/dpa_adapt/test_trainer_dim_case_embd.py b/source/tests/dpa_adapt/test_trainer_dim_case_embd.py new file mode 100644 index 0000000000..a81acf1fb2 --- /dev/null +++ b/source/tests/dpa_adapt/test_trainer_dim_case_embd.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Lock DPATrainer._build_fitting_net's dim_case_embd behavior. + +History (the "repeatedly reverted" patch): 2026-05-18 a dim_case_embd=31 +injection was added for FT/LP, because `--finetune --model-branch ` +tried to copy the branch's [159, 240] property head and failed without it. +On 2026-05-20 the FT/LP command was realigned to the paper repo, which uses +`--finetune` WITHOUT --model-branch: single-task fine-tune copies only the +backbone and random-inits the property head at [128, 240]. With no branch +head to size-match, dim_case_embd must NOT be injected (the paper qm9_gap +input.json omits it). + +So: FT/LP fitting_net has no dim_case_embd unless the user sets it +explicitly via fitting_net_params. These tests build config only. +""" + +from __future__ import ( + annotations, +) + +from dpa_adapt.trainer import ( + DPATrainer, +) + +TYPE_MAP = ["H", "C", "N", "O"] +DUMMY_SYS = ["/data/sys"] + + +def _trainer(pretrained, **overrides): + kwargs = { + "pretrained": pretrained, + "train_systems": DUMMY_SYS, + "valid_systems": DUMMY_SYS, + "type_map": TYPE_MAP, + } + kwargs.update(overrides) + return DPATrainer(**kwargs) + + +def test_pretrained_mode_no_dim_case_embd(tmp_path): + """FT/LP (pretrained != None) must NOT inject dim_case_embd: the paper + single-task fine-tune random-inits the property head, so there is no + [159, 240] checkpoint head to match. + """ + ckpt = tmp_path / "ckpt.pt" + ckpt.write_bytes(b"") + t = _trainer(str(ckpt)) + fn = t._build_fitting_net() + assert fn.get("dim_case_embd") is None + + +def test_user_fitting_net_params_can_set_dim_case_embd(tmp_path): + """An explicit user-supplied dim_case_embd is still honored verbatim.""" + ckpt = tmp_path / "ckpt.pt" + ckpt.write_bytes(b"") + t = _trainer(str(ckpt), fitting_net_params={"dim_case_embd": 99}) + fn = t._build_fitting_net() + assert fn["dim_case_embd"] == 99 diff --git a/source/tests/dpa_adapt/test_type_map.py b/source/tests/dpa_adapt/test_type_map.py new file mode 100644 index 0000000000..3317bc08b8 --- /dev/null +++ b/source/tests/dpa_adapt/test_type_map.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for type_map validation and local→global atom-type remapping.""" + +import sys +from unittest.mock import ( + MagicMock, +) + +import numpy as np +import pytest + +# Only stub torch when it is genuinely absent. Injecting a MagicMock into +# sys.modules unconditionally leaks into other test modules during a full +# pytest run (the stub wins the import race and stays session-wide), breaking +# tests that do real tensor math. Same guard as test_predictor.py. +try: + import torch as _unused_torch_for_test +except Exception: + _mock_torch = MagicMock() + _mock_torch.Tensor = type("Tensor", (), {}) + _unused_torch_for_test = _mock_torch + sys.modules.setdefault("torch", _mock_torch) +else: + _unused_torch_for_test.set_default_device(None) + +from dpa_adapt.data.errors import ( + DPADataError, +) +from dpa_adapt.data.loader import ( + load_data, +) +from dpa_adapt.finetuner import ( + DPAFineTuner, + _read_data_type_map, +) + +PERIODIC_PREFIX_9 = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F"] + + +def _make_system(tmp_path, name, type_indices, type_map): + """Create a minimal deepmd/npy system and load it via dpdata.""" + root = tmp_path / name + root.mkdir(parents=True, exist_ok=True) + n_atoms = len(type_indices) + (root / "type.raw").write_text("\n".join(str(i) for i in type_indices) + "\n") + (root / "type_map.raw").write_text("\n".join(type_map) + "\n") + sd = root / "set.000" + sd.mkdir(exist_ok=True) + np.save(sd / "coord.npy", np.zeros((1, n_atoms * 3))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + return load_data(str(root))[0] + + +# --------------------------------------------------------------------------- +# _validate_type_map +# --------------------------------------------------------------------------- + + +class TestValidateTypeMapSubset: + def test_non_prefix_subset_accepted(self, tmp_path): + sys = _make_system(tmp_path, "qm9", [0, 1, 2], ["H", "C", "N"]) + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + ft._validate_type_map([], [sys]) + ft._validate_type_map(["H", "C", "N", "O", "F"], [sys]) + + def test_empty_checkpoint_skips(self, tmp_path): + sys = _make_system(tmp_path, "sys", [0], ["Xx"]) + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = [] + ft._validate_type_map(["Xx"], [sys]) + + def test_no_type_map_raw_skips(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n") + # No type_map.raw → no atom_names + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((1, 3))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + sys = load_data(str(root))[0] + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + ft._validate_type_map([], [sys]) + + +class TestValidateTypeMapUnsupported: + def test_unsupported_in_user_type_map(self, tmp_path): + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + with pytest.raises(DPADataError) as ei: + ft._validate_type_map(["H", "C", "Xx"], []) + msg = str(ei.value) + assert "not supported" in msg + assert "Xx" in msg + assert "prefix" not in msg.lower() + + def test_unsupported_in_data_type_map(self, tmp_path): + sys = _make_system(tmp_path, "sys", [0, 1], ["H", "Xx"]) + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + with pytest.raises(DPADataError) as ei: + ft._validate_type_map([], [sys]) + msg = str(ei.value) + assert "not supported" in msg + assert "Xx" in msg + assert "prefix" not in msg.lower() + + +# --------------------------------------------------------------------------- +# _remap_atom_types +# --------------------------------------------------------------------------- + + +class TestRemapAtomTypes: + def test_remap_via_atom_names(self, tmp_path): + sys = _make_system(tmp_path, "qm9", [0, 1, 2, 3, 4], ["H", "C", "N", "O", "F"]) + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + atom_types = np.array([0, 1, 2, 3, 4], dtype=np.int64) + out = ft._remap_atom_types(atom_types, sys) + np.testing.assert_array_equal(out, [0, 5, 6, 7, 8]) + + def test_remap_with_arbitrary_order(self, tmp_path): + sys = _make_system(tmp_path, "sys", [0, 1, 0], ["O", "H"]) + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + out = ft._remap_atom_types(np.array([0, 1, 0]), sys) + np.testing.assert_array_equal(out, [7, 0, 7]) + + def test_fallback_to_user_type_map(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n1\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((1, 6))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + sys = load_data(str(root))[0] + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + ft.type_map = ["C", "F"] + out = ft._remap_atom_types(np.array([0, 1]), sys) + np.testing.assert_array_equal(out, [5, 8]) + + def test_no_type_map_in_range_passes_through(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n1\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((1, 6))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + sys = load_data(str(root))[0] + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + out = ft._remap_atom_types(np.array([0, 1]), sys) + np.testing.assert_array_equal(out, [0, 1]) + + def test_no_type_map_out_of_range_raises(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n42\n") + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((1, 6))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + sys = load_data(str(root))[0] + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + with pytest.raises(DPADataError, match="out of range"): + ft._remap_atom_types(np.array([0, 42]), sys) + + def test_unsupported_element_in_data_type_map_raises(self, tmp_path): + sys = _make_system(tmp_path, "sys", [0], ["Xx"]) + ft = DPAFineTuner(pretrained="fake.pt") + ft._checkpoint_type_map = list(PERIODIC_PREFIX_9) + with pytest.raises(DPADataError) as ei: + ft._remap_atom_types(np.array([0]), sys) + assert "not supported" in str(ei.value) + assert "Xx" in str(ei.value) + + +# --------------------------------------------------------------------------- +# _read_data_type_map +# --------------------------------------------------------------------------- + + +class TestReadDataTypeMap: + def test_reads_elements(self, tmp_path): + sys = _make_system(tmp_path, "sys", [0, 1, 2], ["H", "C", "N"]) + assert _read_data_type_map(sys) == ["H", "C", "N"] + + def test_returns_empty_when_missing(self, tmp_path): + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n") + # No type_map.raw + sd = root / "set.000" + sd.mkdir() + np.save(sd / "coord.npy", np.zeros((1, 3))) + np.save(sd / "box.npy", np.eye(3).reshape(1, 9)) + sys = load_data(str(root))[0] + assert _read_data_type_map(sys) == [] + + def test_strips_blank_lines(self, tmp_path): + sys = _make_system(tmp_path, "sys", [0, 1], ["H", "C"]) + assert _read_data_type_map(sys) == ["H", "C"] diff --git a/source/tests/dpa_adapt/test_validate.py b/source/tests/dpa_adapt/test_validate.py new file mode 100644 index 0000000000..3bd6baae43 --- /dev/null +++ b/source/tests/dpa_adapt/test_validate.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for check_data() — content-level sanity checks on dpdata systems.""" + +import numpy as np +import pytest + +from dpa_adapt.data.errors import ( + DPADataError, +) +from dpa_adapt.data.loader import ( + load_data, +) +from dpa_adapt.data.validate import ( + check_data, +) + + +def _make_set_dir( + set_dir, *, coord=None, box=None, energy=None, force=None, n_frames=3, n_atoms=2 +): + set_dir.mkdir(parents=True) + if coord is None: + coord = np.random.RandomState(0).rand(n_frames, n_atoms * 3) + if box is None: + box = (np.eye(3) * 10.0).reshape(1, 9).repeat(n_frames, 0) + np.save(set_dir / "coord.npy", coord) + np.save(set_dir / "box.npy", box) + if energy is not None: + np.save(set_dir / "energy.npy", energy) + if force is not None: + np.save(set_dir / "force.npy", force) + + +def _system(tmp_path, **set_kwargs): + """Create a valid deepmd/npy system, load via dpdata, return it.""" + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n0\n") + (root / "type_map.raw").write_text("H\nH\n") + _make_set_dir(root / "set.000", **set_kwargs) + return load_data(str(root))[0] + + +# --------------------------------------------------------------------------- +# Clean data +# --------------------------------------------------------------------------- + + +def test_clean_data_no_issues(tmp_path): + system = _system(tmp_path) + issues = check_data(system) + assert len(issues) == 0 + + +def test_structure_only_no_energy_force_is_clean(tmp_path): + # Create system with only coords + box (no energy/force) + root = tmp_path / "sys" + root.mkdir() + (root / "type.raw").write_text("0\n0\n") + (root / "type_map.raw").write_text("H\nH\n") + _make_set_dir(root / "set.000") + # Remove energy.npy and force.npy before loading + system = load_data(str(root))[0] + issues = check_data(system) + assert len(issues) == 0, [i.description for i in issues] + + +# --------------------------------------------------------------------------- +# NaN / Inf +# --------------------------------------------------------------------------- + + +def test_energy_nan_is_error(tmp_path): + system = _system(tmp_path, energy=np.array([np.nan, 0.0, 0.0])) + issues = check_data(system) + assert any("energies" in i.file and "non-finite" in i.description for i in issues) + + +def test_force_inf_is_error(tmp_path): + system = _system(tmp_path) + # Inject bad forces after loading (dpdata may refuse to load inf arrays) + system.data["forces"] = np.full((3, 2, 3), np.inf) + issues = check_data(system) + assert any("forces" in i.file and "non-finite" in i.description for i in issues) + + +def test_box_nan_is_error(tmp_path): + system = _system(tmp_path, box=np.full((3, 9), np.nan)) + issues = check_data(system) + assert any("cells" in i.file and "non-finite" in i.description for i in issues) + + +# --------------------------------------------------------------------------- +# Degenerate box +# --------------------------------------------------------------------------- + + +def test_degenerate_box_is_error_with_det_in_description(tmp_path): + system = _system(tmp_path, box=np.zeros((3, 9))) + issues = check_data(system) + assert any("cells" in i.file and "degenerate" in i.description for i in issues) + + +def test_box_det_tolerance_boundary(tmp_path): + # A very thin but valid box near the default tolerance + box = np.tile(np.diag([10.0, 1e-11, 10.0]).ravel(), (3, 1)) + system = _system(tmp_path, box=box) + issues = check_data(system) + # |det| = 10 * 1e-11 * 10 = 1e-9, which is > 1e-10 default tol → clean + assert not any("degenerate" in i.description for i in issues) + + +def test_box_det_tol_is_configurable(tmp_path): + box = np.tile(np.diag([10.0, 1e-11, 10.0]).ravel(), (3, 1)) + system = _system(tmp_path, box=box) + issues = check_data(system, box_det_tol=1e-8) + # |det| = 1e-9 < 1e-8 tol → degenerate + assert any("degenerate" in i.description for i in issues) + + +# --------------------------------------------------------------------------- +# Magnitude warnings +# --------------------------------------------------------------------------- + + +def test_energy_magnitude_warning(tmp_path): + system = _system(tmp_path, energy=np.array([1e5, 0.0, 0.0])) + issues = check_data(system) + assert any( + "energies" in i.file and "suspicious magnitude" in i.description for i in issues + ) + + +def test_force_magnitude_warning(tmp_path): + system = _system(tmp_path) + big_force = np.zeros((3, 2, 3)) + big_force[0, 0, 0] = 5000.0 + system.data["forces"] = big_force + issues = check_data(system) + assert any( + "forces" in i.file and "suspicious magnitude" in i.description for i in issues + ) + + +# --------------------------------------------------------------------------- +# Frame count alignment +# --------------------------------------------------------------------------- + + +def test_frame_count_mismatch_is_error(tmp_path): + system = _system(tmp_path, coord=np.zeros((3, 6))) + system.data["energies"] = np.zeros(5) # mismatched + issues = check_data(system) + assert any( + "energies" in i.file and "frame counts must align" in i.description + for i in issues + ) + + +# --------------------------------------------------------------------------- +# Strict mode +# --------------------------------------------------------------------------- + + +def test_strict_raises_on_first_issue(tmp_path): + system = _system(tmp_path, energy=np.array([np.nan, 0.0, 0.0])) + with pytest.raises(DPADataError, match="check_data"): + check_data(system, strict=True) + + +# --------------------------------------------------------------------------- +# List input +# --------------------------------------------------------------------------- + + +def test_list_input_aggregates_across_systems(tmp_path): + s1 = _system(tmp_path, energy=np.array([np.nan, 0.0, 0.0])) + # use a different tmp subdir to avoid conflict + s2_root = tmp_path / "sys2" + s2_root.mkdir() + (s2_root / "type.raw").write_text("0\n0\n") + (s2_root / "type_map.raw").write_text("H\nH\n") + from dpa_adapt.data.loader import ( + load_data, + ) + + _make_set_dir(s2_root / "set.000") + s2 = load_data(str(s2_root))[0] + issues = check_data([s1, s2]) + assert len(issues) >= 1 + + +def test_set_dirs_checked_in_numeric_order(tmp_path): + # dpdata loads all set.* dirs; check covers all frames + system = _system(tmp_path, energy=np.array([1e5, 0.0, 0.0])) + issues = check_data(system) + # magnitude warning should reference frame 0 + mag_issues = [i for i in issues if "suspicious magnitude" in i.description] + assert len(mag_issues) >= 1 + + +def test_issue_namedtuple_shape(tmp_path): + system = _system(tmp_path, energy=np.array([np.nan, 0.0, 0.0])) + issues = check_data(system) + assert len(issues) > 0 + issue = issues[0] + assert issue.severity in ("warning", "error") + assert isinstance(issue.system, str) + assert isinstance(issue.file, str) + assert isinstance(issue.description, str)