diff --git a/deepmd/tf/infer/ewald_recp.py b/deepmd/tf/infer/ewald_recp.py index 4c0b7b7cbc..4b1a0b27be 100644 --- a/deepmd/tf/infer/ewald_recp.py +++ b/deepmd/tf/infer/ewald_recp.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import numpy as np +from typing_extensions import ( + Self, +) from deepmd.tf.env import ( GLOBAL_TF_FLOAT_PRECISION, @@ -93,3 +96,18 @@ def eval( ) return energy, force, virial + + def close(self) -> None: + """Close the TensorFlow session held by this object.""" + sess = getattr(self, "sess", None) + if sess is not None: + sess.close() + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None: + self.close() + + def __del__(self) -> None: + self.close() diff --git a/deepmd/tf/modifier/dipole_charge.py b/deepmd/tf/modifier/dipole_charge.py index a5d1fbf975..c6ec07b234 100644 --- a/deepmd/tf/modifier/dipole_charge.py +++ b/deepmd/tf/modifier/dipole_charge.py @@ -105,6 +105,12 @@ def __init__( self.force = None self.ntypes = len(self.sel_a) + def close(self) -> None: + """Close the TensorFlow session held by the Ewald reciprocal evaluator.""" + er = getattr(self, "er", None) + if er is not None: + er.close() + def serialize(self) -> dict: """Serialize the modifier. diff --git a/source/tests/tf/test_dipolecharge.py b/source/tests/tf/test_dipolecharge.py index 71c46446f6..f677d791f9 100644 --- a/source/tests/tf/test_dipolecharge.py +++ b/source/tests/tf/test_dipolecharge.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import os import unittest +from unittest import ( + mock, +) import numpy as np @@ -125,6 +128,13 @@ def tearDownClass(cls) -> None: os.remove("dipolecharge_d.pb") cls.dp = None + def test_close_forwards_to_ewald(self) -> None: + # closing the modifier must release the EwaldRecp session. Patch the + # evaluator so the shared class-level modifier is not disturbed. + with mock.patch.object(self.dp, "er") as mock_er: + self.dp.close() + mock_er.close.assert_called_once() + def test_attrs(self) -> None: self.assertEqual(self.dp.get_ntypes(), 5) self.assertAlmostEqual(self.dp.get_rcut(), 4.0, places=default_places) diff --git a/source/tests/tf/test_ewald.py b/source/tests/tf/test_ewald.py index 270546fbc8..3b9cee66b2 100644 --- a/source/tests/tf/test_ewald.py +++ b/source/tests/tf/test_ewald.py @@ -224,3 +224,26 @@ def test_virial(self) -> None: np.testing.assert_almost_equal( t_esti.ravel(), virial.ravel(), places, err_msg="virial component failed" ) + + +class TestEwaldRecpClose(tf.test.TestCase): + """EwaldRecp owns a TensorFlow session that must be closeable.""" + + coord = np.array([[0.0, 0.0, 0.0, 1.0, 0.0, 0.0]]) + charge = np.array([[1.0, -1.0]]) + box = np.array([[10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0]]) + + def test_close_releases_session(self) -> None: + er = EwaldRecp(1.0, 1.0) + # a fresh evaluator works + er.eval(self.coord, self.charge, self.box) + er.close() + # after close the underlying session is unusable + with self.assertRaises(RuntimeError): + er.eval(self.coord, self.charge, self.box) + + def test_context_manager_closes_session(self) -> None: + with EwaldRecp(1.0, 1.0) as er: + er.eval(self.coord, self.charge, self.box) + with self.assertRaises(RuntimeError): + er.eval(self.coord, self.charge, self.box)