diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index d735443383..255ea4a385 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -416,6 +416,10 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return json.loads(self.dp.get_model_def_script()) + def serialize(self) -> dict[str, Any]: + model = self.dp + return model.serialize() + def get_observed_types(self) -> dict: """Get observed types (elements) of the model during data statistics. diff --git a/deepmd/entrypoints/show.py b/deepmd/entrypoints/show.py index b156e9d43d..002fb159c4 100644 --- a/deepmd/entrypoints/show.py +++ b/deepmd/entrypoints/show.py @@ -149,3 +149,17 @@ def show( observed_types = model.get_observed_types() log.info(f"Number of observed types: {observed_types['type_num']} ") log.info(f"Observed types: {observed_types['observed_type']} ") + + if "serialization-tree" in ATTRIBUTES: + from deepmd.dpmodel.utils.serialization import ( + Node, + ) + + if model_is_multi_task: + for branch in model_params["model_dict"]: + branch_model = DeepEval(INPUT, head=branch) + root = Node.deserialize(branch_model.serialize()) + log.info("Model serialization tree of branch %s:\n%s", branch, root) + else: + root = Node.deserialize(model.serialize()) + log.info("Model serialization tree:\n%s", root) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 557f3ddd23..a08ec885f3 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -380,6 +380,27 @@ def get_model(self) -> Any: The model module implemented by the deep learning framework. """ + def serialize(self) -> dict[str, Any]: + """Serialize the loaded model as a model tree. + + Most in-tree backends return the lossless, weight-bearing ``model`` + subtree from the serialized file payload. Backends that cannot recover a + lossless tree may override this method to document and implement their + narrower behavior. + + Returns + ------- + dict + Serialized model tree that can be consumed by ``Node.deserialize``. + """ + model = self.get_model() + if hasattr(model, "serialize"): + return model.serialize() + raise NotImplementedError( + f"{type(self).__name__} does not implement serialize(), and its " + "model object has no serialize() method." + ) + class DeepEval(ABC): """High-level Deep Evaluator interface. @@ -423,6 +444,7 @@ def __init__( neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, **kwargs: Any, ) -> None: + self.model_file = model_file self.deep_eval = DeepEvalBackend( model_file, self.output_def, @@ -439,6 +461,16 @@ def __init__( def output_def(self) -> ModelOutputDef: """Returns the output variable definitions.""" + def serialize(self) -> dict[str, Any]: + """Serialize the loaded model as a model tree. + + Most backends return the lossless, weight-bearing ``model`` subtree from + the serialized file payload. JAX ``.savedmodel`` inputs are the known + exception: they are reconstructed from the model definition script and + therefore do not preserve trained weights. + """ + return self.deep_eval.serialize() + def get_rcut(self) -> float: """Get the cutoff radius of this model.""" return self.deep_eval.get_rcut() diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 2e028225f7..09b5783ba9 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -187,6 +187,31 @@ def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model.""" return 0 + def serialize(self) -> dict[str, Any]: + """Serialize the loaded model as a model tree. + + JAX-native ``.jax``/``.hlo`` inputs return the lossless, weight-bearing + ``model`` subtree from the file payload. TensorFlow-wrapped + ``.savedmodel`` inputs cannot be converted back losslessly; for that + format this method reconstructs the model tree from the definition + script, so trained weights are not preserved. + """ + if str(self.model_path).endswith(".savedmodel"): + from deepmd.jax.model.model import ( + get_model, + ) + + return get_model(self.get_model_def_script()).serialize() + + from deepmd.jax.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(self.model_path) + if "model" not in data: + raise RuntimeError("Serialized model data does not contain key 'model'.") + return data["model"] + def eval( self, coords: np.ndarray, diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 14386d9f3d..eded22b692 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -202,5 +202,13 @@ def convert_str_to_int_key(item: dict) -> None: data.pop("constants") data["@variables"].pop("stablehlo") return data + elif model_file.endswith(".savedmodel"): + raise ValueError( + "JAX SavedModel does not support lossless file serialization. " + "Use DeepEval.serialize() for a structure-only model tree." + ) else: - raise ValueError("JAX backend only supports converting .jax directory") + raise ValueError( + "JAX backend only supports lossless file serialization for .jax " + "directory and .hlo." + ) diff --git a/deepmd/main.py b/deepmd/main.py index bf59dfdad5..c08dd7ba79 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -951,6 +951,7 @@ def main_parser() -> argparse.ArgumentParser: "fitting-net", "size", "observed-type", + "serialization-tree", ], nargs="+", ) diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index fcc9534097..c8bb113495 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -740,6 +740,19 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return self.model_def_script + def serialize(self) -> dict[str, Any]: + model = ( + self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp + ) + if hasattr(model, "serialize"): + return model.serialize() + + from deepmd.pd.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file(self.model_path)["model"] + def get_model_size(self) -> dict: """Get model parameter count. diff --git a/deepmd/pretrained/deep_eval.py b/deepmd/pretrained/deep_eval.py index 2dc671b0cc..aa15a50760 100644 --- a/deepmd/pretrained/deep_eval.py +++ b/deepmd/pretrained/deep_eval.py @@ -184,3 +184,6 @@ def get_ntypes_spin(self) -> int: def get_model(self) -> Any: return self._backend.get_model() + + def serialize(self) -> dict[str, Any]: + return self._backend.serialize() diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 7d54c7ef01..dce5e36186 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -932,6 +932,17 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return self.model_def_script + def serialize(self) -> dict[str, Any]: + model = self.dp.model["Default"] + if hasattr(model, "serialize"): + return model.serialize() + + from deepmd.pt.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file(self.model_path)["model"] + def get_model_size(self) -> dict: """Get model parameter count. diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 42e4181d65..dfa2770b9e 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -1463,6 +1463,14 @@ def get_model_def_script(self) -> dict: """Get model definition script (training config).""" return self._model_def_script + def serialize(self) -> dict[str, Any]: + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(self.model_path) + return data["model"] if isinstance(data, dict) and "model" in data else data + def get_model(self) -> torch.nn.Module: """Get the exported model module. diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index 0ec2f1c74e..0d81ce588f 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -110,6 +110,7 @@ def __init__( input_map=input_map, ) self.load_prefix = load_prefix + self.model_file = model_file # graph_compatable should be called after graph and prefix are set if not self._graph_compatable(): @@ -1121,6 +1122,22 @@ def get_model_def_script(self) -> dict: model_def_script = script.decode("utf-8") return json.loads(model_def_script)["model"] + def serialize(self) -> dict[str, Any]: + from deepmd.tf.model.model import ( + Model, + ) + from deepmd.tf.utils.graph import ( + load_graph_def, + ) + + graph, graph_def = load_graph_def(str(self.model_file)) + + model_def_script = self.get_model_def_script() + model = Model(**model_def_script) + # important! must be called before serialize + model.init_variables(graph=graph, graph_def=graph_def) + return model.serialize() + def get_model(self) -> "tf.Graph": """Get the TensorFlow graph. @@ -1172,6 +1189,7 @@ def __init__( input_map=input_map, ) self.load_prefix = load_prefix + self.model_file = model_file # graph_compatable should be called after graph and prefix are set if not self._graph_compatable(): diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 982d56d8fa..904576d4fe 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -156,11 +156,14 @@ def test_deep_eval(self) -> None: if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) - self.save_data_to_model( - prefix + backend.suffixes[suffix_idx], reference_data - ) - deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx]) + model_file = prefix + backend.suffixes[suffix_idx] + self.save_data_to_model(model_file, reference_data) + deep_eval = DeepEval(model_file) self.assertIsInstance(deep_eval.get_model_def_script(), dict) + if not model_file.endswith((".savedmodel", ".savedmodeltf")): + # SavedModel formats store an executable graph, not a lossless model dict. + serialized_data = self.get_data_from_model(model_file) + np.testing.assert_equal(deep_eval.serialize(), serialized_data["model"]) if deep_eval.get_dim_fparam() > 0: fparam = np.ones((nframes, deep_eval.get_dim_fparam())) else: diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index f831b365f6..4444d5e9a1 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -140,6 +140,15 @@ def test_model_api_delegation(self) -> None: self.assertEqual(de.get_dim_aparam(), 0) self.assertEqual(de.get_sel_type(), self.model.get_sel_type()) + def test_serialize_returns_model_tree(self) -> None: + data = self.dp.deep_eval.serialize() + self.assertEqual(data["@class"], self.model.serialize()["@class"]) + self.assertEqual(data["type"], self.model.serialize()["type"]) + # The serialized model tree contains NumPy array leaves, so unittest's + # dict equality would try to coerce elementwise array comparisons to a + # single bool and fail with an ambiguous truth-value error. + np.testing.assert_equal(data, serialize_from_file(self.tmpfile.name)["model"]) + def test_eval_consistency(self) -> None: """Test that DeepPot.eval gives same results as direct model forward.""" rng = np.random.default_rng(GLOBAL_SEED) diff --git a/source/tests/test_deep_eval_serialize_api.py b/source/tests/test_deep_eval_serialize_api.py new file mode 100644 index 0000000000..6959a1bcb3 --- /dev/null +++ b/source/tests/test_deep_eval_serialize_api.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + Mock, + patch, +) + +from deepmd.infer.deep_eval import ( + DeepEvalBackend, +) + + +class _DefaultSerializeBackend(DeepEvalBackend): + def __init__(self, model: object) -> None: + self._model = model + + def eval(self, *args: object, **kwargs: object) -> dict: + return {} + + def get_rcut(self) -> float: + return 0.0 + + def get_ntypes(self) -> int: + return 0 + + def get_type_map(self) -> list[str]: + return [] + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + @property + def model_type(self) -> type: + return object + + def get_sel_type(self) -> list[int]: + return [] + + def get_ntypes_spin(self) -> int: + return 0 + + def get_model(self) -> object: + return self._model + + +class TestDeepEvalBackendSerialize(unittest.TestCase): + def test_default_serialize_delegates_to_model_when_available(self) -> None: + model = Mock() + model.serialize.return_value = {"@class": "Model"} + backend = _DefaultSerializeBackend(model) + + self.assertEqual(backend.serialize(), {"@class": "Model"}) + model.serialize.assert_called_once_with() + + def test_default_serialize_has_clear_error_without_model_method(self) -> None: + backend = _DefaultSerializeBackend(object()) + + with self.assertRaisesRegex( + NotImplementedError, "does not implement serialize" + ): + backend.serialize() + + +def _load_deep_eval_backend(module_name: str, backend_name: str): + try: + module = __import__(module_name, fromlist=["DeepEval"]) + except ImportError as exc: + raise unittest.SkipTest( + f"{backend_name} backend is not importable: {exc}" + ) from exc + return module.DeepEval + + +class TestPaddleDeepEvalSerialize(unittest.TestCase): + def test_jit_model_falls_back_to_file_serializer(self) -> None: + PaddleDeepEvalBackend = _load_deep_eval_backend( + "deepmd.pd.infer.deep_eval", "Paddle" + ) + backend = object.__new__(PaddleDeepEvalBackend) + backend.model_path = "frozen_model.json" + backend.dp = object() + + with patch("deepmd.pd.utils.serialization.serialize_from_file") as serialize: + serialize.return_value = {"model": {"@class": "RecoveredModel"}} + + self.assertEqual(backend.serialize(), {"@class": "RecoveredModel"}) + + serialize.assert_called_once_with("frozen_model.json") + + +class TestPyTorchDeepEvalSerialize(unittest.TestCase): + def test_jit_model_falls_back_to_file_serializer(self) -> None: + PyTorchDeepEvalBackend = _load_deep_eval_backend( + "deepmd.pt.infer.deep_eval", "PyTorch" + ) + backend = object.__new__(PyTorchDeepEvalBackend) + backend.model_path = "frozen_model.pth" + backend.dp = Mock() + backend.dp.model = {"Default": object()} + + with patch("deepmd.pt.utils.serialization.serialize_from_file") as serialize: + serialize.return_value = {"model": {"@class": "RecoveredModel"}} + + self.assertEqual(backend.serialize(), {"@class": "RecoveredModel"}) + + serialize.assert_called_once_with("frozen_model.pth") + + +class TestPyTorchExportableDeepEvalSerialize(unittest.TestCase): + def test_raw_model_payload_fallback_is_preserved(self) -> None: + PyTorchExportableDeepEvalBackend = _load_deep_eval_backend( + "deepmd.pt_expt.infer.deep_eval", "PyTorch exportable" + ) + backend = object.__new__(PyTorchExportableDeepEvalBackend) + backend.model_path = "frozen_model.pt" + + with patch( + "deepmd.pt_expt.utils.serialization.serialize_from_file" + ) as serialize: + serialize.return_value = {"@class": "RawExportedModel"} + + self.assertEqual(backend.serialize(), {"@class": "RawExportedModel"}) + + serialize.assert_called_once_with("frozen_model.pt") + + +class TestJAXDeepEvalSerialize(unittest.TestCase): + def test_savedmodel_reconstructs_tree_from_model_def_script(self) -> None: + JAXDeepEvalBackend = _load_deep_eval_backend( + "deepmd.jax.infer.deep_eval", "JAX" + ) + backend = object.__new__(JAXDeepEvalBackend) + backend.model_path = "frozen_model.savedmodel" + backend.get_model_def_script = Mock(return_value={"type_map": ["O", "H"]}) + + model = Mock() + model.serialize.return_value = {"@class": "SavedModelTree"} + with patch("deepmd.jax.model.model.get_model", return_value=model) as get_model: + self.assertEqual(backend.serialize(), {"@class": "SavedModelTree"}) + + get_model.assert_called_once_with({"type_map": ["O", "H"]}) + model.serialize.assert_called_once_with() diff --git a/source/tests/test_entrypoint_show_serialization_tree.py b/source/tests/test_entrypoint_show_serialization_tree.py new file mode 100644 index 0000000000..839a6b48a7 --- /dev/null +++ b/source/tests/test_entrypoint_show_serialization_tree.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + Mock, + call, + patch, +) + +from deepmd.entrypoints.show import ( + show, +) + + +class TestShowSerializationTree(unittest.TestCase): + def test_serialization_tree_uses_deep_eval_model_payload(self) -> None: + with ( + patch("deepmd.entrypoints.show.DeepEval") as mock_deep_eval, + patch( + "deepmd.dpmodel.utils.serialization.Node.deserialize" + ) as mock_deserialize, + patch("deepmd.entrypoints.show.log.info") as mock_log_info, + ): + model = mock_deep_eval.return_value + model.get_model_def_script.return_value = {"type_map": ["H", "O"]} + model.get_model_size.return_value = {} + model.serialize.return_value = {"@class": "MockModel"} + mock_deserialize.return_value = "ROOT" + + show(INPUT="mock.pte", ATTRIBUTES=["serialization-tree"]) + + model.serialize.assert_called_once_with() + mock_deserialize.assert_called_once_with({"@class": "MockModel"}) + mock_log_info.assert_any_call("Model serialization tree:\n%s", "ROOT") + + def test_serialization_tree_iterates_multitask_branches(self) -> None: + with ( + patch("deepmd.entrypoints.show.DeepEval") as mock_deep_eval, + patch( + "deepmd.dpmodel.utils.serialization.Node.deserialize" + ) as mock_deserialize, + patch("deepmd.entrypoints.show.log.info") as mock_log_info, + ): + initial_model = Mock() + branch_a_model = Mock() + branch_b_model = Mock() + mock_deep_eval.side_effect = [initial_model, branch_a_model, branch_b_model] + + initial_model.get_model_def_script.return_value = { + "model_dict": {"branch_a": {}, "branch_b": {}} + } + initial_model.get_model_size.return_value = {} + branch_a_model.serialize.return_value = {"@class": "BranchA"} + branch_b_model.serialize.return_value = {"@class": "BranchB"} + mock_deserialize.side_effect = ["ROOT_A", "ROOT_B"] + + show(INPUT="mock-multitask.pte", ATTRIBUTES=["serialization-tree"]) + + self.assertEqual( + mock_deep_eval.call_args_list, + [ + call("mock-multitask.pte", head=0), + call("mock-multitask.pte", head="branch_a"), + call("mock-multitask.pte", head="branch_b"), + ], + ) + mock_deserialize.assert_has_calls( + [call({"@class": "BranchA"}), call({"@class": "BranchB"})] + ) + mock_log_info.assert_any_call( + "Model serialization tree of branch %s:\n%s", "branch_a", "ROOT_A" + ) + mock_log_info.assert_any_call( + "Model serialization tree of branch %s:\n%s", "branch_b", "ROOT_B" + )