diff --git a/deepmd/pt_expt/__init__.py b/deepmd/pt_expt/__init__.py index 6ceb116d85..f18bb9749c 100644 --- a/deepmd/pt_expt/__init__.py +++ b/deepmd/pt_expt/__init__.py @@ -1 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.utils.entry_point import ( + load_entry_point, +) + +load_entry_point("deepmd.pt_expt") diff --git a/source/tests/pt_expt/test_plugin.py b/source/tests/pt_expt/test_plugin.py new file mode 100644 index 0000000000..a59242e592 --- /dev/null +++ b/source/tests/pt_expt/test_plugin.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import importlib +import importlib.metadata +import sys + + +class _FakeEntryPoint: + def __init__(self, calls): + self.calls = calls + + def load(self): + self.calls.append("load") + + +def test_pt_expt_loads_plugin_entry_points(monkeypatch): + groups = [] + calls = [] + + def fake_entry_points(*, group=None): + groups.append(group) + return [_FakeEntryPoint(calls)] + + monkeypatch.setattr(importlib.metadata, "entry_points", fake_entry_points) + sys.modules.pop("deepmd.pt_expt", None) + + try: + importlib.import_module("deepmd.pt_expt") + finally: + sys.modules.pop("deepmd.pt_expt", None) + + assert groups == ["deepmd.pt_expt"] + assert calls == ["load"]