diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 9fd96ed491..4b30d97c29 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -391,4 +391,14 @@ def _get_output_shape(self, odef, nframes, natoms): def get_model_def_script(self) -> dict: """Get model definition script.""" - return json.loads(self.model.get_model_def_script()) + return json.loads(self.dp.get_model_def_script()) + + def get_model(self) -> "BaseModel": + """Get the dpmodel BaseModel. + + Returns + ------- + BaseModel + The dpmodel BaseModel. + """ + return self.dp diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index d067c4322e..49c9576afd 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -341,6 +341,20 @@ def get_observed_types(self) -> dict: """Get observed types (elements) of the model during data statistics.""" raise NotImplementedError("Not implemented in this backend.") + @abstractmethod + def get_model(self) -> Any: + """Get the model module implemented by the deep learning framework. + + For PyTorch, this returns the nn.Module. For Paddle, this returns + the paddle.nn.Layer. For TensorFlow, this returns the graph. + For dpmodel, this returns the BaseModel. + + Returns + ------- + model + The model module implemented by the deep learning framework. + """ + class DeepEval(ABC): """High-level Deep Evaluator interface. @@ -685,3 +699,17 @@ def get_model_size(self) -> dict: def get_observed_types(self) -> dict: """Get observed types (elements) of the model during data statistics.""" return self.deep_eval.get_observed_types() + + def get_model(self) -> Any: + """Get the model module implemented by the deep learning framework. + + For PyTorch, this returns the nn.Module. For Paddle, this returns + the paddle.nn.Layer. For TensorFlow, this returns the graph. + For dpmodel, this returns the BaseModel. + + Returns + ------- + model + The model module implemented by the deep learning framework. + """ + return self.deep_eval.get_model() diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index acfd42b66a..2e74c15fff 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -420,3 +420,13 @@ def _get_output_shape(self, odef, nframes, natoms): def get_model_def_script(self) -> dict: """Get model definition script.""" return json.loads(self.dp.get_model_def_script()) + + def get_model(self) -> Any: + """Get the JAX model as BaseModel. + + Returns + ------- + BaseModel + The JAX model as BaseModel instance. + """ + return self.dp diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index 2363e29100..61c3f9e9a3 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -46,6 +46,10 @@ if TYPE_CHECKING: import ase.neighborlist + from deepmd.pd.model.model.model import ( + BaseModel, + ) + class DeepEval(DeepEvalBackend): """Paddle backend implementation of DeepEval. @@ -506,6 +510,16 @@ def get_model_size(self) -> dict: "total": sum_param_des + sum_param_fit, } + def get_model(self) -> "BaseModel": + """Get the Paddle model. + + Returns + ------- + BaseModel + The Paddle model instance. + """ + return self.dp.model["Default"] + def eval_descriptor( self, coords: np.ndarray, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 13bd4d2bf0..25caf12b64 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -75,6 +75,10 @@ if TYPE_CHECKING: import ase.neighborlist + from deepmd.pt.model.model.model import ( + BaseModel, + ) + log = logging.getLogger(__name__) @@ -706,6 +710,16 @@ def get_observed_types(self) -> dict: "observed_type": sort_element_type(observed_type_list), } + def get_model(self) -> "BaseModel": + """Get the PyTorch model. + + Returns + ------- + BaseModel + The PyTorch model instance. + """ + return self.dp.model["Default"] + def eval_descriptor( self, coords: np.ndarray, diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index a7682d2e58..75440accb9 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -1126,6 +1126,16 @@ def get_model_def_script(self) -> dict: model_def_script = script.decode("utf-8") return json.loads(model_def_script)["model"] + def get_model(self) -> "tf.Graph": + """Get the TensorFlow graph. + + Returns + ------- + tf.Graph + The TensorFlow graph. + """ + return self.graph + class DeepEvalOld: # old class for DipoleChargeModifier only diff --git a/source/tests/infer/test_get_model.py b/source/tests/infer/test_get_model.py new file mode 100644 index 0000000000..4c52dda0a1 --- /dev/null +++ b/source/tests/infer/test_get_model.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.infer.deep_eval import ( + DeepEval, +) + +from ..consistent.common import ( + parameterized, +) +from .case import ( + get_cases, +) + + +@parameterized( + ( + "se_e2_a", + "fparam_aparam", + ), # key + (".pb", ".pth"), # model extension +) +class TestGetModelMethod(unittest.TestCase): + """Test the new get_model method functionality.""" + + @classmethod + def setUpClass(cls) -> None: + key, extension = cls.param + cls.case = get_cases()[key] + cls.model_name = cls.case.get_model(extension) + cls.dp = DeepEval(cls.model_name) + + @classmethod + def tearDownClass(cls) -> None: + cls.dp = None + + def test_get_model_method_exists(self): + """Test that get_model method exists.""" + self.assertTrue( + hasattr(self.dp, "get_model"), "DeepEval should have get_model method" + ) + + def test_get_model_returns_valid_object(self): + """Test that get_model returns a valid model object.""" + model = self.dp.get_model() + self.assertIsNotNone(model, "get_model should return a non-None object") + + def test_get_model_backend_specific(self): + """Test that get_model returns the expected type for each backend.""" + key, extension = self.param + model = self.dp.get_model() + + if extension == ".pth": + # For PyTorch .pth models (TorchScript), should return torch.jit.ScriptModule + import torch + + self.assertIsInstance( + model, + torch.jit.ScriptModule, + "PyTorch .pth model should return TorchScript ScriptModule instance", + ) + # TorchScript modules are also nn.Module instances + self.assertIsInstance( + model, + torch.nn.Module, + "PyTorch .pth model should be a torch.nn.Module instance", + ) + # Check if it has common model methods + self.assertTrue( + hasattr(model, "get_type_map"), + "PyTorch model should have get_type_map method", + ) + self.assertTrue( + hasattr(model, "get_rcut"), + "PyTorch model should have get_rcut method", + ) + elif extension == ".pb": + # For TensorFlow models, should return graph + try: + # Should be a TensorFlow graph or have graph-like properties + self.assertTrue( + hasattr(model, "get_operations") + or str(type(model)).find("Graph") >= 0, + "TensorFlow model should be a graph or graph-like object", + ) + except ImportError: + # If TensorFlow not available, skip this assertion + pass + + def test_get_model_consistency(self): + """Test that get_model always returns the same object.""" + model1 = self.dp.get_model() + model2 = self.dp.get_model() + # Should return the same object (not necessarily equal, but same reference) + self.assertIs( + model1, model2, "get_model should return consistent object reference" + ) + + +if __name__ == "__main__": + unittest.main()