Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c8de6a6
feat(show): add serialization-tree via DeepEval.serialize
njzjz-bot Mar 16, 2026
b695aaa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
1694360
refactor(deepeval): implement backend serialize via model.serialize
njzjz-bot Mar 16, 2026
7675388
fix(show): align backend serialize output contracts
njzjz-bot Apr 7, 2026
a8a80f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
c18814c
refactor(show): decouple serialization tree from deep eval wrapper
njzjz-bot Apr 8, 2026
5678d78
refactor(deepeval): serialize model tree only
njzjz-bot Apr 8, 2026
8b48f46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2026
5e97073
test(io): cover deep eval serialization in consistent io
njzjz-bot Apr 8, 2026
8bf3176
fix(review): handle paddle static serialize and io assert
njzjz-bot Apr 8, 2026
6cf9dbf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2026
53f6f12
fix(ci): handle savedmodel and torchscript serialize
njzjz-bot Apr 8, 2026
a2431e8
fix(ci): guard pt serialize fallback shape
njzjz-bot Apr 8, 2026
39aab42
fix(test): pt_expt serialize fallback returns model tree with @class key
njzjz-bot Apr 8, 2026
67f4d1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2026
024a058
fix(ci): revert pt deep eval mixed pte serialization path
njzjz-bot Apr 8, 2026
b951333
fix(ci): restore backend-specific serialize fallbacks
njzjz-bot Apr 8, 2026
f65bb5f
fix(jax): avoid lossless savedmodel serialization
njzjz Jun 21, 2026
ebda14d
test(pt-expt): compare serialized trees with numpy-aware assert
njzjz-bot Jun 21, 2026
980c5fe
Merge branch 'master' into feat/deepeval-serialize
njzjz Jun 21, 2026
1140851
fix: address serialization tree review comments
njzjz-bot Jun 22, 2026
5e34f15
test(show): fix serialization tree CI failures
njzjz Jun 28, 2026
1373976
test(io): skip TF2 SavedModel serialize assert
njzjz Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return json.loads(self.dp.get_model_def_script())

def serialize(self) -> dict[str, Any]:
model = self.dp
return model.serialize()

def get_observed_types(self) -> dict:
"""Get observed types (elements) of the model during data statistics.

Expand Down
14 changes: 14 additions & 0 deletions deepmd/entrypoints/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,17 @@ def show(
observed_types = model.get_observed_types()
log.info(f"Number of observed types: {observed_types['type_num']} ")
log.info(f"Observed types: {observed_types['observed_type']} ")

if "serialization-tree" in ATTRIBUTES:
from deepmd.dpmodel.utils.serialization import (
Node,
)

if model_is_multi_task:
for branch in model_params["model_dict"]:
branch_model = DeepEval(INPUT, head=branch)
root = Node.deserialize(branch_model.serialize())
log.info("Model serialization tree of branch %s:\n%s", branch, root)
else:
root = Node.deserialize(model.serialize())
log.info("Model serialization tree:\n%s", root)
32 changes: 32 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,27 @@ def get_model(self) -> Any:
The model module implemented by the deep learning framework.
"""

def serialize(self) -> dict[str, Any]:
"""Serialize the loaded model as a model tree.

Most in-tree backends return the lossless, weight-bearing ``model``
subtree from the serialized file payload. Backends that cannot recover a
lossless tree may override this method to document and implement their
narrower behavior.

Returns
-------
dict
Serialized model tree that can be consumed by ``Node.deserialize``.
"""
model = self.get_model()
if hasattr(model, "serialize"):
return model.serialize()
raise NotImplementedError(
f"{type(self).__name__} does not implement serialize(), and its "
"model object has no serialize() method."
)


class DeepEval(ABC):
"""High-level Deep Evaluator interface.
Expand Down Expand Up @@ -423,6 +444,7 @@ def __init__(
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Any,
) -> None:
self.model_file = model_file
self.deep_eval = DeepEvalBackend(
model_file,
self.output_def,
Expand All @@ -439,6 +461,16 @@ def __init__(
def output_def(self) -> ModelOutputDef:
"""Returns the output variable definitions."""

def serialize(self) -> dict[str, Any]:
"""Serialize the loaded model as a model tree.

Most backends return the lossless, weight-bearing ``model`` subtree from
the serialized file payload. JAX ``.savedmodel`` inputs are the known
exception: they are reconstructed from the model definition script and
therefore do not preserve trained weights.
"""
return self.deep_eval.serialize()

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
return self.deep_eval.get_rcut()
Expand Down
25 changes: 25 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,31 @@ def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
return 0

def serialize(self) -> dict[str, Any]:
"""Serialize the loaded model as a model tree.

