From 6bc3a4d9c25157df9846d6463aec597559306340 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 2 Jul 2026 09:32:28 +0800 Subject: [PATCH] fix(tf): close EwaldRecp TensorFlow sessions EwaldRecp creates a dedicated tf.Session in its constructor but exposed no way to close it, and DipoleChargeModifier holds an EwaldRecp without releasing it. Repeatedly constructing and discarding these objects leaked TensorFlow sessions and graph resources in long-running processes. Add close(), context-manager support, and a defensive __del__ to EwaldRecp, and have DipoleChargeModifier.close() forward to the EwaldRecp. Add tests that the session is released after close()/context-manager exit and that the modifier forwards close to its evaluator. Fix #5685 --- deepmd/tf/infer/ewald_recp.py | 18 ++++++++++++++++++ deepmd/tf/modifier/dipole_charge.py | 6 ++++++ source/tests/tf/test_dipolecharge.py | 10 ++++++++++ source/tests/tf/test_ewald.py | 23 +++++++++++++++++++++++ 4 files changed, 57 insertions(+) diff --git a/deepmd/tf/infer/ewald_recp.py b/deepmd/tf/infer/ewald_recp.py index 4c0b7b7cbc..4b1a0b27be 100644 --- a/deepmd/tf/infer/ewald_recp.py +++ b/deepmd/tf/infer/ewald_recp.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import numpy as np +from typing_extensions import ( + Self, +) from deepmd.tf.env import ( GLOBAL_TF_FLOAT_PRECISION, @@ -93,3 +96,18 @@ def eval( ) return energy, force, virial + + def close(self) -> None: + """Close the TensorFlow session held by this object.""" + sess = getattr(self, "sess", None) + if sess is not None: + sess.close() + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None: + self.close() + + def __del__(self) -> None: + self.close() diff --git a/deepmd/tf/modifier/dipole_charge.py b/deepmd/tf/modifier/dipole_charge.py index a5d1fbf975..c6ec07b234 100644 --- a/deepmd/tf/modifier/dipole_charge.py +++ b/deepmd/tf/modifier/dipole_charge.py @@ -105,6 +105,12 @@ def __init__( self.force = None self.ntypes = len(self.sel_a) + def close(self) -> None: + """Close the TensorFlow session held by the Ewald reciprocal evaluator.""" + er = getattr(self, "er", None) + if er is not None: + er.close() + def serialize(self) -> dict: """Serialize the modifier. diff --git a/source/tests/tf/test_dipolecharge.py b/source/tests/tf/test_dipolecharge.py index 71c46446f6..f677d791f9 100644 --- a/source/tests/tf/test_dipolecharge.py +++ b/source/tests/tf/test_dipolecharge.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import os import unittest +from unittest import ( + mock, +) import numpy as np @@ -125,6 +128,13 @@ def tearDownClass(cls) -> None: os.remove("dipolecharge_d.pb") cls.dp = None + def test_close_forwards_to_ewald(self) -> None: + # closing the modifier must release the EwaldRecp session. Patch the + # evaluator so the shared class-level modifier is not disturbed. + with mock.patch.object(self.dp, "er") as mock_er: + self.dp.close() + mock_er.close.assert_called_once() + def test_attrs(self) -> None: self.assertEqual(self.dp.get_ntypes(), 5) self.assertAlmostEqual(self.dp.get_rcut(), 4.0, places=default_places) diff --git a/source/tests/tf/test_ewald.py b/source/tests/tf/test_ewald.py index 270546fbc8..3b9cee66b2 100644 --- a/source/tests/tf/test_ewald.py +++ b/source/tests/tf/test_ewald.py @@ -224,3 +224,26 @@ def test_virial(self) -> None: np.testing.assert_almost_equal( t_esti.ravel(), virial.ravel(), places, err_msg="virial component failed" ) + + +class TestEwaldRecpClose(tf.test.TestCase): + """EwaldRecp owns a TensorFlow session that must be closeable.""" + + coord = np.array([[0.0, 0.0, 0.0, 1.0, 0.0, 0.0]]) + charge = np.array([[1.0, -1.0]]) + box = np.array([[10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0]]) + + def test_close_releases_session(self) -> None: + er = EwaldRecp(1.0, 1.0) + # a fresh evaluator works + er.eval(self.coord, self.charge, self.box) + er.close() + # after close the underlying session is unusable + with self.assertRaises(RuntimeError): + er.eval(self.coord, self.charge, self.box) + + def test_context_manager_closes_session(self) -> None: + with EwaldRecp(1.0, 1.0) as er: + er.eval(self.coord, self.charge, self.box) + with self.assertRaises(RuntimeError): + er.eval(self.coord, self.charge, self.box)