diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 38db47809d..0db472dca9 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -530,6 +530,9 @@ def get_single_frame(self, index: int, num_worker: int) -> dict: frame_data, self, ) + # propagate any exception raised inside the modifier instead of + # silently returning (and caching) an unmodified frame + future.result() if self.use_modifier_cache: # Cache the modified frame to avoid recomputation self._modified_frame_cache[index] = copy.deepcopy(frame_data) diff --git a/source/tests/common/test_deepmd_data.py b/source/tests/common/test_deepmd_data.py index c7a26491cd..0f26d2466a 100644 --- a/source/tests/common/test_deepmd_data.py +++ b/source/tests/common/test_deepmd_data.py @@ -46,3 +46,41 @@ def test_remap_with_unused_types(self) -> None: loaded = data._load_set(self.set_dir) expected_sorted = expected_atom_types[data.idx_map] np.testing.assert_array_equal(loaded["type"], np.tile(expected_sorted, (1, 1))) + + +class _RaisingModifier: + """A modifier whose ``modify_data`` always fails.""" + + use_cache = True + + def modify_data(self, data: dict, data_sys: DeepmdData) -> None: + raise ValueError("modifier failure") + + +class TestDeepmdDataModifierError(unittest.TestCase): + def setUp(self) -> None: + self.tmpdir = tempfile.TemporaryDirectory() + self.root = Path(self.tmpdir.name) + set_dir = self.root / "set.000" + set_dir.mkdir() + atom_types = np.array([0, 1], dtype=np.int32) + np.savetxt(self.root / "type.raw", atom_types, fmt="%d") + np.save( + set_dir / "coord.npy", + np.zeros((3, atom_types.size * 3), dtype=np.float32), + ) + np.save( + set_dir / "box.npy", + np.tile(np.eye(3, dtype=np.float32).reshape(9), (3, 1)), + ) + + def tearDown(self) -> None: + self.tmpdir.cleanup() + + def test_get_single_frame_propagates_modifier_error(self) -> None: + data = DeepmdData(str(self.root), modifier=_RaisingModifier()) + # a failing modifier must surface its error, not be swallowed + with self.assertRaises(ValueError): + data.get_single_frame(0, num_worker=1) + # and the unmodified frame must not be cached + self.assertNotIn(0, data._modified_frame_cache)