From 239bb4e27577064362efe955b1199e0e73a7ceea Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 2 Jul 2026 15:01:26 +0800 Subject: [PATCH] fix(pd): preserve fparam/aparam inputs when freezing models The Paddle freeze entrypoint converted forward/forward_lower to static with fparam and aparam hardcoded to None in the input_spec. Models trained with required frame or atomic parameters therefore exported a static graph whose signature baked both values as None, so inference through the frozen model could not supply the required fparam/aparam. Build the fparam/aparam InputSpec from the model's get_dim_fparam/get_dim_aparam (None only when the dim is 0) and use it in both forward and forward_lower signatures. Add a unit test for the spec builder covering the unused, used, and fparam-only cases. Fix #5687 --- deepmd/pd/utils/serialization.py | 37 +++++++++++++++-- .../pd/model/test_serialization_fparam.py | 41 +++++++++++++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 source/tests/pd/model/test_serialization_fparam.py diff --git a/deepmd/pd/utils/serialization.py b/deepmd/pd/utils/serialization.py index bd70deb75c..a4db114f15 100644 --- a/deepmd/pd/utils/serialization.py +++ b/deepmd/pd/utils/serialization.py @@ -23,6 +23,32 @@ def serialize_from_file(model_file: str) -> dict: raise NotImplementedError("Paddle do not support jit.export yet.") +def _fparam_aparam_input_specs(model: "paddle.nn.Layer") -> tuple: + """Return the fparam/aparam static ``InputSpec``s for jit export. + + A spec is returned only when the model actually uses the corresponding + input (nonzero ``get_dim_fparam``/``get_dim_aparam``); otherwise ``None`` is + returned so the frozen signature keeps that argument optional. + """ + from paddle.static import ( + InputSpec, + ) + + dim_fparam = model.get_dim_fparam() + dim_aparam = model.get_dim_aparam() + fparam_spec = ( + InputSpec([-1, dim_fparam], dtype="float64", name="fparam") + if dim_fparam > 0 + else None + ) + aparam_spec = ( + InputSpec([-1, -1, dim_aparam], dtype="float64", name="aparam") + if dim_aparam > 0 + else None + ) + return fparam_spec, aparam_spec + + def deserialize_to_file(model_file: str, data: dict) -> None: """Deserialize the dictionary to a model file. @@ -57,6 +83,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: InputSpec, ) + # include fparam/aparam in the static signature when the model uses them + fparam_spec, aparam_spec = _fparam_aparam_input_specs(model) + """ example output shape and dtype of forward atom_energy: fetch_name_0 (1, 6, 1) float64 atom_virial: fetch_name_1 (1, 6, 1, 9) float64 @@ -72,8 +101,8 @@ def deserialize_to_file(model_file: str, data: dict) -> None: InputSpec([-1, -1, 3], dtype="float64", name="coord"), InputSpec([-1, -1], dtype="int64", name="atype"), InputSpec([-1, 9], dtype="float64", name="box"), - None, - None, + fparam_spec, + aparam_spec, True, ], ) @@ -92,8 +121,8 @@ def deserialize_to_file(model_file: str, data: dict) -> None: InputSpec([-1, -1], dtype="int32", name="atype"), InputSpec([-1, -1, -1], dtype="int32", name="nlist"), None, - None, - None, + fparam_spec, + aparam_spec, True, None, ], diff --git a/source/tests/pd/model/test_serialization_fparam.py b/source/tests/pd/model/test_serialization_fparam.py new file mode 100644 index 0000000000..9c56b245c2 --- /dev/null +++ b/source/tests/pd/model/test_serialization_fparam.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test that Paddle freeze static signatures include fparam/aparam when used.""" + +import unittest + +from deepmd.pd.utils.serialization import ( + _fparam_aparam_input_specs, +) + + +class _StubModel: + def __init__(self, dim_fparam: int, dim_aparam: int) -> None: + self._dim_fparam = dim_fparam + self._dim_aparam = dim_aparam + + def get_dim_fparam(self) -> int: + return self._dim_fparam + + def get_dim_aparam(self) -> int: + return self._dim_aparam + + +class TestFparamAparamInputSpecs(unittest.TestCase): + def test_absent_when_unused(self) -> None: + fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(0, 0)) + self.assertIsNone(fparam_spec) + self.assertIsNone(aparam_spec) + + def test_present_when_used(self) -> None: + fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(2, 3)) + self.assertIsNotNone(fparam_spec) + self.assertIsNotNone(aparam_spec) + self.assertEqual(fparam_spec.name, "fparam") + self.assertEqual(aparam_spec.name, "aparam") + self.assertEqual(list(fparam_spec.shape), [-1, 2]) + self.assertEqual(list(aparam_spec.shape), [-1, -1, 3]) + + def test_only_fparam(self) -> None: + fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(2, 0)) + self.assertIsNotNone(fparam_spec) + self.assertIsNone(aparam_spec)