Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions deepmd/pd/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,32 @@ def serialize_from_file(model_file: str) -> dict:
raise NotImplementedError("Paddle do not support jit.export yet.")


def _fparam_aparam_input_specs(model: "paddle.nn.Layer") -> tuple:
"""Return the fparam/aparam static ``InputSpec``s for jit export.

A spec is returned only when the model actually uses the corresponding
input (nonzero ``get_dim_fparam``/``get_dim_aparam``); otherwise ``None`` is
returned so the frozen signature keeps that argument optional.
"""
from paddle.static import (
InputSpec,
)

dim_fparam = model.get_dim_fparam()
dim_aparam = model.get_dim_aparam()
fparam_spec = (
InputSpec([-1, dim_fparam], dtype="float64", name="fparam")
if dim_fparam > 0
else None
)
aparam_spec = (
InputSpec([-1, -1, dim_aparam], dtype="float64", name="aparam")
if dim_aparam > 0
else None
)
return fparam_spec, aparam_spec


def deserialize_to_file(model_file: str, data: dict) -> None:
"""Deserialize the dictionary to a model file.

Expand Down Expand Up @@ -57,6 +83,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
InputSpec,
)

# include fparam/aparam in the static signature when the model uses them
fparam_spec, aparam_spec = _fparam_aparam_input_specs(model)

""" example output shape and dtype of forward
atom_energy: fetch_name_0 (1, 6, 1) float64
atom_virial: fetch_name_1 (1, 6, 1, 9) float64
Expand All @@ -72,8 +101,8 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
InputSpec([-1, -1, 3], dtype="float64", name="coord"),
InputSpec([-1, -1], dtype="int64", name="atype"),
InputSpec([-1, 9], dtype="float64", name="box"),
None,
None,
fparam_spec,
aparam_spec,
True,
],
)
Expand All @@ -92,8 +121,8 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
InputSpec([-1, -1], dtype="int32", name="atype"),
InputSpec([-1, -1, -1], dtype="int32", name="nlist"),
None,
None,
None,
fparam_spec,
aparam_spec,
True,
None,
],
Expand Down
41 changes: 41 additions & 0 deletions source/tests/pd/model/test_serialization_fparam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Test that Paddle freeze static signatures include fparam/aparam when used."""

import unittest

from deepmd.pd.utils.serialization import (
_fparam_aparam_input_specs,
)


class _StubModel:
def __init__(self, dim_fparam: int, dim_aparam: int) -> None:
self._dim_fparam = dim_fparam
self._dim_aparam = dim_aparam

def get_dim_fparam(self) -> int:
return self._dim_fparam

def get_dim_aparam(self) -> int:
return self._dim_aparam


class TestFparamAparamInputSpecs(unittest.TestCase):
def test_absent_when_unused(self) -> None:
fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(0, 0))
self.assertIsNone(fparam_spec)
self.assertIsNone(aparam_spec)

def test_present_when_used(self) -> None:
fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(2, 3))
self.assertIsNotNone(fparam_spec)
self.assertIsNotNone(aparam_spec)
self.assertEqual(fparam_spec.name, "fparam")
self.assertEqual(aparam_spec.name, "aparam")
self.assertEqual(list(fparam_spec.shape), [-1, 2])
self.assertEqual(list(aparam_spec.shape), [-1, -1, 3])

def test_only_fparam(self) -> None:
fparam_spec, aparam_spec = _fparam_aparam_input_specs(_StubModel(2, 0))
self.assertIsNotNone(fparam_spec)
self.assertIsNone(aparam_spec)
Loading