diff --git a/deepmd/pd/utils/serialization.py b/deepmd/pd/utils/serialization.py index bd70deb75c..a4db114f15 100644 --- a/deepmd/pd/utils/serialization.py +++ b/deepmd/pd/utils/serialization.py @@ -23,6 +23,32 @@ def serialize_from_file(model_file: str) -> dict: raise NotImplementedError("Paddle do not support jit.export yet.") +def _fparam_aparam_input_specs(model: "paddle.nn.Layer") -> tuple: + """Return the fparam/aparam static ``InputSpec``s for jit export. + + A spec is returned only when the model actually uses the corresponding + input (nonzero ``get_dim_fparam``/``get_dim_aparam``); otherwise ``None`` is + returned so the frozen signature keeps that argument optional. + """ + from paddle.static import ( + InputSpec, + ) + + dim_fparam = model.get_dim_fparam() + dim_aparam = model.get_dim_aparam() + fparam_spec = ( + InputSpec([-1, dim_fparam], dtype="float64", name="fparam") + if dim_fparam > 0 + else None + ) + aparam_spec = ( + InputSpec([-1, -1, dim_aparam], dtype="float64", name="aparam") + if dim_aparam > 0 + else None + ) + return fparam_spec, aparam_spec + + def deserialize_to_file(model_file: str, data: dict) -> None: """Deserialize the dictionary to a model file. @@ -57,6 +83,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: InputSpec, ) + # include fparam/aparam in the static signature when the model uses them + fparam_spec, aparam_spec = _fparam_aparam_input_specs(model) + """ example output shape and dtype of forward atom_energy: fetch_name_0 (1, 6, 1) float64 atom_virial: fetch_name_1 (1, 6, 1, 9) float64 @@ -72,8 +101,8 @@ def deserialize_to_file(model_file: str, data: dict) -> None: InputSpec([-1, -1, 3], dtype="float64", name="coord"), InputSpec([-1, -1], dtype="int64", name="atype"), InputSpec([-1, 9], dtype="float64", name="box"), - None, - None, + fparam_spec, + aparam_spec, True, ], ) @@ -92,8 +121,8 @@ def deserialize_to_file(model_file: str, data: dict) -> None: InputSpec([-1, -1], dtype="int32", name="atype"), InputSpec([-1, -1, -1], dtype="int32", name="nlist"), None, - None, - None, + fparam_spec, + aparam_spec, True, None, ], diff --git a/source/tests/pd/model/test_serialization_fparam.py b/source/tests/pd/model/test_serialization_fparam.py new file mode 100644 index 0000000000..9c56b245c2 --- /dev/null +++ b/source/tests/pd/model/test_serialization_fparam.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test that Paddle freeze static signatures include fparam/aparam when used.""" + +import unittest + +from deepmd.pd.utils.serialization import ( + _fparam_aparam_input_specs, +) + + +class _StubModel: + def __init__(self, dim_fparam: int, dim_aparam: int) -> None: + self._dim_fparam = dim_fparam + self._dim_aparam = dim_aparam + + def get_dim_fparam(self) -> int: + return self._dim_fparam + + def get_dim_aparam(self) -> int: + return self._dim_aparam + + +class TestFparamAparamInputSpecs(unittest.TestCase): + def test_absent_when_unused(self) -> None: + fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(0, 0)) + self.assertIsNone(fparam_spec) + self.assertIsNone(aparam_spec) + + def test_present_when_used(self) -> None: + fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(2, 3)) + self.assertIsNotNone(fparam_spec) + self.assertIsNotNone(aparam_spec) + self.assertEqual(fparam_spec.name, "fparam") + self.assertEqual(aparam_spec.name, "aparam") + self.assertEqual(list(fparam_spec.shape), [-1, 2]) + self.assertEqual(list(aparam_spec.shape), [-1, -1, 3]) + + def test_only_fparam(self) -> None: + fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(2, 0)) + self.assertIsNotNone(fparam_spec) + self.assertIsNone(aparam_spec)