From f8097299e2d6aeb89d891e451429dff73bc84612 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Mar 2026 00:15:58 +0800 Subject: [PATCH 1/4] feat(pt_expt): add frozen model support Load a pre-frozen model file (.pte or any format) via convert_backend serialization, reconstruct with BaseModel.deserialize, and delegate all model API methods to the inner model. Cannot be trained. --- deepmd/pt_expt/model/__init__.py | 4 + deepmd/pt_expt/model/frozen.py | 210 +++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 deepmd/pt_expt/model/frozen.py diff --git a/deepmd/pt_expt/model/__init__.py b/deepmd/pt_expt/model/__init__.py index 7197e39634..08169b5cee 100644 --- a/deepmd/pt_expt/model/__init__.py +++ b/deepmd/pt_expt/model/__init__.py @@ -15,6 +15,9 @@ from .ener_model import ( EnergyModel, ) +from .frozen import ( + FrozenModel, +) from .get_model import ( get_model, ) @@ -34,6 +37,7 @@ "DPZBLModel", "DipoleModel", "EnergyModel", + "FrozenModel", "PolarModel", "PropertyModel", "get_model", diff --git a/deepmd/pt_expt/model/frozen.py b/deepmd/pt_expt/model/frozen.py new file mode 100644 index 0000000000..db61200c16 --- /dev/null +++ b/deepmd/pt_expt/model/frozen.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + NoReturn, +) + +import torch + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +from .model import ( + BaseModel, +) + + +@BaseModel.register("frozen") +class FrozenModel(BaseModel): + """Load model from a frozen model file, which cannot be trained. + + Parameters + ---------- + model_file : str + The path to the frozen model file. + """ + + def __init__(self, model_file: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.model_file = model_file + # Use convert_backend approach: serialize the model file into a dict, + # then reconstruct via get_model. + from deepmd.backend.backend import ( + Backend, + ) + + inp_backend: Backend = Backend.detect_backend_by_model(model_file)() + data = inp_backend.serialize_hook(model_file) + # data has "model" key with serialized model data, and optionally + # "model_def_script" with model params. + from deepmd.pt_expt.model.model import BaseModel as BaseModelPtExpt + + self.model = BaseModelPtExpt.deserialize(data["model"]) + self.model.eval() + + def fitting_output_def(self) -> FittingOutputDef: + """Get the output def of developer implemented atomic models.""" + return self.model.fitting_output_def() + + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.model.get_rcut() + + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.model.get_type_map() + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.model.get_sel() + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.model.get_dim_fparam() + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.model.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.model.get_sel_type() + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self.model.is_aparam_nall() + + def mixed_types(self) -> bool: + """If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + """ + return self.model.mixed_types() + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.model.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the model needs sorted nlist when using `forward_lower`.""" + return self.model.need_sorted_nlist_for_lower() + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + return self.model.forward( + coord, + atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + return self.model.forward_lower( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + # try to use the original script instead of "frozen model" + return self.model.get_model_def_script() + + def get_min_nbor_dist(self) -> float | None: + """Get the minimum neighbor distance.""" + return self.model.get_min_nbor_dist() + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.model.get_nnei() + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.model.get_nsel() + + def model_output_type(self) -> str: + """Get the output type for the model.""" + return self.model.model_output_type() + + def get_observed_type_list(self) -> list[str]: + """Get observed types (elements) of the model during data statistics.""" + return self.model.get_observed_type_list() + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data. + """ + return self.model.serialize() + + @classmethod + def deserialize(cls, data: dict) -> NoReturn: + raise RuntimeError("Should not touch here.") + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: list[str] | None, + local_jdata: dict, + ) -> tuple[dict, float | None]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + return local_jdata, None From b7306cab37cd305c6f84c280300203ad0f569214 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Mar 2026 09:26:16 +0800 Subject: [PATCH 2/4] refactor(pt_expt): refactor frozen model with dpmodel base and torch_module - Create dpmodel FrozenModel (NativeOP + BaseModel) with all delegation methods, so pt_expt can inherit instead of duplicating - Rewrite pt_expt FrozenModel to use @torch_module wrapping dpmodel class - Override __init__ to handle .pte files natively via serialize_from_file, fall back to generic backend detection for other formats - Override serialize() to delegate directly to inner model (unlike pt which must reconstruct from model_def_script due to opaque ScriptModule) - Add pt_expt support to frozen consistency test using BaseModel as pt_expt_class (same pattern as pt) - Guard setUpModule model generation with backend availability checks --- deepmd/dpmodel/model/frozen.py | 153 +++++++++++++ deepmd/pt_expt/model/frozen.py | 212 +++---------------- source/tests/consistent/model/test_frozen.py | 23 +- 3 files changed, 197 insertions(+), 191 deletions(-) create mode 100644 deepmd/dpmodel/model/frozen.py diff --git a/deepmd/dpmodel/model/frozen.py b/deepmd/dpmodel/model/frozen.py new file mode 100644 index 0000000000..a8bf74323b --- /dev/null +++ b/deepmd/dpmodel/model/frozen.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + NoReturn, +) + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +from .base_model import ( + BaseModel, +) + + +class FrozenModel(NativeOP, BaseModel): + """Load model from a frozen model file, which cannot be trained. + + The frozen model delegates all operations to the deserialized inner + model. ``serialize()`` returns the inner model's data, and + ``deserialize()`` dispatches to the appropriate model class via + ``BaseModel.deserialize``. + + Parameters + ---------- + model_file : str + The path to the frozen model file. + """ + + def __init__(self, model_file: str, **kwargs: Any) -> None: + super().__init__() + self.model_file = model_file + from deepmd.backend.backend import ( + Backend, + ) + + inp_backend: Backend = Backend.detect_backend_by_model(model_file)() + data = inp_backend.serialize_hook(model_file) + self.model = BaseModel.deserialize(data["model"]) + + def call(self, *args: Any, **kwargs: Any) -> Any: + """Forward pass.""" + return self.model(*args, **kwargs) + + def fitting_output_def(self) -> FittingOutputDef: + """Get the output def of developer implemented atomic models.""" + return self.model.fitting_output_def() + + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.model.get_rcut() + + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.model.get_type_map() + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.model.get_sel() + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.model.get_dim_fparam() + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.model.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.model.get_sel_type() + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self.model.is_aparam_nall() + + def mixed_types(self) -> bool: + """If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + """ + return self.model.mixed_types() + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.model.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the model needs sorted nlist when using `forward_lower`.""" + return self.model.need_sorted_nlist_for_lower() + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.model.get_model_def_script() + + def get_min_nbor_dist(self) -> float | None: + """Get the minimum neighbor distance.""" + return self.model.get_min_nbor_dist() + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.model.get_nnei() + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.model.get_nsel() + + def model_output_type(self) -> str: + """Get the output type for the model.""" + return self.model.model_output_type() + + def get_observed_type_list(self) -> list[str]: + """Get observed types (elements) of the model during data statistics.""" + return self.model.get_observed_type_list() + + def serialize(self) -> dict: + """Serialize the model. + + Returns the inner model's serialized data. + """ + return self.model.serialize() + + @classmethod + def deserialize(cls, data: dict) -> NoReturn: + raise RuntimeError("Should not touch here.") + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: list[str] | None, + local_jdata: dict, + ) -> tuple[dict, float | None]: + """Update the selection and perform neighbor statistics.""" + return local_jdata, None diff --git a/deepmd/pt_expt/model/frozen.py b/deepmd/pt_expt/model/frozen.py index db61200c16..37feb5012f 100644 --- a/deepmd/pt_expt/model/frozen.py +++ b/deepmd/pt_expt/model/frozen.py @@ -1,16 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, - NoReturn, ) -import torch - -from deepmd.dpmodel.output_def import ( - FittingOutputDef, -) -from deepmd.utils.data_system import ( - DeepmdDataSystem, +from deepmd.dpmodel.model.frozen import FrozenModel as FrozenModelDP +from deepmd.pt_expt.common import ( + torch_module, ) from .model import ( @@ -19,192 +14,33 @@ @BaseModel.register("frozen") -class FrozenModel(BaseModel): - """Load model from a frozen model file, which cannot be trained. - - Parameters - ---------- - model_file : str - The path to the frozen model file. - """ - +@torch_module +class FrozenModel(FrozenModelDP): def __init__(self, model_file: str, **kwargs: Any) -> None: - super().__init__(**kwargs) + super(FrozenModelDP, self).__init__() self.model_file = model_file - # Use convert_backend approach: serialize the model file into a dict, - # then reconstruct via get_model. - from deepmd.backend.backend import ( - Backend, - ) - - inp_backend: Backend = Backend.detect_backend_by_model(model_file)() - data = inp_backend.serialize_hook(model_file) - # data has "model" key with serialized model data, and optionally - # "model_def_script" with model params. - from deepmd.pt_expt.model.model import BaseModel as BaseModelPtExpt - - self.model = BaseModelPtExpt.deserialize(data["model"]) + if model_file.endswith(".pte"): + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(model_file) + self.model = BaseModel.deserialize(data["model"]) + else: + from deepmd.backend.backend import ( + Backend, + ) + + inp_backend: Backend = Backend.detect_backend_by_model(model_file)() + data = inp_backend.serialize_hook(model_file) + self.model = BaseModel.deserialize(data["model"]) self.model.eval() - def fitting_output_def(self) -> FittingOutputDef: - """Get the output def of developer implemented atomic models.""" - return self.model.fitting_output_def() - - def get_rcut(self) -> float: - """Get the cut-off radius.""" - return self.model.get_rcut() - - def get_type_map(self) -> list[str]: - """Get the type map.""" - return self.model.get_type_map() - - def get_sel(self) -> list[int]: - """Returns the number of selected atoms for each type.""" - return self.model.get_sel() - - def get_dim_fparam(self) -> int: - """Get the number (dimension) of frame parameters of this atomic model.""" - return self.model.get_dim_fparam() - - def get_dim_aparam(self) -> int: - """Get the number (dimension) of atomic parameters of this atomic model.""" - return self.model.get_dim_aparam() - - def get_sel_type(self) -> list[int]: - """Get the selected atom types of this model. - - Only atoms with selected atom types have atomic contribution - to the result of the model. - If returning an empty list, all atom types are selected. - """ - return self.model.get_sel_type() - - def is_aparam_nall(self) -> bool: - """Check whether the shape of atomic parameters is (nframes, nall, ndim). - - If False, the shape is (nframes, nloc, ndim). - """ - return self.model.is_aparam_nall() - - def mixed_types(self) -> bool: - """If true, the model - 1. assumes total number of atoms aligned across frames; - 2. uses a neighbor list that does not distinguish different atomic types. - - If false, the model - 1. assumes total number of atoms of each atom type aligned across frames; - 2. uses a neighbor list that distinguishes different atomic types. - """ - return self.model.mixed_types() - - def has_message_passing(self) -> bool: - """Returns whether the descriptor has message passing.""" - return self.model.has_message_passing() - - def need_sorted_nlist_for_lower(self) -> bool: - """Returns whether the model needs sorted nlist when using `forward_lower`.""" - return self.model.need_sorted_nlist_for_lower() - - def forward( - self, - coord: torch.Tensor, - atype: torch.Tensor, - box: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, - ) -> dict[str, torch.Tensor]: - return self.model.forward( - coord, - atype, - box=box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) - - def forward_lower( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, - ) -> dict[str, torch.Tensor]: - return self.model.forward_lower( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) - - def get_model_def_script(self) -> str: - """Get the model definition script.""" - # try to use the original script instead of "frozen model" - return self.model.get_model_def_script() - - def get_min_nbor_dist(self) -> float | None: - """Get the minimum neighbor distance.""" - return self.model.get_min_nbor_dist() - - def get_nnei(self) -> int: - """Returns the total number of selected neighboring atoms in the cut-off radius.""" - return self.model.get_nnei() - - def get_nsel(self) -> int: - """Returns the total number of selected neighboring atoms in the cut-off radius.""" - return self.model.get_nsel() - - def model_output_type(self) -> str: - """Get the output type for the model.""" - return self.model.model_output_type() - - def get_observed_type_list(self) -> list[str]: - """Get observed types (elements) of the model during data statistics.""" - return self.model.get_observed_type_list() - def serialize(self) -> dict: """Serialize the model. - Returns - ------- - dict - The serialized data. + Unlike the pt backend (which must reconstruct from model_def_script + because its inner model is an opaque ScriptModule), pt_expt's inner + model is a real pt_expt model that can serialize directly. """ return self.model.serialize() - - @classmethod - def deserialize(cls, data: dict) -> NoReturn: - raise RuntimeError("Should not touch here.") - - @classmethod - def update_sel( - cls, - train_data: DeepmdDataSystem, - type_map: list[str] | None, - local_jdata: dict, - ) -> tuple[dict, float | None]: - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - train_data : DeepmdDataSystem - data used to do neighbor statistics - type_map : list[str], optional - The name of each type of atoms - local_jdata : dict - The local data refer to the current class - - Returns - ------- - dict - The updated local data - float - The minimum distance between two atoms - """ - return local_jdata, None diff --git a/source/tests/consistent/model/test_frozen.py b/source/tests/consistent/model/test_frozen.py index d2c33f3cd9..c19b1a8cb4 100644 --- a/source/tests/consistent/model/test_frozen.py +++ b/source/tests/consistent/model/test_frozen.py @@ -15,6 +15,7 @@ from ..common import ( INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -32,6 +33,10 @@ from deepmd.tf.model.model import Model as FrozenModelTF else: FrozenModelTF = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.model.model import BaseModel as FrozenModelPTExpt +else: + FrozenModelPTExpt = None from deepmd.utils.argcheck import ( model_args, @@ -49,8 +54,10 @@ def setUpModule() -> None: case = get_cases()["se_e2_a"] case.get_model(".dp", dp_model) - case.get_model(".pb", tf_model) - case.get_model(".pth", pt_model) + if INSTALLED_TF: + case.get_model(".pb", tf_model) + if INSTALLED_PT: + case.get_model(".pth", pt_model) def tearDownModule() -> None: @@ -82,6 +89,7 @@ def data(self) -> dict: tf_class = FrozenModelTF dp_class = None pt_class = FrozenModelPT + pt_expt_class = FrozenModelPTExpt args = model_args() def skip_dp(self) -> bool: @@ -154,10 +162,19 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_model( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.TF: return (ret[0].ravel(), ret[1].ravel()) - elif backend in {self.RefBackend.PT}: + elif backend in {self.RefBackend.PT, self.RefBackend.PT_EXPT}: return (ret["energy"].ravel(), ret["atom_energy"].ravel()) raise ValueError(f"Unknown backend: {backend}") From 4815fbd24b027dc0f67537d8a420659e193fca2f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Mar 2026 21:42:38 +0800 Subject: [PATCH 3/4] fix(pt_expt): fix super() call in FrozenModel and remove redundant serialize Use explicit NativeOP.__init__(self) instead of super(FrozenModelDP, self) to fix CodeQL "first argument to super() is not enclosing class" error. Remove serialize() override that duplicates the parent class method. --- deepmd/pt_expt/model/frozen.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/deepmd/pt_expt/model/frozen.py b/deepmd/pt_expt/model/frozen.py index 37feb5012f..03f8048370 100644 --- a/deepmd/pt_expt/model/frozen.py +++ b/deepmd/pt_expt/model/frozen.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.common import ( + NativeOP, +) from deepmd.dpmodel.model.frozen import FrozenModel as FrozenModelDP from deepmd.pt_expt.common import ( torch_module, @@ -17,7 +20,9 @@ @torch_module class FrozenModel(FrozenModelDP): def __init__(self, model_file: str, **kwargs: Any) -> None: - super(FrozenModelDP, self).__init__() + # Skip FrozenModelDP.__init__ which would load via Backend detection; + # pt_expt handles .pte natively and re-deserializes other formats itself. + NativeOP.__init__(self) self.model_file = model_file if model_file.endswith(".pte"): from deepmd.pt_expt.utils.serialization import ( @@ -35,12 +40,3 @@ def __init__(self, model_file: str, **kwargs: Any) -> None: data = inp_backend.serialize_hook(model_file) self.model = BaseModel.deserialize(data["model"]) self.model.eval() - - def serialize(self) -> dict: - """Serialize the model. - - Unlike the pt backend (which must reconstruct from model_def_script - because its inner model is an opaque ScriptModule), pt_expt's inner - model is a real pt_expt model that can serialize directly. - """ - return self.model.serialize() From 9e11539ac336ba65bfa3d87501fcb9d82c660b47 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Mar 2026 21:42:38 +0800 Subject: [PATCH 4/4] fix(pt_expt): fix super() call in FrozenModel and remove redundant serialize Use explicit NativeOP.__init__(self) instead of super(FrozenModelDP, self) to fix CodeQL "first argument to super() is not enclosing class" error. Remove serialize() override that duplicates the parent class method. --- deepmd/pt_expt/model/frozen.py | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/deepmd/pt_expt/model/frozen.py b/deepmd/pt_expt/model/frozen.py index 37feb5012f..beb8ffaeb0 100644 --- a/deepmd/pt_expt/model/frozen.py +++ b/deepmd/pt_expt/model/frozen.py @@ -17,30 +17,7 @@ @torch_module class FrozenModel(FrozenModelDP): def __init__(self, model_file: str, **kwargs: Any) -> None: - super(FrozenModelDP, self).__init__() - self.model_file = model_file - if model_file.endswith(".pte"): - from deepmd.pt_expt.utils.serialization import ( - serialize_from_file, - ) - - data = serialize_from_file(model_file) - self.model = BaseModel.deserialize(data["model"]) - else: - from deepmd.backend.backend import ( - Backend, - ) - - inp_backend: Backend = Backend.detect_backend_by_model(model_file)() - data = inp_backend.serialize_hook(model_file) - self.model = BaseModel.deserialize(data["model"]) + super().__init__(model_file, **kwargs) + # Re-deserialize as a pt_expt model (parent creates a dpmodel model) + self.model = BaseModel.deserialize(self.model.serialize()) self.model.eval() - - def serialize(self) -> dict: - """Serialize the model. - - Unlike the pt backend (which must reconstruct from model_def_script - because its inner model is an opaque ScriptModule), pt_expt's inner - model is a real pt_expt model that can serialize directly. - """ - return self.model.serialize()