Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,4 +391,14 @@ def _get_output_shape(self, odef, nframes, natoms):

def get_model_def_script(self) -> dict:
"""Get model definition script."""
return json.loads(self.model.get_model_def_script())
return json.loads(self.dp.get_model_def_script())

def get_model(self) -> "BaseModel":
"""Get the dpmodel BaseModel.

Returns
-------
BaseModel
The dpmodel BaseModel.
"""
return self.dp
28 changes: 28 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,20 @@ def get_observed_types(self) -> dict:
"""Get observed types (elements) of the model during data statistics."""
raise NotImplementedError("Not implemented in this backend.")

@abstractmethod
def get_model(self) -> Any:
"""Get the model module implemented by the deep learning framework.

For PyTorch, this returns the nn.Module. For Paddle, this returns
the paddle.nn.Layer. For TensorFlow, this returns the graph.
For dpmodel, this returns the BaseModel.

Returns
-------
model
The model module implemented by the deep learning framework.
"""


class DeepEval(ABC):
"""High-level Deep Evaluator interface.
Expand Down Expand Up @@ -685,3 +699,17 @@ def get_model_size(self) -> dict:
def get_observed_types(self) -> dict:
"""Get observed types (elements) of the model during data statistics."""
return self.deep_eval.get_observed_types()

def get_model(self) -> Any:
"""Get the model module implemented by the deep learning framework.

For PyTorch, this returns the nn.Module. For Paddle, this returns
the paddle.nn.Layer. For TensorFlow, this returns the graph.
For dpmodel, this returns the BaseModel.

Returns
-------
model
The model module implemented by the deep learning framework.
"""
return self.deep_eval.get_model()
10 changes: 10 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,13 @@ def _get_output_shape(self, odef, nframes, natoms):
def get_model_def_script(self) -> dict:
"""Get model definition script."""
return json.loads(self.dp.get_model_def_script())

def get_model(self) -> Any:
"""Get the JAX model as BaseModel.

Returns
-------
BaseModel
The JAX model as BaseModel instance.
"""
return self.dp
14 changes: 14 additions & 0 deletions deepmd/pd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
if TYPE_CHECKING:
import ase.neighborlist

from deepmd.pd.model.model.model import (
BaseModel,
)


class DeepEval(DeepEvalBackend):
"""Paddle backend implementation of DeepEval.
Expand Down Expand Up @@ -506,6 +510,16 @@ def get_model_size(self) -> dict:
"total": sum_param_des + sum_param_fit,
}

def get_model(self) -> "BaseModel":
"""Get the Paddle model.

Returns
-------
BaseModel
The Paddle model instance.
"""
return self.dp.model["Default"]

def eval_descriptor(
self,
coords: np.ndarray,
Expand Down
14 changes: 14 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
if TYPE_CHECKING:
import ase.neighborlist

from deepmd.pt.model.model.model import (
BaseModel,
)

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -706,6 +710,16 @@ def get_observed_types(self) -> dict:
"observed_type": sort_element_type(observed_type_list),
}

def get_model(self) -> "BaseModel":
"""Get the PyTorch model.

Returns
-------
BaseModel
The PyTorch model instance.
"""
return self.dp.model["Default"]

def eval_descriptor(
self,
coords: np.ndarray,
Expand Down
10 changes: 10 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,16 @@ def get_model_def_script(self) -> dict:
model_def_script = script.decode("utf-8")
return json.loads(model_def_script)["model"]

def get_model(self) -> "tf.Graph":
"""Get the TensorFlow graph.

Returns
-------
tf.Graph
The TensorFlow graph.
"""
return self.graph


class DeepEvalOld:
# old class for DipoleChargeModifier only
Expand Down
101 changes: 101 additions & 0 deletions source/tests/infer/test_get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

from deepmd.infer.deep_eval import (
DeepEval,
)

from ..consistent.common import (
parameterized,
)
from .case import (
get_cases,
)


@parameterized(
(
"se_e2_a",
"fparam_aparam",
), # key
(".pb", ".pth"), # model extension
)
class TestGetModelMethod(unittest.TestCase):
"""Test the new get_model method functionality."""

@classmethod
def setUpClass(cls) -> None:
key, extension = cls.param
cls.case = get_cases()[key]
cls.model_name = cls.case.get_model(extension)
cls.dp = DeepEval(cls.model_name)

@classmethod
def tearDownClass(cls) -> None:
cls.dp = None

def test_get_model_method_exists(self):
"""Test that get_model method exists."""
self.assertTrue(
hasattr(self.dp, "get_model"), "DeepEval should have get_model method"
)

def test_get_model_returns_valid_object(self):
"""Test that get_model returns a valid model object."""
model = self.dp.get_model()
self.assertIsNotNone(model, "get_model should return a non-None object")

def test_get_model_backend_specific(self):
"""Test that get_model returns the expected type for each backend."""
key, extension = self.param
model = self.dp.get_model()

if extension == ".pth":
# For PyTorch .pth models (TorchScript), should return torch.jit.ScriptModule
import torch

self.assertIsInstance(
model,
torch.jit.ScriptModule,
"PyTorch .pth model should return TorchScript ScriptModule instance",
)
# TorchScript modules are also nn.Module instances
self.assertIsInstance(
model,
torch.nn.Module,
"PyTorch .pth model should be a torch.nn.Module instance",
)
# Check if it has common model methods
self.assertTrue(
hasattr(model, "get_type_map"),
"PyTorch model should have get_type_map method",
)
self.assertTrue(
hasattr(model, "get_rcut"),
"PyTorch model should have get_rcut method",
)
elif extension == ".pb":
# For TensorFlow models, should return graph
try:
# Should be a TensorFlow graph or have graph-like properties
self.assertTrue(
hasattr(model, "get_operations")
or str(type(model)).find("Graph") >= 0,
"TensorFlow model should be a graph or graph-like object",
)
except ImportError:
# If TensorFlow not available, skip this assertion
pass

def test_get_model_consistency(self):
"""Test that get_model always returns the same object."""
model1 = self.dp.get_model()
model2 = self.dp.get_model()
# Should return the same object (not necessarily equal, but same reference)
self.assertIs(
model1, model2, "get_model should return consistent object reference"
)


if __name__ == "__main__":
unittest.main()