From 158504866cf9ef9ddb44db036ecb8800d4b12e55 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 26 Jun 2026 00:50:25 +0800 Subject: [PATCH 1/6] feat(jax): add JAX-MD interface --- deepmd/jax/jax_md.py | 423 ++++++++++++++++++++++++++++ doc/third-party/index.rst | 1 + doc/third-party/jaxmd.md | 115 ++++++++ examples/water/jax_md/README.md | 44 +++ examples/water/jax_md/run_jax_md.py | 177 ++++++++++++ source/tests/jax/test_jax_md.py | 136 +++++++++ 6 files changed, 896 insertions(+) create mode 100644 deepmd/jax/jax_md.py create mode 100644 doc/third-party/jaxmd.md create mode 100644 examples/water/jax_md/README.md create mode 100644 examples/water/jax_md/run_jax_md.py create mode 100644 source/tests/jax/test_jax_md.py diff --git a/deepmd/jax/jax_md.py b/deepmd/jax/jax_md.py new file mode 100644 index 0000000000..1e22fea5e5 --- /dev/null +++ b/deepmd/jax/jax_md.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""JAX-MD compatible interface for JAX DeePMD models.""" + +import inspect +import json +from collections.abc import ( + Callable, + Sequence, +) +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, +) +from deepmd.jax.env import ( + jax, + jnp, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.model.hlo import ( + HLO, +) +from deepmd.jax.utils.serialization import ( + serialize_from_file, +) + +Array = jax.Array +EnergyFn = Callable[..., Array] + +_JAX_MD_SENTINEL = object() + + +def load_model(model: str | Path | Any) -> Any: + """Load a JAX DeePMD model, or return an already constructed model.""" + if not isinstance(model, str | Path): + return model + + model_path = str(Path(model).resolve()) + if model_path.endswith(".jax"): + data = serialize_from_file(model_path) + jax_model = BaseModel.deserialize(data["model"]) + jax_model.model_def_script = json.dumps(data.get("model_def_script", {})) + return jax_model + if model_path.endswith(".hlo"): + return _load_hlo_model(model_path) + raise ValueError("JAX-MD interface supports .jax checkpoints and .hlo models.") + + +def energy_fn( + model: str | Path | Any, + atom_types: Sequence[int | str] | Array, + *, + box: Array | Sequence[float] | None = None, + displacement_fn: Callable[..., Array] | None = None, + fparam: Array | Sequence[float] | None = None, + aparam: Array | Sequence[float] | None = None, + charge_spin: Array | Sequence[float] | None = None, +) -> EnergyFn: + """Create a JAX-MD compatible ``energy_fn(R, neighbor=None, **kwargs)``. + + The returned function accepts a single-frame coordinate array ``R`` with + shape ``(natoms, 3)`` and returns a scalar total energy. If a JAX-MD dense + neighbor list object is passed as ``neighbor``, DeePMD uses it to build the + lower-interface extended system. Otherwise DeePMD builds its native JAX + neighbor list from ``R`` and ``box``. + """ + jax_model = load_model(model) + default_atom_types = _normalize_atom_types(jax_model, atom_types) + default_box = _normalize_box(box) + default_fparam = fparam + default_aparam = aparam + default_charge_spin = charge_spin + + def apply( + R: Array, + *, + neighbor: Any | None = None, + atom_types: Sequence[int | str] | Array | None = None, + box: Array | Sequence[float] | None | object = _JAX_MD_SENTINEL, + fparam: Array | Sequence[float] | None | object = _JAX_MD_SENTINEL, + aparam: Array | Sequence[float] | None | object = _JAX_MD_SENTINEL, + charge_spin: Array | Sequence[float] | None | object = _JAX_MD_SENTINEL, + **kwargs: Any, + ) -> Array: + """Evaluate a single-frame total energy in the JAX-MD call convention.""" + coord = _normalize_coord(R) + atype = ( + default_atom_types + if atom_types is None + else _normalize_atom_types(jax_model, atom_types) + ) + current_box = default_box if box is _JAX_MD_SENTINEL else _normalize_box(box) + current_fparam = default_fparam if fparam is _JAX_MD_SENTINEL else fparam + current_aparam = default_aparam if aparam is _JAX_MD_SENTINEL else aparam + current_charge_spin = ( + default_charge_spin if charge_spin is _JAX_MD_SENTINEL else charge_spin + ) + fparam_batch = _normalize_fparam(jax_model, current_fparam, coord.dtype) + aparam_batch = _normalize_aparam( + jax_model, current_aparam, coord.shape[0], coord.dtype + ) + charge_spin_batch = _normalize_charge_spin(current_charge_spin, coord.dtype) + + if neighbor is None: + model_kwargs = { + "box": None if current_box is None else current_box[None, ...], + "fparam": fparam_batch, + "aparam": aparam_batch, + } + if charge_spin_batch is not None: + if not _accepts_keyword(jax_model, "charge_spin"): + raise TypeError("This model does not accept charge_spin input.") + model_kwargs["charge_spin"] = charge_spin_batch + ret = jax_model( + coord[None, ...], + atype[None, ...], + **model_kwargs, + ) + else: + ret = _eval_with_jax_md_neighbor( + jax_model, + coord, + atype, + neighbor, + displacement_fn, + fparam_batch, + aparam_batch, + charge_spin_batch, + kwargs, + ) + return _extract_energy(ret) + + return apply + + +def force_fn(energy: EnergyFn) -> EnergyFn: + """Create a JAX-MD compatible force function from an energy function.""" + + def apply(R: Array, **kwargs: Any) -> Array: + """Evaluate forces by differentiating the supplied energy function.""" + return -jax.grad(lambda coord: energy(coord, **kwargs))(R) + + return apply + + +def neighbor_list( + model: str | Path | Any, + displacement_or_metric: Callable[..., Array], + box: Array | Sequence[float], + **kwargs: Any, +) -> Any: + """Create a dense JAX-MD neighbor-list function using the model cutoff.""" + try: + from jax_md import ( + partition, + ) + except ImportError as exc: + raise ImportError( + "The JAX-MD neighbor-list helper requires the optional jax-md package." + ) from exc + + jax_model = load_model(model) + kwargs.setdefault("format", partition.NeighborListFormat.Dense) + return partition.neighbor_list( + displacement_or_metric, + box, + r_cutoff=jax_model.get_rcut(), + **kwargs, + ) + + +def as_jax_md( + model: str | Path | Any, + displacement_or_metric: Callable[..., Array], + box: Array | Sequence[float], + atom_types: Sequence[int | str] | Array, + **kwargs: Any, +) -> tuple[Any, EnergyFn]: + """Return ``(neighbor_fn, energy_fn)`` in the usual JAX-MD style.""" + jax_model = load_model(model) + potential = energy_fn( + jax_model, + atom_types, + box=_normalize_box(box), + displacement_fn=displacement_or_metric, + fparam=kwargs.pop("fparam", None), + aparam=kwargs.pop("aparam", None), + charge_spin=kwargs.pop("charge_spin", None), + ) + nlist_fn = neighbor_list(jax_model, displacement_or_metric, box, **kwargs) + return nlist_fn, potential + + +def _load_hlo_model(model_file: str) -> HLO: + """Load a DeePMD HLO model into the JAX inference wrapper.""" + model_data = load_dp_model(model_file) + return HLO( + stablehlo=model_data["@variables"]["stablehlo"].tobytes(), + stablehlo_atomic_virial=model_data["@variables"][ + "stablehlo_atomic_virial" + ].tobytes(), + stablehlo_no_ghost=model_data["@variables"]["stablehlo_no_ghost"].tobytes(), + stablehlo_atomic_virial_no_ghost=model_data["@variables"][ + "stablehlo_atomic_virial_no_ghost" + ].tobytes(), + model_def_script=json.dumps(model_data["model_def_script"]), + **model_data["constants"], + ) + + +def _normalize_atom_types(model: Any, atom_types: Sequence[int | str] | Array) -> Array: + """Convert type names or type indexes to a JAX int32 type array.""" + if isinstance(atom_types, jax.Array): + return atom_types.astype(jnp.int32) + atom_types_list = list(atom_types) + if atom_types_list and isinstance(atom_types_list[0], str): + type_map = {name: idx for idx, name in enumerate(model.get_type_map())} + atom_types_list = [type_map[str(atom_type)] for atom_type in atom_types_list] + return jnp.asarray(atom_types_list, dtype=jnp.int32) + + +def _accepts_keyword(callable_obj: Callable[..., Any], keyword: str) -> bool: + """Return whether a callable signature accepts a keyword argument.""" + try: + signature = inspect.signature(callable_obj) + except (TypeError, ValueError): + return False + for parameter in signature.parameters.values(): + if parameter.kind == inspect.Parameter.VAR_KEYWORD: + return True + return keyword in signature.parameters + + +def _normalize_coord(coord: Array) -> Array: + """Validate and convert a single-frame coordinate array.""" + coord = jnp.asarray(coord) + if coord.ndim != 2 or coord.shape[-1] != 3: + raise ValueError("JAX-MD DeePMD energy functions require R with shape (N, 3).") + return coord + + +def _normalize_box(box: Array | Sequence[float] | None) -> Array | None: + """Convert supported box representations to a 3-by-3 cell matrix.""" + if box is None: + return None + box_array = jnp.asarray(box) + if box_array.ndim == 0: + return jnp.eye(3, dtype=box_array.dtype) * box_array + if box_array.shape == (3,): + return jnp.diag(box_array) + if box_array.shape == (9,): + return box_array.reshape(3, 3) + if box_array.shape == (3, 3): + return box_array + raise ValueError("box must be a scalar, shape (3,), shape (9,), or shape (3, 3).") + + +def _normalize_fparam( + model: Any, fparam: Array | Sequence[float] | None, dtype: Any +) -> Array | None: + """Convert frame parameters to DeePMD's batched JAX input shape.""" + dim_fparam = model.get_dim_fparam() + if dim_fparam == 0: + return None + if fparam is None: + if getattr(model, "has_default_fparam", lambda: False)(): + default_fparam = model.get_default_fparam() + if default_fparam is not None: + return jnp.asarray(default_fparam, dtype=dtype).reshape(1, dim_fparam) + raise ValueError("This model requires fparam, but none was provided.") + return jnp.asarray(fparam, dtype=dtype).reshape(1, dim_fparam) + + +def _normalize_aparam( + model: Any, + aparam: Array | Sequence[float] | None, + natoms: int, + dtype: Any, +) -> Array | None: + """Convert atomic parameters to DeePMD's batched JAX input shape.""" + dim_aparam = model.get_dim_aparam() + if dim_aparam == 0: + return None + if aparam is None: + raise ValueError("This model requires aparam, but none was provided.") + aparam_array = jnp.asarray(aparam, dtype=dtype) + if aparam_array.shape == (dim_aparam,): + aparam_array = jnp.tile(aparam_array[None, :], (natoms, 1)) + return aparam_array.reshape(1, natoms, dim_aparam) + + +def _normalize_charge_spin( + charge_spin: Array | Sequence[float] | None, dtype: Any +) -> Array | None: + """Convert charge-spin parameters to a batched JAX input array.""" + if charge_spin is None: + return None + return jnp.asarray(charge_spin, dtype=dtype)[None, ...] + + +def _eval_with_jax_md_neighbor( + model: Any, + coord: Array, + atype: Array, + neighbor: Any, + displacement_fn: Callable[..., Array] | None, + fparam: Array | None, + aparam: Array | None, + charge_spin: Array | None, + displacement_kwargs: dict[str, Any], +) -> dict[str, Array]: + """Evaluate a DeePMD model with a precomputed dense JAX-MD neighbor list.""" + if not hasattr(model, "call_lower"): + raise TypeError("JAX-MD neighbor lists require a DeePMD model with call_lower.") + extended_coord, extended_atype, nlist, mapping = _jax_md_neighbor_to_lower_inputs( + coord, + atype, + neighbor, + displacement_fn, + displacement_kwargs, + ) + return model.call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + ) + + +def _jax_md_neighbor_to_lower_inputs( + coord: Array, + atype: Array, + neighbor: Any, + displacement_fn: Callable[..., Array] | None, + displacement_kwargs: dict[str, Any], +) -> tuple[Array, Array, Array, Array]: + """Convert a dense JAX-MD neighbor list to DeePMD lower-interface inputs. + + DeePMD's lower interface expects an extended coordinate array plus a + neighbor list that points from local atoms into that extended array. JAX-MD + dense neighbor lists store original atom indexes instead, so each valid edge + is materialized as a ghost coordinate with the minimum-image displacement + supplied by the JAX-MD displacement function. + """ + if not hasattr(neighbor, "idx"): + raise TypeError("Expected a JAX-MD neighbor object with an idx attribute.") + nlist = jnp.asarray(neighbor.idx) + natoms = coord.shape[0] + if nlist.ndim != 2 or nlist.shape[0] != natoms: + raise ValueError( + "Only dense JAX-MD neighbor lists with shape (N, max_occupancy) are supported." + ) + + valid = (nlist >= 0) & (nlist < natoms) + safe_nlist = jnp.where(valid, nlist, 0).astype(jnp.int32) + neighbor_coord = coord[safe_nlist] + central_coord = jnp.broadcast_to(coord[:, None, :], neighbor_coord.shape) + + if displacement_fn is None: + ghost_coord = neighbor_coord + else: + displacement = jax.vmap( + jax.vmap( + lambda central, neighbor: displacement_fn( + central, neighbor, **displacement_kwargs + ) + ) + )(central_coord, neighbor_coord) + # JAX-MD displacement functions use the Ra - Rb convention. + ghost_coord = central_coord - displacement + + ghost_coord = jnp.where(valid[..., None], ghost_coord, central_coord) + ghost_atype = jnp.where(valid, atype[safe_nlist], -1) + nedge = nlist.size + ghost_start = natoms + ghost_indices = jnp.arange(ghost_start, ghost_start + nedge, dtype=jnp.int64) + lower_nlist = jnp.where(valid, ghost_indices.reshape(nlist.shape), -1) + + extended_coord = jnp.concatenate( + [coord, ghost_coord.reshape(nedge, 3)], + axis=0, + )[None, ...] + extended_atype = jnp.concatenate( + [atype, ghost_atype.reshape(nedge)], + axis=0, + )[None, ...] + mapping = jnp.concatenate( + [ + jnp.arange(natoms, dtype=jnp.int64), + safe_nlist.reshape(nedge).astype(jnp.int64), + ], + axis=0, + )[None, ...] + return extended_coord, extended_atype, lower_nlist[None, ...], mapping + + +def _extract_energy(ret: Any) -> Array: + """Extract a scalar total energy from a DeePMD model return value.""" + if isinstance(ret, tuple): + ret = ret[0] + for key in ("energy", "energy_redu"): + if key in ret and ret[key] is not None: + return jnp.ravel(ret[key])[0] + raise KeyError("Model output does not contain an energy value.") + + +__all__ = [ + "as_jax_md", + "energy_fn", + "force_fn", + "load_model", + "neighbor_list", +] diff --git a/doc/third-party/index.rst b/doc/third-party/index.rst index cd0726a4bb..e32ae4ae57 100644 --- a/doc/third-party/index.rst +++ b/doc/third-party/index.rst @@ -9,6 +9,7 @@ Note that the model for inference is required to be compatible with the DeePMD-k dpdata ase lammps-command + jaxmd ipi gromacs out-of-deepmd-kit diff --git a/doc/third-party/jaxmd.md b/doc/third-party/jaxmd.md new file mode 100644 index 0000000000..632e21e22d --- /dev/null +++ b/doc/third-party/jaxmd.md @@ -0,0 +1,115 @@ +# Run MD with JAX-MD + +:::{note} +See [Environment variables](../env.md) for the runtime environment variables. +::: + +DeePMD-kit provides a JAX-MD compatible interface for DeePMD models trained with +the JAX backend. The interface adapts a DeePMD model to the usual JAX-MD style, +where a neighbor-list factory and an energy function are passed to JAX-MD +simulation routines. + +The interface is available from `deepmd.jax.jax_md`. + +## Requirements + +Install DeePMD-kit with the JAX backend and install +[JAX-MD](https://github.com/jax-md/jax-md). The JAX-MD package is an optional +runtime dependency and is not required for other DeePMD-kit interfaces. + +## Basic usage + +The most common entry point is `as_jax_md`, which returns a JAX-MD neighbor-list +function and a potential energy function: + +```python +import jax +import jax.numpy as jnp +from jax_md import space + +from deepmd.jax.jax_md import as_jax_md + +box = jnp.asarray([12.4447, 12.4447, 12.4447]) +coord = jnp.asarray(...) # shape: (natoms, 3) +atom_types = jnp.asarray(...) # shape: (natoms,), DeePMD type indexes + +displacement_fn, shift_fn = space.periodic(box) +neighbor_fn, potential_fn = as_jax_md( + "model.ckpt.jax", + displacement_fn, + box, + atom_types, + dr_threshold=0.2, + capacity_multiplier=1.5, +) + +neighbor = neighbor_fn.allocate(coord) +energy = potential_fn(coord, neighbor=neighbor) +force = -jax.grad(lambda x: potential_fn(x, neighbor=neighbor))(coord) +``` + +The returned `potential_fn` accepts a single-frame coordinate array with shape +`(natoms, 3)` and returns the scalar total energy. The optional `neighbor` +argument should be a dense JAX-MD neighbor list allocated by the returned +`neighbor_fn`. + +## Running dynamics + +The potential can be used with JAX-MD simulation routines. A minimal NVE loop +looks like: + +```python +from jax_md import simulate + +K_B_EV_PER_K = 8.617333262145e-5 +kT = K_B_EV_PER_K * 330.0 +mass = jnp.ones((coord.shape[0], 1)) + +init_fn, step_fn = simulate.nve(potential_fn, shift_fn, dt=0.0005) +state = init_fn(jax.random.key(0), coord, kT=kT, mass=mass, neighbor=neighbor) + +for _ in range(10): + neighbor = neighbor_fn.update(state.position, neighbor) + state = step_fn(state, neighbor=neighbor) +``` + +For a complete water example using the same 192-atom configuration as the +LAMMPS example, see `examples/water/jax_md`. + +## Model files + +`deepmd.jax.jax_md.load_model` accepts: + +- a DeePMD JAX checkpoint path ending in `.jax`, +- a DeePMD HLO model path ending in `.hlo`, +- an already constructed JAX DeePMD model object. + +The `atom_types` argument may be an integer array of DeePMD type indexes. It +may also be a sequence of type names if the model has a `type_map`. + +## Neighbor lists + +The helper `neighbor_list` creates a dense JAX-MD neighbor-list function using +the model cutoff: + +```python +from deepmd.jax.jax_md import energy_fn, neighbor_list + +neighbor_fn = neighbor_list("model.ckpt.jax", displacement_fn, box) +potential_fn = energy_fn( + "model.ckpt.jax", + atom_types, + box=box, + displacement_fn=displacement_fn, +) +``` + +Only dense JAX-MD neighbor lists are currently supported. If the neighbor-list +buffer overflows during a simulation, increase `capacity_multiplier` or rebuild +the neighbor list with a larger capacity. + +## Units + +The JAX-MD interface does not perform unit conversion. Coordinates, box +vectors, energies, forces, masses, and timesteps should be provided in units +consistent with the DeePMD model and the chosen JAX-MD simulation setup. diff --git a/examples/water/jax_md/README.md b/examples/water/jax_md/README.md new file mode 100644 index 0000000000..991f456791 --- /dev/null +++ b/examples/water/jax_md/README.md @@ -0,0 +1,44 @@ +# JAX-MD water example + +This example runs a short JAX-MD NVE smoke simulation using a DeePMD JAX +checkpoint and the same 192-atom water configuration used by the LAMMPS example +in `../lmp/water.lmp`. + +It is intentionally small so it can be used as an integration check. The +JAX-MD run itself is short; the checkpoint should be produced from the existing +`../se_e2_a` water training directory or supplied with `--model`. The script +uses dpdata to read the LAMMPS data file. + +## Train a JAX checkpoint + +Reuse the existing `se_e2_a` training input: + +```bash +cd ../se_e2_a +dp --jax train input.json --skip-neighbor-stat +``` + +This writes `model.ckpt.jax`, a stable checkpoint pointer to the latest +checkpoint directory. The checked-in `../se_e2_a/input.json` is a full training +example; for a quick integration check, use an existing checkpoint or make a +scratch copy with smaller `training.numb_steps` and `training.save_freq`. + +## Run JAX-MD + +```bash +cd ../jax_md +python run_jax_md.py --model ../se_e2_a/model.ckpt.jax --steps 10 +``` + +The script prints the JAX backend/device, neighbor-list shape, and a small +thermo table. It uses: + +- `jax_md.space.periodic` for the periodic cubic water box, +- `jax_md.partition.neighbor_list` for a dense JAX-MD neighbor list, +- `deepmd.jax.jax_md.as_jax_md` to adapt the DeePMD checkpoint to a JAX-MD + potential, +- `dpdata.System` to load the LAMMPS water data file. + +The default timestep, temperature, and random seed follow `../lmp/in.lammps` +(`0.0005`, `330 K`, `23456789`). Masses are taken from the LAMMPS input +(`O=16`, `H=2`) and converted to `eV ps^2 / A^2` for metal-style units. diff --git a/examples/water/jax_md/run_jax_md.py b/examples/water/jax_md/run_jax_md.py new file mode 100644 index 0000000000..2c2e80fdec --- /dev/null +++ b/examples/water/jax_md/run_jax_md.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Run a short JAX-MD trajectory with a DeePMD JAX checkpoint.""" + +from __future__ import ( + annotations, +) + +import argparse +import sys +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import numpy as np +from jax_md import ( + quantity, + simulate, + space, +) + +K_B_EV_PER_K = 8.617333262145e-5 +AMU_TO_EV_PS2_PER_A2 = 1.0364269656262175e-4 +WATER_TYPE_MAP = ("O", "H") +WATER_MASS_AMU = { + "O": 16.0, + "H": 2.0, +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run DeePMD JAX checkpoint through JAX-MD on water.lmp." + ) + parser.add_argument( + "--model", + default="../se_e2_a/model.ckpt.jax", + help="Path to a DeePMD JAX checkpoint directory or stable checkpoint pointer.", + ) + parser.add_argument( + "--data", + default="../lmp/water.lmp", + help="LAMMPS data file used as initial coordinates.", + ) + parser.add_argument("--steps", type=int, default=10, help="NVE integration steps.") + parser.add_argument( + "--dt", + type=float, + default=0.0005, + help="Timestep in ps, matching the LAMMPS metal-unit example.", + ) + parser.add_argument( + "--temperature", + type=float, + default=330.0, + help="Initial temperature in K.", + ) + parser.add_argument( + "--seed", + type=int, + default=23456789, + help="Random seed matching the LAMMPS velocity command.", + ) + parser.add_argument( + "--dr-threshold", + type=float, + default=0.2, + help="JAX-MD neighbor-list update threshold in Angstrom.", + ) + parser.add_argument( + "--capacity-multiplier", + type=float, + default=1.5, + help="JAX-MD neighbor-list capacity multiplier.", + ) + return parser.parse_args() + + +def read_lammps_water( + path: Path, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Read the water LAMMPS data file with dpdata.""" + import dpdata + + system = dpdata.System(path, fmt="lammps/lmp", type_map=list(WATER_TYPE_MAP)) + coord = np.asarray(system.data["coords"][0], dtype=np.float64) + atom_types = np.asarray(system.data["atom_types"], dtype=np.int32) + cell = np.asarray(system.data["cells"][0], dtype=np.float64) + off_diagonal = cell - np.diag(np.diag(cell)) + if np.any(np.abs(off_diagonal) > 1e-12): + raise ValueError("This JAX-MD example only supports orthogonal boxes.") + type_names = np.asarray(WATER_TYPE_MAP, dtype=object)[atom_types] + masses = np.asarray([WATER_MASS_AMU[name] for name in type_names], dtype=np.float64) + masses = masses[:, None] * AMU_TO_EV_PS2_PER_A2 + box = np.diag(cell) + return coord, atom_types, masses, box + + +def emit(line: str) -> None: + sys.stdout.write(line + "\n") + + +def main() -> None: + from deepmd.jax.env import ( + jax, + jnp, + ) + from deepmd.jax.jax_md import ( + as_jax_md, + ) + + args = parse_args() + backend = jax.default_backend() + devices = jax.devices() + + coord_np, atom_types_np, masses_np, box_np = read_lammps_water(Path(args.data)) + coord = jnp.asarray(coord_np) + atom_types = jnp.asarray(atom_types_np) + masses = jnp.asarray(masses_np) + box = jnp.asarray(box_np) + kT = K_B_EV_PER_K * args.temperature + + displacement_fn, shift_fn = space.periodic(box) + neighbor_fn, potential_fn = as_jax_md( + args.model, + displacement_fn, + box, + atom_types, + dr_threshold=args.dr_threshold, + capacity_multiplier=args.capacity_multiplier, + ) + neighbor = neighbor_fn.allocate(coord) + init_fn, step_fn = simulate.nve(potential_fn, shift_fn, dt=args.dt) + key = jax.random.key(args.seed) + state = init_fn(key, coord, kT=kT, mass=masses, neighbor=neighbor) + + emit(f"jax_backend {backend}") + emit("jax_devices " + ", ".join(str(device) for device in devices)) + emit(f"neighbor_idx_shape {tuple(neighbor.idx.shape)}") + emit("# step potential_eV kinetic_eV temperature_K neighbor_overflow") + + def thermo(current_state: Any, current_neighbor: Any) -> tuple[Any, Any, Any]: + energy = potential_fn(current_state.position, neighbor=current_neighbor) + kinetic = quantity.kinetic_energy( + momentum=current_state.momentum, + mass=current_state.mass, + ) + temperature = quantity.temperature( + momentum=current_state.momentum, + mass=current_state.mass, + ) + return energy, kinetic, temperature + + @jax.jit + def md_step(current_state: Any, current_neighbor: Any) -> tuple[Any, ...]: + current_neighbor = neighbor_fn.update(current_state.position, current_neighbor) + current_state = step_fn(current_state, neighbor=current_neighbor) + return current_state, current_neighbor, *thermo(current_state, current_neighbor) + + energy, kinetic, temperature = thermo(state, neighbor) + emit( + f"0 {float(energy):.12e} {float(kinetic):.12e} " + f"{float(temperature / K_B_EV_PER_K):.6f} {bool(neighbor.did_buffer_overflow)}" + ) + for step in range(1, args.steps + 1): + state, neighbor, energy, kinetic, temperature = md_step(state, neighbor) + emit( + f"{step} {float(energy):.12e} {float(kinetic):.12e} " + f"{float(temperature / K_B_EV_PER_K):.6f} " + f"{bool(neighbor.did_buffer_overflow)}" + ) + + +if __name__ == "__main__": + main() diff --git a/source/tests/jax/test_jax_md.py b/source/tests/jax/test_jax_md.py new file mode 100644 index 0000000000..244a0105ee --- /dev/null +++ b/source/tests/jax/test_jax_md.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from dataclasses import ( + dataclass, +) +from importlib.util import ( + find_spec, +) + +import numpy as np + +from deepmd.jax.env import ( + jax, + jnp, +) +from deepmd.jax.jax_md import ( + as_jax_md, + energy_fn, + force_fn, +) + + +class HarmonicModel: + def get_type_map(self): + return ["O", "H"] + + def get_rcut(self): + return 6.0 + + def get_dim_fparam(self): + return 0 + + def get_dim_aparam(self): + return 0 + + def __call__( + self, + coord, + atype, + box=None, + fparam=None, + aparam=None, + charge_spin=None, + ): + del atype, box, fparam, aparam, charge_spin + return {"energy": jnp.sum(coord**2, axis=(1, 2))[:, None]} + + +class EdgeModel(HarmonicModel): + def call_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping, + fparam=None, + aparam=None, + charge_spin=None, + ): + del extended_atype, mapping, fparam, aparam, charge_spin + valid = nlist >= 0 + safe_nlist = jnp.where(valid, nlist, 0) + neighbor_coord = jax.vmap(lambda coord, idx: coord[idx])( + extended_coord, safe_nlist + ) + nloc = nlist.shape[1] + center_coord = extended_coord[:, :nloc, None, :] + edge_vec = jnp.where(valid[..., None], neighbor_coord - center_coord, 0.0) + return {"energy": 0.5 * jnp.sum(edge_vec**2, axis=(1, 2, 3))[:, None]} + + +@dataclass +class DenseNeighbor: + idx: jax.Array + + +def test_energy_and_force_fn(): + potential = energy_fn(HarmonicModel(), ["O", "H"]) + coord = jnp.asarray( + [ + [1.0, 2.0, 3.0], + [0.5, 0.0, -1.0], + ] + ) + + np.testing.assert_allclose(potential(coord), 15.25) + np.testing.assert_allclose(force_fn(potential)(coord), -2.0 * coord) + np.testing.assert_allclose(jax.jit(potential)(coord), 15.25) + + +def test_dense_neighbor_uses_jax_md_displacement_convention(): + potential = energy_fn( + EdgeModel(), + [0, 1], + displacement_fn=lambda ra, rb: (ra - rb) - 10.0 * jnp.round((ra - rb) / 10.0), + ) + coord = jnp.asarray( + [ + [0.1, 0.0, 0.0], + [9.9, 0.0, 0.0], + ] + ) + neighbor = DenseNeighbor(jnp.asarray([[1], [0]], dtype=jnp.int32)) + + np.testing.assert_allclose(potential(coord, neighbor=neighbor), 0.04, atol=1e-12) + + +@unittest.skipIf(find_spec("jax_md") is None, "jax-md is not installed") +def test_actual_jax_md_neighbor_list(): + from jax_md import ( + space, + ) + + displacement, _ = space.periodic(10.0) + neighbor_fn, potential = as_jax_md( + EdgeModel(), + displacement, + 10.0, + [0, 1], + dr_threshold=0.1, + ) + coord = jnp.asarray( + [ + [0.1, 0.0, 0.0], + [9.9, 0.0, 0.0], + ] + ) + neighbor = neighbor_fn.allocate(coord) + + np.testing.assert_array_equal(np.asarray(neighbor.idx), [[1], [0]]) + np.testing.assert_allclose(potential(coord, neighbor=neighbor), 0.04, atol=1e-12) + np.testing.assert_allclose( + force_fn(potential)(coord, neighbor=neighbor), + [[-0.4, 0.0, 0.0], [0.4, 0.0, 0.0]], + atol=1e-12, + ) From c4271951dacab4c3d2b24f21ded40355f545c06a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jun 2026 00:18:20 +0800 Subject: [PATCH 2/6] fix(jax): reject HLO models in JAX-MD --- deepmd/jax/jax_md.py | 31 ++++++------------------------- source/tests/jax/test_jax_md.py | 6 ++++++ 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/deepmd/jax/jax_md.py b/deepmd/jax/jax_md.py index 1e22fea5e5..188730bdfc 100644 --- a/deepmd/jax/jax_md.py +++ b/deepmd/jax/jax_md.py @@ -14,9 +14,6 @@ Any, ) -from deepmd.dpmodel.utils.serialization import ( - load_dp_model, -) from deepmd.jax.env import ( jax, jnp, @@ -24,9 +21,6 @@ from deepmd.jax.model.base_model import ( BaseModel, ) -from deepmd.jax.model.hlo import ( - HLO, -) from deepmd.jax.utils.serialization import ( serialize_from_file, ) @@ -49,8 +43,12 @@ def load_model(model: str | Path | Any) -> Any: jax_model.model_def_script = json.dumps(data.get("model_def_script", {})) return jax_model if model_path.endswith(".hlo"): - return _load_hlo_model(model_path) - raise ValueError("JAX-MD interface supports .jax checkpoints and .hlo models.") + raise NotImplementedError( + "JAX-MD does not support .hlo models yet. The JAX-MD simulation " + "helpers require differentiating the energy function, while exported " + "StableHLO models do not expose a VJP to JAX. Use a .jax checkpoint." + ) + raise ValueError("JAX-MD interface supports .jax checkpoints.") def energy_fn( @@ -198,23 +196,6 @@ def as_jax_md( return nlist_fn, potential -def _load_hlo_model(model_file: str) -> HLO: - """Load a DeePMD HLO model into the JAX inference wrapper.""" - model_data = load_dp_model(model_file) - return HLO( - stablehlo=model_data["@variables"]["stablehlo"].tobytes(), - stablehlo_atomic_virial=model_data["@variables"][ - "stablehlo_atomic_virial" - ].tobytes(), - stablehlo_no_ghost=model_data["@variables"]["stablehlo_no_ghost"].tobytes(), - stablehlo_atomic_virial_no_ghost=model_data["@variables"][ - "stablehlo_atomic_virial_no_ghost" - ].tobytes(), - model_def_script=json.dumps(model_data["model_def_script"]), - **model_data["constants"], - ) - - def _normalize_atom_types(model: Any, atom_types: Sequence[int | str] | Array) -> Array: """Convert type names or type indexes to a JAX int32 type array.""" if isinstance(atom_types, jax.Array): diff --git a/source/tests/jax/test_jax_md.py b/source/tests/jax/test_jax_md.py index 244a0105ee..27ca55d59b 100644 --- a/source/tests/jax/test_jax_md.py +++ b/source/tests/jax/test_jax_md.py @@ -17,6 +17,7 @@ as_jax_md, energy_fn, force_fn, + load_model, ) @@ -88,6 +89,11 @@ def test_energy_and_force_fn(): np.testing.assert_allclose(jax.jit(potential)(coord), 15.25) +def test_hlo_model_raises_not_implemented(): + with np.testing.assert_raises(NotImplementedError): + load_model("model.hlo") + + def test_dense_neighbor_uses_jax_md_displacement_convention(): potential = energy_fn( EdgeModel(), From 5e78ce751360f8e298acecd431c3af1bc74e439a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jun 2026 01:49:41 +0800 Subject: [PATCH 3/6] fix(jax): validate JAX-MD neighbor inputs --- deepmd/jax/jax_md.py | 23 ++++++++++++++- source/tests/jax/test_jax_md.py | 51 +++++++++++++++++++++++++++++---- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/deepmd/jax/jax_md.py b/deepmd/jax/jax_md.py index 188730bdfc..94510ea9b9 100644 --- a/deepmd/jax/jax_md.py +++ b/deepmd/jax/jax_md.py @@ -165,7 +165,10 @@ def neighbor_list( ) from exc jax_model = load_model(model) - kwargs.setdefault("format", partition.NeighborListFormat.Dense) + neighbor_format = kwargs.setdefault("format", partition.NeighborListFormat.Dense) + if neighbor_format != partition.NeighborListFormat.Dense: + raise ValueError("Only dense JAX-MD neighbor lists are supported.") + _validate_displacement_or_metric(displacement_or_metric) return partition.neighbor_list( displacement_or_metric, box, @@ -196,6 +199,19 @@ def as_jax_md( return nlist_fn, potential +def _validate_displacement_or_metric( + displacement_or_metric: Callable[..., Array], +) -> None: + """Reject scalar metrics where DeePMD needs vector displacements.""" + coord = jnp.zeros((3,), dtype=jnp.float32) + displacement = jnp.asarray(displacement_or_metric(coord, coord)) + if displacement.shape != coord.shape: + raise ValueError( + "Dense neighbor evaluation requires a displacement function returning " + "vectors with shape (..., 3); scalar metric functions are not supported." + ) + + def _normalize_atom_types(model: Any, atom_types: Sequence[int | str] | Array) -> Array: """Convert type names or type indexes to a JAX int32 type array.""" if isinstance(atom_types, jax.Array): @@ -357,6 +373,11 @@ def _jax_md_neighbor_to_lower_inputs( ) ) )(central_coord, neighbor_coord) + if displacement.shape != neighbor_coord.shape: + raise ValueError( + "Dense neighbor evaluation requires a displacement function returning " + "vectors with shape (..., 3); scalar metric functions are not supported." + ) # JAX-MD displacement functions use the Ra - Rb convention. ghost_coord = central_coord - displacement diff --git a/source/tests/jax/test_jax_md.py b/source/tests/jax/test_jax_md.py index 27ca55d59b..ce4b7288c6 100644 --- a/source/tests/jax/test_jax_md.py +++ b/source/tests/jax/test_jax_md.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest from dataclasses import ( dataclass, ) -from importlib.util import ( - find_spec, -) import numpy as np +import pytest from deepmd.jax.env import ( jax, @@ -18,6 +15,7 @@ energy_fn, force_fn, load_model, + neighbor_list, ) @@ -111,8 +109,51 @@ def test_dense_neighbor_uses_jax_md_displacement_convention(): np.testing.assert_allclose(potential(coord, neighbor=neighbor), 0.04, atol=1e-12) -@unittest.skipIf(find_spec("jax_md") is None, "jax-md is not installed") +def test_dense_neighbor_rejects_scalar_metric(): + potential = energy_fn( + EdgeModel(), + [0, 1], + displacement_fn=lambda ra, rb: jnp.linalg.norm(ra - rb), + ) + coord = jnp.asarray( + [ + [0.1, 0.0, 0.0], + [9.9, 0.0, 0.0], + ] + ) + neighbor = DenseNeighbor(jnp.asarray([[1], [0]], dtype=jnp.int32)) + + with pytest.raises(ValueError, match="scalar metric"): + potential(coord, neighbor=neighbor) + + +def test_neighbor_list_rejects_unsupported_format_and_scalar_metric(): + pytest.importorskip("jax_md") + from jax_md import ( + partition, + space, + ) + + displacement, _ = space.periodic(10.0) + with pytest.raises(ValueError, match="Only dense"): + neighbor_list( + EdgeModel(), + displacement, + 10.0, + format=partition.NeighborListFormat.Sparse, + ) + + with pytest.raises(ValueError, match="scalar metric"): + as_jax_md( + EdgeModel(), + lambda ra, rb: jnp.linalg.norm(ra - rb), + 10.0, + [0, 1], + ) + + def test_actual_jax_md_neighbor_list(): + pytest.importorskip("jax_md") from jax_md import ( space, ) From c274f1dd4db7ed785f86ee46a932ee055701db53 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jun 2026 14:32:18 +0800 Subject: [PATCH 4/6] fix(jax): use public energy output in JAX-MD --- deepmd/jax/jax_md.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepmd/jax/jax_md.py b/deepmd/jax/jax_md.py index 94510ea9b9..45080dbf8c 100644 --- a/deepmd/jax/jax_md.py +++ b/deepmd/jax/jax_md.py @@ -410,9 +410,8 @@ def _extract_energy(ret: Any) -> Array: """Extract a scalar total energy from a DeePMD model return value.""" if isinstance(ret, tuple): ret = ret[0] - for key in ("energy", "energy_redu"): - if key in ret and ret[key] is not None: - return jnp.ravel(ret[key])[0] + if "energy" in ret and ret["energy"] is not None: + return jnp.ravel(ret["energy"])[0] raise KeyError("Model output does not contain an energy value.") From 30228236832c2137e5cc4b9bce94f7603cfb17cd Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jun 2026 20:49:29 +0800 Subject: [PATCH 5/6] refactor(jax): move JAX-MD adapter into package --- deepmd/jax/{jax_md.py => jax_md/__init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename deepmd/jax/{jax_md.py => jax_md/__init__.py} (100%) diff --git a/deepmd/jax/jax_md.py b/deepmd/jax/jax_md/__init__.py similarity index 100% rename from deepmd/jax/jax_md.py rename to deepmd/jax/jax_md/__init__.py From 4e0b1df3cf5986f1cb8a41a5d4ff262bce5325da Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 28 Jun 2026 22:30:10 +0800 Subject: [PATCH 6/6] docs: remove the claim of support for hlo --- doc/third-party/jaxmd.md | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/third-party/jaxmd.md b/doc/third-party/jaxmd.md index 632e21e22d..0b7cf2c2fd 100644 --- a/doc/third-party/jaxmd.md +++ b/doc/third-party/jaxmd.md @@ -81,7 +81,6 @@ LAMMPS example, see `examples/water/jax_md`. `deepmd.jax.jax_md.load_model` accepts: - a DeePMD JAX checkpoint path ending in `.jax`, -- a DeePMD HLO model path ending in `.hlo`, - an already constructed JAX DeePMD model object. The `atom_types` argument may be an integer array of DeePMD type indexes. It