JAX-native ``.jax``/``.hlo`` inputs return the lossless, weight-bearing
``model`` subtree from the file payload. TensorFlow-wrapped
``.savedmodel`` inputs cannot be converted back losslessly; for that
format this method reconstructs the model tree from the definition
script, so trained weights are not preserved.
"""
if str(self.model_path).endswith(".savedmodel"):
from deepmd.jax.model.model import (
get_model,
)

return get_model(self.get_model_def_script()).serialize()

from deepmd.jax.utils.serialization import (
serialize_from_file,
)

data = serialize_from_file(self.model_path)
if "model" not in data:
raise RuntimeError("Serialized model data does not contain key 'model'.")
return data["model"]

def eval(
self,
coords: np.ndarray,
Expand Down
10 changes: 9 additions & 1 deletion deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,13 @@ def convert_str_to_int_key(item: dict) -> None:
data.pop("constants")
data["@variables"].pop("stablehlo")
return data
elif model_file.endswith(".savedmodel"):
raise ValueError(
"JAX SavedModel does not support lossless file serialization. "
"Use DeepEval.serialize() for a structure-only model tree."
)
else:
raise ValueError("JAX backend only supports converting .jax directory")
raise ValueError(
"JAX backend only supports lossless file serialization for .jax "
"directory and .hlo."
)
1 change: 1 addition & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ def main_parser() -> argparse.ArgumentParser:
"fitting-net",
"size",
"observed-type",
"serialization-tree",
],
nargs="+",
)
Expand Down
13 changes: 13 additions & 0 deletions deepmd/pd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,19 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.model_def_script

def serialize(self) -> dict[str, Any]:
model = (
self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp
)
if hasattr(model, "serialize"):
return model.serialize()

from deepmd.pd.utils.serialization import (
serialize_from_file,
)

return serialize_from_file(self.model_path)["model"]

Comment thread
njzjz marked this conversation as resolved.
def get_model_size(self) -> dict:
"""Get model parameter count.

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pretrained/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ def get_ntypes_spin(self) -> int:

def get_model(self) -> Any:
return self._backend.get_model()

def serialize(self) -> dict[str, Any]:
return self._backend.serialize()
11 changes: 11 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,17 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.model_def_script

def serialize(self) -> dict[str, Any]:
model = self.dp.model["Default"]
if hasattr(model, "serialize"):
return model.serialize()

from deepmd.pt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file(self.model_path)["model"]

def get_model_size(self) -> dict:
"""Get model parameter count.

Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,14 @@ def get_model_def_script(self) -> dict:
"""Get model definition script (training config)."""
return self._model_def_script

def serialize(self) -> dict[str, Any]:
from deepmd.pt_expt.utils.serialization import (
serialize_from_file,
)

data = serialize_from_file(self.model_path)
return data["model"] if isinstance(data, dict) and "model" in data else data

def get_model(self) -> torch.nn.Module:
"""Get the exported model module.

Expand Down
18 changes: 18 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
input_map=input_map,
)
self.load_prefix = load_prefix
self.model_file = model_file

# graph_compatable should be called after graph and prefix are set
if not self._graph_compatable():
Expand Down Expand Up @@ -1121,6 +1122,22 @@ def get_model_def_script(self) -> dict:
model_def_script = script.decode("utf-8")
return json.loads(model_def_script)["model"]

def serialize(self) -> dict[str, Any]:
from deepmd.tf.model.model import (
Model,
)
from deepmd.tf.utils.graph import (
load_graph_def,
)

graph, graph_def = load_graph_def(str(self.model_file))

model_def_script = self.get_model_def_script()
model = Model(**model_def_script)
# important! must be called before serialize
model.init_variables(graph=graph, graph_def=graph_def)
return model.serialize()

def get_model(self) -> "tf.Graph":
"""Get the TensorFlow graph.

Expand Down Expand Up @@ -1172,6 +1189,7 @@ def __init__(
input_map=input_map,
)
self.load_prefix = load_prefix
self.model_file = model_file

# graph_compatable should be called after graph and prefix are set
if not self._graph_compatable():
Expand Down
11 changes: 7 additions & 4 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,14 @@ def test_deep_eval(self) -> None:
if not backend.is_available():
continue
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(
prefix + backend.suffixes[suffix_idx], reference_data
)
deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx])
model_file = prefix + backend.suffixes[suffix_idx]
self.save_data_to_model(model_file, reference_data)
deep_eval = DeepEval(model_file)
self.assertIsInstance(deep_eval.get_model_def_script(), dict)
if not model_file.endswith((".savedmodel", ".savedmodeltf")):
# SavedModel formats store an executable graph, not a lossless model dict.
serialized_data = self.get_data_from_model(model_file)
np.testing.assert_equal(deep_eval.serialize(), serialized_data["model"])
if deep_eval.get_dim_fparam() > 0:
fparam = np.ones((nframes, deep_eval.get_dim_fparam()))
else:
Expand Down
9 changes: 9 additions & 0 deletions source/tests/pt_expt/infer/test_deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def test_model_api_delegation(self) -> None:
self.assertEqual(de.get_dim_aparam(), 0)
self.assertEqual(de.get_sel_type(), self.model.get_sel_type())

def test_serialize_returns_model_tree(self) -> None:
data = self.dp.deep_eval.serialize()
self.assertEqual(data["@class"], self.model.serialize()["@class"])
self.assertEqual(data["type"], self.model.serialize()["type"])
# The serialized model tree contains NumPy array leaves, so unittest's
# dict equality would try to coerce elementwise array comparisons to a
# single bool and fail with an ambiguous truth-value error.
np.testing.assert_equal(data, serialize_from_file(self.tmpfile.name)["model"])

def test_eval_consistency(self) -> None:
"""Test that DeepPot.eval gives same results as direct model forward."""
rng = np.random.default_rng(GLOBAL_SEED)
Expand Down
Loading
Loading