From bcceaccf3290a76e33831ebceba9d462ba1d4ce9 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:18:10 -0700 Subject: [PATCH 01/20] feat: add warehouse primitives for handling protocol units --- src/openfe/storage/warehouse.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index c10baa12c..e0f2c771a 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -5,6 +5,8 @@ import re from typing import Literal, TypedDict +from gufe.protocols.protocoldag import ProtocolDAG +from gufe.protocols.protocolunit import ProtocolUnit from gufe.storage.externalresource import ExternalStorage, FileStorage from gufe.tokenization import ( JSON_HANDLER, @@ -35,6 +37,8 @@ class WarehouseStores(TypedDict): setup: ExternalStorage result: ExternalStorage + shared: ExternalStorage + tasks: ExternalStorage class WarehouseBaseClass: @@ -83,6 +87,12 @@ def delete(self, store_name: Literal["setup", "result"], location: str): store: ExternalStorage = self.stores[store_name] store.delete(location) + def store_task(self, obj: ProtocolUnit): + self._store_gufe_tokenizable("tasks", obj) + + def load_task(self, obj: GufeKey): + self._load_gufe_tokenizable(obj) + def store_setup_tokenizable(self, obj: GufeTokenizable): """Store a GufeTokenizable object in the setup store. @@ -134,7 +144,7 @@ def load_result_tokenizable(self, obj: GufeKey) -> GufeTokenizable: return self._load_gufe_tokenizable(gufe_key=obj) def exists(self, key: GufeKey) -> bool: - """Check if an object with the given key exists in any store. + """Check if an object with the given key exists in any store that holds tokenizables. Parameters ---------- @@ -171,7 +181,12 @@ def _get_store_for_key(self, key: GufeKey) -> ExternalStorage: return self.stores[name] raise ValueError(f"GufeKey {key} is not stored") - def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: GufeTokenizable): + def _store_gufe_tokenizable( + self, + store_name: Literal["setup", "result", "tasks"], + obj: GufeTokenizable, + name: str | None = None, + ): """Store a GufeTokenizable object with deduplication. Parameters @@ -197,7 +212,10 @@ def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: G data = json.dumps(keyed_dict, cls=JSON_HANDLER.encoder, sort_keys=True).encode( "utf-8" ) - target.store_bytes(gufe_key, data) + if name: + target.store_bytes(name, data) + else: + target.store_bytes(gufe_key, data) def _load_gufe_tokenizable(self, gufe_key: GufeKey) -> GufeTokenizable: """Load a deduplicated object from a GufeKey. @@ -315,5 +333,9 @@ class FileSystemWarehouse(WarehouseBaseClass): def __init__(self, root_dir: str = "warehouse"): setup_store = FileStorage(f"{root_dir}/setup") result_store = FileStorage(f"{root_dir}/result") - stores = WarehouseStores(setup=setup_store, result=result_store) + shared_store = FileStorage(f"{root_dir}/shared") + tasks_store = FileStorage(f"{root_dir}/tasks") + stores = WarehouseStores( + setup=setup_store, result=result_store, shared=shared_store, tasks=tasks_store + ) super().__init__(stores) From 17eb5ff3819e8cacccd463062a5698a78b015e83 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:18:40 -0700 Subject: [PATCH 02/20] feat: inital worker for exorcist --- environment.yml | 1 + src/openfe/orchestration/__init__.py | 57 ++++++++++++++++++++++ src/openfe/orchestration/exorcist_utils.py | 53 ++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 src/openfe/orchestration/exorcist_utils.py diff --git a/environment.yml b/environment.yml index c45d8102a..99e814252 100644 --- a/environment.yml +++ b/environment.yml @@ -54,6 +54,7 @@ dependencies: - threadpoolctl - pip: - git+https://github.com/OpenFreeEnergy/gufe@main + - git+https://github.com/OpenFreeEnergy/exorcist@main - run_constrained: # drop this pin when handled upstream in espaloma-feedstock - smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109 diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index e69de29bb..d8e31db06 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from pathlib import Path + +from exorcist.taskdb import TaskStatusDB +from gufe.protocols.protocoldag import _pu_to_pur +from gufe.protocols.protocolunit import ( + Context, + ProtocolUnit, + ProtocolUnitFailure, + ProtocolUnitResult, +) +from gufe.storage.externalresource.filestorage import FileStorage +from gufe.tokenization import GufeKey + +from openfe.storage.warehouse import FileSystemWarehouse + +from .exorcist_utils import ( + alchemical_network_to_task_graph, + build_task_db_from_alchemical_network, +) + + +@dataclass +class Worker: + warehouse: FileSystemWarehouse + + def _get_task(self) -> ProtocolUnit: + # Right now, we are just going to assume it exists in the warehouse folder + location = Path("./warehouse/tasks.db") + + db: TaskStatusDB = TaskStatusDB.from_filename(location) + # The format for the taskid is going to "Transformation-:Unit" + taskid = db.check_out_task() + # Load the unit from warehouse and return + unit = taskid.split(":") + + return self.warehouse.load_task(unit) + + def execute_unit(self, scratch: Path): + # 1. Get task/unit + unit = self._get_task() + # 2. Constrcut the context + # NOTE: On changes to context, this can easily be replaced with external storage objects + # However, to satisfy the current work, we will use this implementation where we + # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. + shared_store: FileStorage = self.warehouse.stores["shared"] + shared_root_dir = shared_store.root_dir + ctx = Context(scratch, shared=shared_root_dir) + results: dict[GufeKey, ProtocolUnitResult] = {} + inputs = _pu_to_pur(unit.inputs, results) + # 3. Execute unit + result = unit.execute(context=ctx, **inputs) + # if not result.ok(): + # Increment attempt in taskdb + # 4. output result to warehouse + # TODO: we may need to end up handling namespacing on the warehouse side for tokenizables + self.warehouse.store_result_tokenizable(result) diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py new file mode 100644 index 000000000..fc30c8c11 --- /dev/null +++ b/src/openfe/orchestration/exorcist_utils.py @@ -0,0 +1,53 @@ +"""Utilities for building Exorcist task graphs and task databases.""" + +from pathlib import Path + +import exorcist +import networkx as nx +from gufe import AlchemicalNetwork + +from openfe.storage.warehouse import WarehouseBaseClass + + +def alchemical_network_to_task_graph( + alchemical_network: AlchemicalNetwork, warehouse: WarehouseBaseClass +) -> nx.DiGraph: + """Build a global task DAG from an AlchemicalNetwork.""" + + global_dag = nx.DiGraph() + for transformation in alchemical_network.edges: + dag = transformation.create() + for unit in dag.protocol_units: + node_id = f"{transformation.name}-{transformation.key}:{unit.name}-{unit.key}" + global_dag.add_node( + node_id, + label=f"{transformation.name}\n{unit.name}", + transformation_key=str(transformation.key), + protocol_unit_key=str(unit.key), + ) + warehouse.store_task(unit) + for u, v in dag.graph.edges: + u_id = f"{transformation.key}:{u.key}" + v_id = f"{transformation.key}:{v.key}" + global_dag.add_edge(u_id, v_id) + + if not nx.is_directed_acyclic_graph(global_dag): + raise ValueError("AlchemicalNetwork produced a task graph that is not a DAG.") + + return global_dag + + +def build_task_db_from_alchemical_network( + alchemical_network: AlchemicalNetwork, + warehouse: WarehouseBaseClass, + db_path: Path | None = None, + max_tries: int = 1, +) -> exorcist.TaskStatusDB: + """Create an Exorcist TaskStatusDB from an AlchemicalNetwork.""" + if db_path is None: + db_path = Path("tasks.db") + + global_dag = alchemical_network_to_task_graph(alchemical_network, warehouse) + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(global_dag, max_tries) + return db From 65bf138edf6cf6e8cc5094c45952f2d5fdc91746 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:57:22 -0700 Subject: [PATCH 03/20] test: add tests for warehouse --- src/openfe/tests/storage/test_warehouse.py | 96 +++++++++++++++++++--- 1 file changed, 84 insertions(+), 12 deletions(-) diff --git a/src/openfe/tests/storage/test_warehouse.py b/src/openfe/tests/storage/test_warehouse.py index 572769d70..d113cf1d7 100644 --- a/src/openfe/tests/storage/test_warehouse.py +++ b/src/openfe/tests/storage/test_warehouse.py @@ -19,18 +19,35 @@ class TestWarehouseBaseClass: def test_store_protocol_dag_result(self): pytest.skip("Not implemented yet") + @staticmethod + def _build_stores() -> WarehouseStores: + return WarehouseStores( + setup=MemoryStorage(), + result=MemoryStorage(), + shared=MemoryStorage(), + tasks=MemoryStorage(), + ) + + @staticmethod + def _get_protocol_unit(transformation): + dag = transformation.create() + return next(iter(dag.protocol_units)) + @staticmethod def _test_store_load_same_process( - obj, store_func_name, load_func_name, store_name: Literal["setup", "result"] + obj, + store_func_name, + load_func_name, + store_name: Literal["setup", "result", "tasks"], ): - setup_store = MemoryStorage() - result_store = MemoryStorage() - stores = WarehouseStores(setup=setup_store, result=result_store) + stores = TestWarehouseBaseClass._build_stores() client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert setup_store._data == {} - assert result_store._data == {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + assert stores["tasks"]._data == {} store_func(obj) store_under_test: MemoryStorage = stores[store_name] assert store_under_test._data != {} @@ -43,16 +60,16 @@ def _test_store_load_different_process( obj: GufeTokenizable, store_func_name, load_func_name, - store_name: Literal["setup", "result"], + store_name: Literal["setup", "result", "tasks"], ): - setup_store = MemoryStorage() - result_store = MemoryStorage() - stores = WarehouseStores(setup=setup_store, result=result_store) + stores = TestWarehouseBaseClass._build_stores() client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert setup_store._data == {} - assert result_store._data == {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + assert stores["tasks"]._data == {} store_func(obj) store_under_test: MemoryStorage = stores[store_name] assert store_under_test._data != {} @@ -65,6 +82,45 @@ def _test_store_load_different_process( assert reload == obj assert reload is not obj + def test_store_load_task_same_process(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + self._test_store_load_same_process(unit, "store_task", "load_task", "tasks") + + def test_store_load_task_different_process(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + self._test_store_load_different_process(unit, "store_task", "load_task", "tasks") + + def test_store_task_writes_to_tasks_store(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + client.store_task(unit) + + assert stores["tasks"]._data != {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + + def test_exists_finds_task_key(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + + client.store_task(unit) + + assert client.exists(unit.key) + + def test_load_task_returns_object(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + + client.store_task(unit) + loaded = client.load_task(unit.key) + + assert loaded is not None + assert isinstance(loaded, GufeTokenizable) + @pytest.mark.parametrize( "fixture", ["absolute_transformation", "complex_equilibrium"], @@ -164,6 +220,22 @@ def test_store_load_transformation_same_process(self, request, fixture): "load_setup_tokenizable", ) + def test_filesystemwarehouse_has_shared_and_tasks_stores(self, absolute_transformation): + unit = TestWarehouseBaseClass._get_protocol_unit(absolute_transformation) + + with tempfile.TemporaryDirectory() as tmpdir: + client = FileSystemWarehouse(tmpdir) + + assert "shared" in client.stores + assert "tasks" in client.stores + + client.stores["shared"].store_bytes("sentinel", b"shared-data") + with client.stores["shared"].load_stream("sentinel") as f: + assert f.read() == b"shared-data" + + client.store_task(unit) + assert client.exists(unit.key) + @pytest.mark.parametrize( "fixture", ["absolute_transformation", "complex_equilibrium"], From bab0e2f407a25141d013ec1fae9b190bb6ec77fc Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:57:41 -0700 Subject: [PATCH 04/20] fix: can now return protocol unit --- src/openfe/storage/warehouse.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index e0f2c771a..0b46a8c6c 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -90,8 +90,11 @@ def delete(self, store_name: Literal["setup", "result"], location: str): def store_task(self, obj: ProtocolUnit): self._store_gufe_tokenizable("tasks", obj) - def load_task(self, obj: GufeKey): - self._load_gufe_tokenizable(obj) + def load_task(self, obj: GufeKey) -> ProtocolUnit: + unit = self._load_gufe_tokenizable(obj) + if not isinstance(unit, ProtocolUnit): + raise ValueError("Unable to load ProtocolUnit") + return unit def store_setup_tokenizable(self, obj: GufeTokenizable): """Store a GufeTokenizable object in the setup store. From c1e0a2879365fc480cecbdbc98241365c2828f8c Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:05:51 -0700 Subject: [PATCH 05/20] refactor: make things more consistent --- src/openfe/orchestration/__init__.py | 4 ++-- src/openfe/orchestration/exorcist_utils.py | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index d8e31db06..42c169d96 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -32,9 +32,9 @@ def _get_task(self) -> ProtocolUnit: # The format for the taskid is going to "Transformation-:Unit" taskid = db.check_out_task() # Load the unit from warehouse and return - unit = taskid.split(":") + _, protocol_unit_key = taskid.split(":", maxsplit=1) - return self.warehouse.load_task(unit) + return self.warehouse.load_task(GufeKey(protocol_unit_key)) def execute_unit(self, scratch: Path): # 1. Get task/unit diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py index fc30c8c11..51c269b2e 100644 --- a/src/openfe/orchestration/exorcist_utils.py +++ b/src/openfe/orchestration/exorcist_utils.py @@ -18,17 +18,14 @@ def alchemical_network_to_task_graph( for transformation in alchemical_network.edges: dag = transformation.create() for unit in dag.protocol_units: - node_id = f"{transformation.name}-{transformation.key}:{unit.name}-{unit.key}" + node_id = f"{str(transformation.key)}:{str(unit.key)}" global_dag.add_node( node_id, - label=f"{transformation.name}\n{unit.name}", - transformation_key=str(transformation.key), - protocol_unit_key=str(unit.key), ) warehouse.store_task(unit) for u, v in dag.graph.edges: - u_id = f"{transformation.key}:{u.key}" - v_id = f"{transformation.key}:{v.key}" + u_id = f"{str(transformation.key)}:{str(u.key)}" + v_id = f"{str(transformation.key)}:{str(v.key)}" global_dag.add_edge(u_id, v_id) if not nx.is_directed_acyclic_graph(global_dag): From e93af94b21b3d0303fb2fd2d5107e9399ba9f445 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:07:17 -0700 Subject: [PATCH 06/20] test: initial test setup for orchestration subpackage --- src/openfe/tests/orchestration/__init__.py | 2 + src/openfe/tests/orchestration/conftest.py | 118 +++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 src/openfe/tests/orchestration/__init__.py create mode 100644 src/openfe/tests/orchestration/conftest.py diff --git a/src/openfe/tests/orchestration/__init__.py b/src/openfe/tests/orchestration/__init__.py new file mode 100644 index 000000000..efae32ddb --- /dev/null +++ b/src/openfe/tests/orchestration/__init__.py @@ -0,0 +1,2 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe diff --git a/src/openfe/tests/orchestration/conftest.py b/src/openfe/tests/orchestration/conftest.py new file mode 100644 index 000000000..1851b7c05 --- /dev/null +++ b/src/openfe/tests/orchestration/conftest.py @@ -0,0 +1,118 @@ +import gufe +import pytest +from gufe import ChemicalSystem, SolventComponent +from gufe.tests.test_protocol import DummyProtocol +from openff.units import unit + + +@pytest.fixture +def solv_comp(): + yield SolventComponent(positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar) + + +@pytest.fixture +def solvated_complex(T4_protein_component, benzene_transforms, solv_comp): + return ChemicalSystem( + { + "ligand": benzene_transforms["toluene"], + "protein": T4_protein_component, + "solvent": solv_comp, + } + ) + + +@pytest.fixture +def solvated_ligand(benzene_transforms, solv_comp): + return ChemicalSystem( + { + "ligand": benzene_transforms["toluene"], + "solvent": solv_comp, + } + ) + + +@pytest.fixture +def absolute_transformation(solvated_ligand, solvated_complex): + return gufe.Transformation( + solvated_ligand, + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + +@pytest.fixture +def complex_equilibrium(solvated_complex): + return gufe.NonTransformation( + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + ) + + +@pytest.fixture +def benzene_variants_star_map(benzene_transforms, solv_comp, T4_protein_component): + variants = ["toluene", "phenol", "benzonitrile", "anisole", "benzaldehyde", "styrene"] + + # define the solvent chemical systems and transformations between + # benzene and the others + solvated_ligands = {} + solvated_ligand_transformations = {} + + solvated_ligands["benzene"] = ChemicalSystem( + { + "solvent": solv_comp, + "ligand": benzene_transforms["benzene"], + }, + name="benzene-solvent", + ) + + for ligand in variants: + solvated_ligands[ligand] = ChemicalSystem( + { + "solvent": solv_comp, + "ligand": benzene_transforms[ligand], + }, + name=f"{ligand}-solvent", + ) + + solvated_ligand_transformations[("benzene", ligand)] = gufe.Transformation( + solvated_ligands["benzene"], + solvated_ligands[ligand], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + # define the complex chemical systems and transformations between + # benzene and the others + solvated_complexes = {} + solvated_complex_transformations = {} + + solvated_complexes["benzene"] = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "solvent": solv_comp, + "ligand": benzene_transforms["benzene"], + }, + name="benzene-complex", + ) + + for ligand in variants: + solvated_complexes[ligand] = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "solvent": solv_comp, + "ligand": benzene_transforms[ligand], + }, + name=f"{ligand}-complex", + ) + solvated_complex_transformations[("benzene", ligand)] = gufe.Transformation( + solvated_complexes["benzene"], + solvated_complexes[ligand], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + return gufe.AlchemicalNetwork( + list(solvated_ligand_transformations.values()) + + list(solvated_complex_transformations.values()) + ) From 764ee54f6c3d6d3fc0fb89f0cc575ac1cfa66292 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:07:46 -0700 Subject: [PATCH 07/20] test: initial exorcist utility testing --- .../orchestration/test_exorcist_utils.py | 210 ++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 src/openfe/tests/orchestration/test_exorcist_utils.py diff --git a/src/openfe/tests/orchestration/test_exorcist_utils.py b/src/openfe/tests/orchestration/test_exorcist_utils.py new file mode 100644 index 000000000..3ae9a85ad --- /dev/null +++ b/src/openfe/tests/orchestration/test_exorcist_utils.py @@ -0,0 +1,210 @@ +from pathlib import Path +from unittest import mock + +import exorcist +import networkx as nx +import pytest +import sqlalchemy as sqla +from gufe.tokenization import GufeKey + +from openfe.orchestration.exorcist_utils import ( + alchemical_network_to_task_graph, + build_task_db_from_alchemical_network, +) +from openfe.storage.warehouse import FileSystemWarehouse + + +class _RecordingWarehouse: + def __init__(self): + self.stored_tasks = [] + + def store_task(self, task): + self.stored_tasks.append(task) + + +def _network_units(benzene_variants_star_map): + units = [] + for transformation in benzene_variants_star_map.edges: + units.extend(transformation.create().protocol_units) + return units + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_stores_all_units(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + expected_units = _network_units(network) + + alchemical_network_to_task_graph(network, warehouse) + + stored_unit_names = [str(unit.name) for unit in warehouse.stored_tasks] + expected_unit_names = [str(unit.name) for unit in expected_units] + + assert len(stored_unit_names) == len(expected_unit_names) + assert sorted(stored_unit_names) == sorted(expected_unit_names) + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_uses_canonical_task_ids(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + + graph = alchemical_network_to_task_graph(network, warehouse) + + transformation_keys = {str(transformation.key) for transformation in network.edges} + expected_protocol_unit_keys = sorted(str(unit.key) for unit in warehouse.stored_tasks) + observed_protocol_unit_keys = [] + + for node in graph.nodes: + transformation_key, protocol_unit_key = node.split(":", maxsplit=1) + assert transformation_key in transformation_keys + observed_protocol_unit_keys.append(protocol_unit_key) + + assert sorted(observed_protocol_unit_keys) == expected_protocol_unit_keys + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_edges_reference_existing_nodes(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + + graph = alchemical_network_to_task_graph(network, warehouse) + + assert len(graph.edges) > 0 + for u, v in graph.edges: + assert u in graph.nodes + assert v in graph.nodes + + +def test_alchemical_network_to_task_graph_raises_for_cycle(): + class _Unit: + def __init__(self, name: str, key: str): + self.name = name + self.key = key + + class _Transformation: + name = "cyclic" + key = "Transformation-cycle" + + def create(self): + unit_a = _Unit("unit-a", "ProtocolUnit-a") + unit_b = _Unit("unit-b", "ProtocolUnit-b") + dag = mock.Mock() + dag.protocol_units = [unit_a, unit_b] + dag.graph = nx.DiGraph() + dag.graph.add_nodes_from([unit_a, unit_b]) + dag.graph.add_edges_from([(unit_a, unit_b), (unit_b, unit_a)]) + return dag + + network = mock.Mock() + network.edges = [_Transformation()] + warehouse = mock.Mock() + + with pytest.raises(ValueError, match="not a DAG"): + alchemical_network_to_task_graph(network, warehouse) + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_checkout_order_is_dependency_safe(tmp_path, request, fixture): + network = request.getfixturevalue(fixture) + warehouse = FileSystemWarehouse(str(tmp_path / "warehouse")) + # Build the real sqlite task DB from a real alchemical network fixture. + db = build_task_db_from_alchemical_network( + network, + warehouse, + db_path=tmp_path / "tasks.db", + ) + + # Read task IDs and dependency edges from the persisted DB state. + initial_task_rows = list(db.get_all_tasks()) + graph_taskids = {row.taskid for row in initial_task_rows} + with db.engine.connect() as conn: + dep_rows = conn.execute(sqla.select(db.dependencies_table)).all() + graph_edges = {(row._mapping["from"], row._mapping["to"]) for row in dep_rows} + + checkout_order = [] + # Hard upper bound prevents infinite checkout loops. + max_checkouts = len(graph_taskids) + print(f"Max Checkout={max_checkouts}") + for _ in range(max_checkouts): + taskid = db.check_out_task() + if taskid is None: + break + + checkout_order.append(taskid) + _, protocol_unit_key = taskid.split(":", maxsplit=1) + loaded_unit = warehouse.load_task(GufeKey(protocol_unit_key)) + assert str(loaded_unit.key) == protocol_unit_key + db.mark_task_completed(taskid, success=True) + + # Coverage/completion: every task is checked out exactly once. + observed_taskids = set(checkout_order) + assert observed_taskids == graph_taskids + assert len(checkout_order) == len(graph_taskids) + + # Dependency safety: upstream tasks must appear before downstream tasks. + checkout_index = {taskid: idx for idx, taskid in enumerate(checkout_order)} + for upstream, downstream in graph_edges: + assert checkout_index[upstream] < checkout_index[downstream] + + # Final DB state: all tasks are completed. + task_rows = list(db.get_all_tasks()) + assert len(task_rows) == len(graph_taskids) + assert {row.taskid for row in task_rows} == graph_taskids + assert {row.status for row in task_rows} == {exorcist.TaskStatus.COMPLETED.value} + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_default_path(request, fixture): + network = request.getfixturevalue(fixture) + warehouse = mock.Mock() + fake_graph = nx.DiGraph() + fake_db = mock.Mock() + + with ( + mock.patch( + "openfe.orchestration.exorcist_utils.alchemical_network_to_task_graph", + return_value=fake_graph, + ) as task_graph_mock, + mock.patch( + "openfe.orchestration.exorcist_utils.exorcist.TaskStatusDB.from_filename", + return_value=fake_db, + ) as db_ctor, + ): + result = build_task_db_from_alchemical_network(network, warehouse) + + task_graph_mock.assert_called_once_with(network, warehouse) + db_ctor.assert_called_once_with(Path("tasks.db")) + fake_db.add_task_network.assert_called_once_with(fake_graph, 1) + assert result is fake_db + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_forwards_graph_and_max_tries(request, tmp_path, fixture): + network = request.getfixturevalue(fixture) + warehouse = mock.Mock() + fake_graph = nx.DiGraph() + fake_db = mock.Mock() + db_path = tmp_path / "custom_tasks.db" + + with ( + mock.patch( + "openfe.orchestration.exorcist_utils.alchemical_network_to_task_graph", + return_value=fake_graph, + ) as task_graph_mock, + mock.patch( + "openfe.orchestration.exorcist_utils.exorcist.TaskStatusDB.from_filename", + return_value=fake_db, + ) as db_ctor, + ): + result = build_task_db_from_alchemical_network( + network, + warehouse, + db_path=db_path, + max_tries=7, + ) + + task_graph_mock.assert_called_once_with(network, warehouse) + db_ctor.assert_called_once_with(db_path) + fake_db.add_task_network.assert_called_once_with(fake_graph, 7) + assert result is fake_db From 4ca30de569009e5e7cbe992d769f0e2d0bc59277 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:42:30 -0700 Subject: [PATCH 08/20] refactor: provide a root path to the exorcist DB --- src/openfe/orchestration/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index 42c169d96..98338df02 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -23,12 +23,10 @@ @dataclass class Worker: warehouse: FileSystemWarehouse + task_db_path: Path = Path("./warehouse/tasks.db") def _get_task(self) -> ProtocolUnit: - # Right now, we are just going to assume it exists in the warehouse folder - location = Path("./warehouse/tasks.db") - - db: TaskStatusDB = TaskStatusDB.from_filename(location) + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) # The format for the taskid is going to "Transformation-:Unit" taskid = db.check_out_task() # Load the unit from warehouse and return @@ -43,7 +41,7 @@ def execute_unit(self, scratch: Path): # NOTE: On changes to context, this can easily be replaced with external storage objects # However, to satisfy the current work, we will use this implementation where we # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. - shared_store: FileStorage = self.warehouse.stores["shared"] + shared_store: FileStorage = self.warehouse.shared_store.root_dir shared_root_dir = shared_store.root_dir ctx = Context(scratch, shared=shared_root_dir) results: dict[GufeKey, ProtocolUnitResult] = {} From 14e73b6bae1ae6b7d2d00a5ebba061b48af3fa08 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:43:01 -0700 Subject: [PATCH 09/20] test: inital worker testing --- src/openfe/tests/orchestration/test_worker.py | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 src/openfe/tests/orchestration/test_worker.py diff --git a/src/openfe/tests/orchestration/test_worker.py b/src/openfe/tests/orchestration/test_worker.py new file mode 100644 index 000000000..a99f872c0 --- /dev/null +++ b/src/openfe/tests/orchestration/test_worker.py @@ -0,0 +1,123 @@ +from pathlib import Path +from unittest import mock + +import exorcist +import gufe +import networkx as nx +import pytest +from gufe.protocols.protocolunit import ProtocolUnit + +from openfe.orchestration import Worker +from openfe.orchestration.exorcist_utils import build_task_db_from_alchemical_network +from openfe.storage.warehouse import FileSystemWarehouse + + +def _result_store_files(warehouse: FileSystemWarehouse) -> set[str]: + result_root = Path(warehouse.result_store.root_dir) + return {str(path.relative_to(result_root)) for path in result_root.rglob("*") if path.is_file()} + + +def _contains_protocol_unit(value) -> bool: + if isinstance(value, ProtocolUnit): + return True + if isinstance(value, dict): + return any(_contains_protocol_unit(item) for item in value.values()) + if isinstance(value, list): + return any(_contains_protocol_unit(item) for item in value) + return False + + +def _get_dependency_free_unit(absolute_transformation): + for unit in absolute_transformation.create().protocol_units: + if not _contains_protocol_unit(unit.inputs): + return unit + raise ValueError("No dependency-free protocol unit found for execution test setup.") + + +@pytest.fixture +def worker_with_real_db(tmp_path, absolute_transformation): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + network = gufe.AlchemicalNetwork([absolute_transformation]) + db = build_task_db_from_alchemical_network(network, warehouse, db_path=db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + return worker, warehouse, db + + +@pytest.fixture +def worker_with_executable_task_db(tmp_path, absolute_transformation): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + unit = _get_dependency_free_unit(absolute_transformation) + warehouse.store_task(unit) + + taskid = f"{absolute_transformation.key}:{unit.key}" + task_graph = nx.DiGraph() + task_graph.add_node(taskid) + + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(task_graph, 1) + + worker = Worker(warehouse=warehouse, task_db_path=db_path) + return worker, warehouse, db, unit + + +def test_get_task_uses_default_db_path_without_patching( + tmp_path, monkeypatch, absolute_transformation +): + monkeypatch.chdir(tmp_path) + warehouse = FileSystemWarehouse("warehouse") + db_path = Path("warehouse/tasks.db") + network = gufe.AlchemicalNetwork([absolute_transformation]) + db = build_task_db_from_alchemical_network(network, warehouse, db_path=db_path) + + worker = Worker(warehouse=warehouse) + loaded = worker._get_task() + + expected_keys = {task_row.taskid.split(":", maxsplit=1)[1] for task_row in db.get_all_tasks()} + assert worker.task_db_path == Path("./warehouse/tasks.db") + assert str(loaded.key) in expected_keys + + +def test_get_task_returns_task_with_canonical_protocol_unit_suffix(worker_with_real_db): + worker, warehouse, db = worker_with_real_db + + task_ids = [row.taskid for row in db.get_all_tasks()] + expected_protocol_unit_keys = {task_id.split(":", maxsplit=1)[1] for task_id in task_ids} + + loaded = worker._get_task() + reloaded = warehouse.load_task(loaded.key) + + assert str(loaded.key) in expected_protocol_unit_keys + assert loaded == reloaded + + +def test_execute_unit_stores_real_result(worker_with_executable_task_db, tmp_path): + worker, warehouse, _, _ = worker_with_executable_task_db + before = _result_store_files(warehouse) + + worker.execute_unit(scratch=tmp_path / "scratch") + + after = _result_store_files(warehouse) + assert len(after) > len(before) + + +def test_execute_unit_propagates_execute_error_without_store( + worker_with_executable_task_db, tmp_path +): + worker, warehouse, _, unit = worker_with_executable_task_db + before = _result_store_files(warehouse) + + with mock.patch.object( + type(unit), + "execute", + autospec=True, + side_effect=RuntimeError("unit execution failed"), + ): + with pytest.raises(RuntimeError, match="unit execution failed"): + worker.execute_unit(scratch=tmp_path / "scratch") + + after = _result_store_files(warehouse) + assert after == before From 08a1c561eec789797ca9d16b9c857d49333c48a9 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:08:27 -0700 Subject: [PATCH 10/20] feat: add shared_store --- src/openfe/storage/warehouse.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index 0b46a8c6c..843dfc9f0 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -314,6 +314,17 @@ def result_store(self): """ return self.stores["result"] + @property + def shared_store(self): + """Get the shared store. + + Returns + ------- + ExternalStorage + The shared storage location + """ + return self.stores["shared"] + class FileSystemWarehouse(WarehouseBaseClass): """Warehouse implementation using local filesystem storage. From 6f41dcf8714a837c4793c3eae2bf1373625fbfa1 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:10:09 -0700 Subject: [PATCH 11/20] feat: add better handling for CLI application Signed-off-by: Ethan Holz --- src/openfe/orchestration/__init__.py | 39 ++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index 98338df02..9c8410843 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -6,7 +6,6 @@ from gufe.protocols.protocolunit import ( Context, ProtocolUnit, - ProtocolUnitFailure, ProtocolUnitResult, ) from gufe.storage.externalresource.filestorage import FileStorage @@ -25,31 +24,49 @@ class Worker: warehouse: FileSystemWarehouse task_db_path: Path = Path("./warehouse/tasks.db") - def _get_task(self) -> ProtocolUnit: + def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) - # The format for the taskid is going to "Transformation-:Unit" + # The format for the taskid is "Transformation-:ProtocolUnit-" taskid = db.check_out_task() - # Load the unit from warehouse and return + if taskid is None: + return None + _, protocol_unit_key = taskid.split(":", maxsplit=1) + unit = self.warehouse.load_task(GufeKey(protocol_unit_key)) + return taskid, unit - return self.warehouse.load_task(GufeKey(protocol_unit_key)) + def _get_task(self) -> ProtocolUnit: + task = self._checkout_task() + if task is None: + raise RuntimeError("No AVAILABLE tasks found in the task database.") + _, unit = task + return unit - def execute_unit(self, scratch: Path): + def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: # 1. Get task/unit - unit = self._get_task() + task = self._checkout_task() + if task is None: + return None + taskid, unit = task # 2. Constrcut the context # NOTE: On changes to context, this can easily be replaced with external storage objects # However, to satisfy the current work, we will use this implementation where we # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. - shared_store: FileStorage = self.warehouse.shared_store.root_dir + shared_store: FileStorage = self.warehouse.stores["shared"] shared_root_dir = shared_store.root_dir ctx = Context(scratch, shared=shared_root_dir) results: dict[GufeKey, ProtocolUnitResult] = {} inputs = _pu_to_pur(unit.inputs, results) + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) # 3. Execute unit - result = unit.execute(context=ctx, **inputs) - # if not result.ok(): - # Increment attempt in taskdb + try: + result = unit.execute(context=ctx, **inputs) + except Exception: + db.mark_task_completed(taskid, success=False) + raise + + db.mark_task_completed(taskid, success=result.ok()) # 4. output result to warehouse # TODO: we may need to end up handling namespacing on the warehouse side for tokenizables self.warehouse.store_result_tokenizable(result) + return taskid, result From d915ca348a60dbe5e0bd393b05b254bf10b95056 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:21:16 -0700 Subject: [PATCH 12/20] test: add new worker tests --- src/openfe/tests/orchestration/test_worker.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/src/openfe/tests/orchestration/test_worker.py b/src/openfe/tests/orchestration/test_worker.py index a99f872c0..dca306d26 100644 --- a/src/openfe/tests/orchestration/test_worker.py +++ b/src/openfe/tests/orchestration/test_worker.py @@ -95,20 +95,26 @@ def test_get_task_returns_task_with_canonical_protocol_unit_suffix(worker_with_r def test_execute_unit_stores_real_result(worker_with_executable_task_db, tmp_path): - worker, warehouse, _, _ = worker_with_executable_task_db + worker, warehouse, db, _ = worker_with_executable_task_db before = _result_store_files(warehouse) - worker.execute_unit(scratch=tmp_path / "scratch") + execution = worker.execute_unit(scratch=tmp_path / "scratch") + assert execution is not None + taskid, _ = execution after = _result_store_files(warehouse) assert len(after) > len(before) + rows = list(db.get_all_tasks()) + status_by_taskid = {row.taskid: row.status for row in rows} + assert status_by_taskid[taskid] == exorcist.TaskStatus.COMPLETED.value def test_execute_unit_propagates_execute_error_without_store( worker_with_executable_task_db, tmp_path ): - worker, warehouse, _, unit = worker_with_executable_task_db + worker, warehouse, db, unit = worker_with_executable_task_db before = _result_store_files(warehouse) + taskid = list(db.get_all_tasks())[0].taskid with mock.patch.object( type(unit), @@ -121,3 +127,28 @@ def test_execute_unit_propagates_execute_error_without_store( after = _result_store_files(warehouse) assert after == before + rows = list(db.get_all_tasks()) + status_by_taskid = {row.taskid: row.status for row in rows} + assert status_by_taskid[taskid] == exorcist.TaskStatus.TOO_MANY_RETRIES.value + + +def test_checkout_task_returns_none_when_no_available_tasks(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse_root.mkdir(parents=True, exist_ok=True) + warehouse = FileSystemWarehouse(str(warehouse_root)) + exorcist.TaskStatusDB.from_filename(db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + assert worker._checkout_task() is None + + +def test_execute_unit_returns_none_when_no_available_tasks(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse_root.mkdir(parents=True, exist_ok=True) + warehouse = FileSystemWarehouse(str(warehouse_root)) + exorcist.TaskStatusDB.from_filename(db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + assert worker.execute_unit(scratch=tmp_path / "scratch") is None From ff7598fedc3412e39e584fcc075d4f87ecf85ba5 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:22:51 -0700 Subject: [PATCH 13/20] feat: add exorcist worker to CLI --- src/openfecli/commands/worker.py | 83 ++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 src/openfecli/commands/worker.py diff --git a/src/openfecli/commands/worker.py b/src/openfecli/commands/worker.py new file mode 100644 index 000000000..dbc52b6b3 --- /dev/null +++ b/src/openfecli/commands/worker.py @@ -0,0 +1,83 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pathlib + +import click + +from openfecli import OFECommandPlugin +from openfecli.utils import print_duration, write + + +def _build_worker(warehouse_path: pathlib.Path, db_path: pathlib.Path): + from openfe.orchestration import Worker + from openfe.storage.warehouse import FileSystemWarehouse + + warehouse = FileSystemWarehouse(str(warehouse_path)) + return Worker(warehouse=warehouse, task_db_path=db_path) + + +def worker_main(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): + db_path = warehouse_path / "tasks.db" + if not db_path.is_file(): + raise click.ClickException(f"Task database not found at: {db_path}") + + if scratch is None: + scratch = pathlib.Path.cwd() + + scratch.mkdir(parents=True, exist_ok=True) + + worker = _build_worker(warehouse_path, db_path) + + try: + execution = worker.execute_unit(scratch=scratch) + except Exception as exc: + raise click.ClickException(f"Task execution failed: {exc}") from exc + + if execution is None: + write("No available task in task graph.") + return None + + taskid, result = execution + if not result.ok(): + raise click.ClickException(f"Task '{taskid}' returned a failure result.") + + write(f"Completed task: {taskid}") + return result + + +@click.command("worker", short_help="Execute one available task from a filesystem warehouse") +@click.argument( + "warehouse_path", + type=click.Path( + exists=True, + readable=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path, + ), +) +@click.option( + "--scratch", + "-s", + default=None, + type=click.Path( + writable=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path, + ), + help="Directory for scratch files. Defaults to current working directory.", +) +@print_duration +def worker(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): + """ + Execute one available task from a warehouse task graph. + + The warehouse directory must contain a ``tasks.db`` task database and task + payloads under ``tasks/`` created via OpenFE orchestration setup. + """ + worker_main(warehouse_path=warehouse_path, scratch=scratch) + + +PLUGIN = OFECommandPlugin(command=worker, section="Quickrun Executor", requires_ofe=(0, 3)) From e79a8dd7fbcbd899b2ef592cf9367e7964cdab8a Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:24:15 -0700 Subject: [PATCH 14/20] test: add for worker CLI command --- src/openfecli/tests/commands/test_worker.py | 107 ++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 src/openfecli/tests/commands/test_worker.py diff --git a/src/openfecli/tests/commands/test_worker.py b/src/openfecli/tests/commands/test_worker.py new file mode 100644 index 000000000..6d3b55f7e --- /dev/null +++ b/src/openfecli/tests/commands/test_worker.py @@ -0,0 +1,107 @@ +from pathlib import Path +from unittest import mock + +from click.testing import CliRunner + +from openfecli.commands.worker import worker + + +class _SuccessfulResult: + def ok(self): + return True + + +class _FailedResult: + def ok(self): + return False + + +def test_worker_requires_task_database(): + runner = CliRunner() + with runner.isolated_filesystem(): + Path("warehouse").mkdir() + result = runner.invoke(worker, ["warehouse"]) + assert result.exit_code == 1 + assert "Task database not found at" in result.output + + +def test_worker_no_available_task_exits_zero(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = None + + with mock.patch( + "openfecli.commands.worker._build_worker", return_value=mock_worker + ) as build_worker: + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 0 + assert "No available task in task graph." in result.output + build_worker.assert_called_once_with(warehouse_path, warehouse_path / "tasks.db") + kwargs = mock_worker.execute_unit.call_args.kwargs + assert kwargs["scratch"] == Path.cwd() + + +def test_worker_executes_one_task_and_reports_completion(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = ( + "Transformation-abc:ProtocolUnit-def", + _SuccessfulResult(), + ) + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse", "--scratch", "scratch"]) + + assert result.exit_code == 0 + assert "Completed task: Transformation-abc:ProtocolUnit-def" in result.output + assert Path("scratch").is_dir() + kwargs = mock_worker.execute_unit.call_args.kwargs + assert kwargs["scratch"] == Path("scratch") + + +def test_worker_raises_when_result_is_failure(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = ( + "Transformation-abc:ProtocolUnit-def", + _FailedResult(), + ) + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 1 + assert "returned a failure result" in result.output + + +def test_worker_raises_when_execution_throws(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.side_effect = RuntimeError("boom") + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 1 + assert "Task execution failed: boom" in result.output From 8d4a1399bc92b88cae7fc381152cb20c31eb77ea Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 20:29:28 -0700 Subject: [PATCH 15/20] docs: add numpy docstrings --- src/openfe/orchestration/__init__.py | 54 ++++++++++++++++++++++ src/openfe/orchestration/exorcist_utils.py | 51 ++++++++++++++++++-- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index 9c8410843..c46809f44 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -1,3 +1,5 @@ +"""Task orchestration utilities backed by Exorcist and a warehouse.""" + from dataclasses import dataclass from pathlib import Path @@ -21,10 +23,29 @@ @dataclass class Worker: + """Execute protocol units from an Exorcist task database. + + Parameters + ---------- + warehouse : FileSystemWarehouse + Warehouse used to load queued tasks and store execution results. + task_db_path : pathlib.Path, default=Path("./warehouse/tasks.db") + Path to the Exorcist SQLite task database. + """ + warehouse: FileSystemWarehouse task_db_path: Path = Path("./warehouse/tasks.db") def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: + """Check out one available task and load its protocol unit. + + Returns + ------- + tuple[str, ProtocolUnit] or None + The checked-out task ID and corresponding protocol unit, or + ``None`` if no task is currently available. + """ + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) # The format for the taskid is "Transformation-:ProtocolUnit-" taskid = db.check_out_task() @@ -36,6 +57,19 @@ def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: return taskid, unit def _get_task(self) -> ProtocolUnit: + """Return the next available protocol unit. + + Returns + ------- + ProtocolUnit + A protocol unit loaded from the warehouse. + + Raises + ------ + RuntimeError + Raised when no task is available in the task database. + """ + task = self._checkout_task() if task is None: raise RuntimeError("No AVAILABLE tasks found in the task database.") @@ -43,6 +77,26 @@ def _get_task(self) -> ProtocolUnit: return unit def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: + """Execute one checked-out protocol unit and persist its result. + + Parameters + ---------- + scratch : pathlib.Path + Scratch directory passed to the protocol execution context. + + Returns + ------- + tuple[str, ProtocolUnitResult] or None + The task ID and execution result for the processed task, or + ``None`` if no task is currently available. + + Raises + ------ + Exception + Re-raises any exception thrown during protocol unit execution after + marking the task as failed. + """ + # 1. Get task/unit task = self._checkout_task() if task is None: diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py index 51c269b2e..2ce8f1e9c 100644 --- a/src/openfe/orchestration/exorcist_utils.py +++ b/src/openfe/orchestration/exorcist_utils.py @@ -1,4 +1,8 @@ -"""Utilities for building Exorcist task graphs and task databases.""" +"""Utilities for building Exorcist task graphs and task databases. + +This module translates an :class:`gufe.AlchemicalNetwork` into Exorcist task +structures and can initialize an Exorcist task database from that graph. +""" from pathlib import Path @@ -12,7 +16,28 @@ def alchemical_network_to_task_graph( alchemical_network: AlchemicalNetwork, warehouse: WarehouseBaseClass ) -> nx.DiGraph: - """Build a global task DAG from an AlchemicalNetwork.""" + """Build a global task DAG from an alchemical network. + + Parameters + ---------- + alchemical_network : AlchemicalNetwork + Network containing transformations to execute. + warehouse : WarehouseBaseClass + Warehouse used to persist protocol units as tasks while the graph is + constructed. + + Returns + ------- + nx.DiGraph + A directed acyclic graph where each node is a task ID in the form + ``":"`` and edges encode + protocol-unit dependencies. + + Raises + ------ + ValueError + Raised if the assembled task graph is not acyclic. + """ global_dag = nx.DiGraph() for transformation in alchemical_network.edges: @@ -40,7 +65,27 @@ def build_task_db_from_alchemical_network( db_path: Path | None = None, max_tries: int = 1, ) -> exorcist.TaskStatusDB: - """Create an Exorcist TaskStatusDB from an AlchemicalNetwork.""" + """Create and populate a task database from an alchemical network. + + Parameters + ---------- + alchemical_network : AlchemicalNetwork + Network containing transformations to convert into task records. + warehouse : WarehouseBaseClass + Warehouse used to persist protocol units while building the task DAG. + db_path : pathlib.Path or None, optional + Location of the SQLite-backed Exorcist database. If ``None``, defaults + to ``Path("tasks.db")`` in the current working directory. + max_tries : int, default=1 + Maximum number of retries for each task before Exorcist marks it as + ``TOO_MANY_RETRIES``. + + Returns + ------- + exorcist.TaskStatusDB + Initialized task database populated with graph nodes and dependency + edges derived from ``alchemical_network``. + """ if db_path is None: db_path = Path("tasks.db") From deca0d8bdd4a436e7497f6cf7fb7cfde7dfb216e Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Fri, 30 Jan 2026 14:46:39 -0700 Subject: [PATCH 16/20] feat: add support for planning an RBFE to a Warehouse --- src/openfe/storage/warehouse.py | 2 + src/openfecli/commands/plan_rbfe_network.py | 10 +++- src/openfecli/parameters/__init__.py | 1 + src/openfecli/parameters/warehouse.py | 4 ++ .../plan_alchemical_networks_utils.py | 47 +++++++++++-------- 5 files changed, 43 insertions(+), 21 deletions(-) create mode 100644 src/openfecli/parameters/warehouse.py diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index 843dfc9f0..f983daaa0 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -179,6 +179,7 @@ def _get_store_for_key(self, key: GufeKey) -> ExternalStorage: ValueError If the key is not found in any store. """ + print(key) for name in self.stores: if key in self.stores[name]: return self.stores[name] @@ -345,6 +346,7 @@ class FileSystemWarehouse(WarehouseBaseClass): """ def __init__(self, root_dir: str = "warehouse"): + self.root_dir = root_dir setup_store = FileStorage(f"{root_dir}/setup") result_store = FileStorage(f"{root_dir}/result") shared_store = FileStorage(f"{root_dir}/shared") diff --git a/src/openfecli/commands/plan_rbfe_network.py b/src/openfecli/commands/plan_rbfe_network.py index 5b73b5618..e74b06d1d 100644 --- a/src/openfecli/commands/plan_rbfe_network.py +++ b/src/openfecli/commands/plan_rbfe_network.py @@ -1,8 +1,8 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe - import click +from openfe.storage.warehouse import FileSystemWarehouse from openfecli import OFECommandPlugin from openfecli.parameters import ( COFACTORS, @@ -12,6 +12,7 @@ OUTPUT_DIR, OVERWRITE, PROTEIN, + WAREHOUSE, YAML_OPTIONS, ) from openfecli.utils import print_duration, write @@ -130,6 +131,7 @@ def plan_rbfe_network_main( @N_PROTOCOL_REPEATS.parameter(multiple=False, required=False, default=3, help=N_PROTOCOL_REPEATS.kwargs["help"]) # fmt: skip @NCORES.parameter(help=NCORES.kwargs["help"], default=1) @OVERWRITE.parameter(help=OVERWRITE.kwargs["help"], default=OVERWRITE.kwargs["default"], is_flag=True) # fmt: skip +@WAREHOUSE.parameter(help=WAREHOUSE.kwargs["help"], is_flag=True) @print_duration def plan_rbfe_network( molecules: list[str], @@ -140,6 +142,7 @@ def plan_rbfe_network( n_protocol_repeats: int, n_cores: int, overwrite_charges: bool, + warehouse: bool, ): """ Plan a relative binding free energy network, saved as JSON files for use by @@ -243,10 +246,15 @@ def plan_rbfe_network( # OUTPUT write("Output:") write("\tSaving to: " + str(output_dir)) + warehouse_object = None + if warehouse: + warehouse_object = FileSystemWarehouse() + plan_alchemical_network_output( alchemical_network=alchemical_network, ligand_network=ligand_network, folder_path=OUTPUT_DIR.get(output_dir), + warehouse=warehouse_object, ) diff --git a/src/openfecli/parameters/__init__.py b/src/openfecli/parameters/__init__.py index fb8dcd0f9..81e96e18f 100644 --- a/src/openfecli/parameters/__init__.py +++ b/src/openfecli/parameters/__init__.py @@ -9,3 +9,4 @@ from .output_dir import OUTPUT_DIR from .plan_network_options import YAML_OPTIONS from .protein import PROTEIN +from .warehouse import WAREHOUSE diff --git a/src/openfecli/parameters/warehouse.py b/src/openfecli/parameters/warehouse.py new file mode 100644 index 000000000..5fb2f07f6 --- /dev/null +++ b/src/openfecli/parameters/warehouse.py @@ -0,0 +1,4 @@ +import click +from plugcli.params import Option + +WAREHOUSE = Option("--warehouse", type=click.BOOL, help="Use a warehouse", default=False) diff --git a/src/openfecli/plan_alchemical_networks_utils.py b/src/openfecli/plan_alchemical_networks_utils.py index 9636fb9b5..da0572603 100644 --- a/src/openfecli/plan_alchemical_networks_utils.py +++ b/src/openfecli/plan_alchemical_networks_utils.py @@ -4,8 +4,10 @@ import json import pathlib +from typing import Optional from openfe import AlchemicalNetwork, LigandNetwork +from openfe.storage.warehouse import WarehouseBaseClass from openfecli.utils import write @@ -13,26 +15,31 @@ def plan_alchemical_network_output( alchemical_network: AlchemicalNetwork, ligand_network: LigandNetwork, folder_path: pathlib.Path, + warehouse: Optional[WarehouseBaseClass], ): """Write the contents of an alchemical network into the structure""" - base_name = folder_path.name - folder_path.mkdir(parents=True, exist_ok=True) - - an_json = folder_path / f"{base_name}.json" - alchemical_network.to_json(an_json) - write("\t\t- " + base_name + ".json") - - ln_fname = "ligand_network.graphml" - with open(folder_path / ln_fname, mode="w") as f: - f.write(ligand_network.to_graphml()) - write(f"\t\t- {ln_fname}") - - transformations_dir = folder_path / "transformations" - transformations_dir.mkdir(parents=True, exist_ok=True) - - for transformation in alchemical_network.edges: - transformation_name = transformation.name or transformation.key - filename = f"{transformation_name}.json" - transformation.to_json(transformations_dir / filename) - write("\t\t\t\t- " + filename) + if warehouse: + warehouse.store_setup_tokenizable(alchemical_network) + warehouse.store_setup_tokenizable(ligand_network) + else: + base_name = folder_path.name + folder_path.mkdir(parents=True, exist_ok=True) + + an_json = folder_path / f"{base_name}.json" + alchemical_network.to_json(an_json) + write("\t\t- " + base_name + ".json") + + ln_fname = "ligand_network.graphml" + with open(folder_path / ln_fname, mode="w") as f: + f.write(ligand_network.to_graphml()) + write(f"\t\t- {ln_fname}") + + transformations_dir = folder_path / "transformations" + transformations_dir.mkdir(parents=True, exist_ok=True) + + for transformation in alchemical_network.edges: + transformation_name = transformation.name or transformation.key + filename = f"{transformation_name}.json" + transformation.to_json(transformations_dir / filename) + write("\t\t\t\t- " + filename) From a31973459b14263a792c2f2a8b5d7edd8c25790c Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Thu, 5 Mar 2026 18:04:58 -0700 Subject: [PATCH 17/20] fix: correct edge direction for task graph --- src/openfe/orchestration/exorcist_utils.py | 8 +++--- .../orchestration/test_exorcist_utils.py | 26 +++++++++++++++---- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py index 2ce8f1e9c..414f5da8c 100644 --- a/src/openfe/orchestration/exorcist_utils.py +++ b/src/openfe/orchestration/exorcist_utils.py @@ -48,10 +48,10 @@ def alchemical_network_to_task_graph( node_id, ) warehouse.store_task(unit) - for u, v in dag.graph.edges: - u_id = f"{str(transformation.key)}:{str(u.key)}" - v_id = f"{str(transformation.key)}:{str(v.key)}" - global_dag.add_edge(u_id, v_id) + for dependent_unit, dependency_unit in dag.graph.edges: + upstream_id = f"{str(transformation.key)}:{str(dependency_unit.key)}" + downstream_id = f"{str(transformation.key)}:{str(dependent_unit.key)}" + global_dag.add_edge(upstream_id, downstream_id) if not nx.is_directed_acyclic_graph(global_dag): raise ValueError("AlchemicalNetwork produced a task graph that is not a DAG.") diff --git a/src/openfe/tests/orchestration/test_exorcist_utils.py b/src/openfe/tests/orchestration/test_exorcist_utils.py index 3ae9a85ad..22413c771 100644 --- a/src/openfe/tests/orchestration/test_exorcist_utils.py +++ b/src/openfe/tests/orchestration/test_exorcist_utils.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast from unittest import mock import exorcist @@ -11,7 +12,7 @@ alchemical_network_to_task_graph, build_task_db_from_alchemical_network, ) -from openfe.storage.warehouse import FileSystemWarehouse +from openfe.storage.warehouse import FileSystemWarehouse, WarehouseBaseClass class _RecordingWarehouse: @@ -35,7 +36,7 @@ def test_alchemical_network_to_task_graph_stores_all_units(request, fixture): network = request.getfixturevalue(fixture) expected_units = _network_units(network) - alchemical_network_to_task_graph(network, warehouse) + alchemical_network_to_task_graph(network, cast(WarehouseBaseClass, warehouse)) stored_unit_names = [str(unit.name) for unit in warehouse.stored_tasks] expected_unit_names = [str(unit.name) for unit in expected_units] @@ -49,7 +50,7 @@ def test_alchemical_network_to_task_graph_uses_canonical_task_ids(request, fixtu warehouse = _RecordingWarehouse() network = request.getfixturevalue(fixture) - graph = alchemical_network_to_task_graph(network, warehouse) + graph = alchemical_network_to_task_graph(network, cast(WarehouseBaseClass, warehouse)) transformation_keys = {str(transformation.key) for transformation in network.edges} expected_protocol_unit_keys = sorted(str(unit.key) for unit in warehouse.stored_tasks) @@ -68,7 +69,7 @@ def test_alchemical_network_to_task_graph_edges_reference_existing_nodes(request warehouse = _RecordingWarehouse() network = request.getfixturevalue(fixture) - graph = alchemical_network_to_task_graph(network, warehouse) + graph = alchemical_network_to_task_graph(network, cast(WarehouseBaseClass, warehouse)) assert len(graph.edges) > 0 for u, v in graph.edges: @@ -76,6 +77,22 @@ def test_alchemical_network_to_task_graph_edges_reference_existing_nodes(request assert v in graph.nodes +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_edge_direction_matches_dependencies(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + + graph = alchemical_network_to_task_graph(network, cast(WarehouseBaseClass, warehouse)) + units_by_key = {str(unit.key): unit for unit in warehouse.stored_tasks} + + for upstream_id, downstream_id in graph.edges: + _, upstream_key = upstream_id.split(":", maxsplit=1) + _, downstream_key = downstream_id.split(":", maxsplit=1) + upstream_unit = units_by_key[upstream_key] + downstream_unit = units_by_key[downstream_key] + assert upstream_unit in downstream_unit.dependencies + + def test_alchemical_network_to_task_graph_raises_for_cycle(): class _Unit: def __init__(self, name: str, key: str): @@ -125,7 +142,6 @@ def test_build_task_db_checkout_order_is_dependency_safe(tmp_path, request, fixt checkout_order = [] # Hard upper bound prevents infinite checkout loops. max_checkouts = len(graph_taskids) - print(f"Max Checkout={max_checkouts}") for _ in range(max_checkouts): taskid = db.check_out_task() if taskid is None: From 27480faea4a11a4a877c3b95734e922add75ed2c Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Thu, 5 Mar 2026 18:05:24 -0700 Subject: [PATCH 18/20] refactor: remove extra debugging from warehouse --- src/openfe/storage/warehouse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index f983daaa0..f1540b429 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -179,7 +179,6 @@ def _get_store_for_key(self, key: GufeKey) -> ExternalStorage: ValueError If the key is not found in any store. """ - print(key) for name in self.stores: if key in self.stores[name]: return self.stores[name] From 43159a10a8ad5ef8b1e9000350d7eb23e32241b7 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Thu, 5 Mar 2026 18:06:09 -0700 Subject: [PATCH 19/20] refactor: cleanup handling of tasks for worker --- src/openfe/orchestration/__init__.py | 152 +++++++++++++++--- src/openfe/tests/orchestration/test_worker.py | 124 +++++++++++++- 2 files changed, 256 insertions(+), 20 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index c46809f44..4317d3987 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -10,6 +10,7 @@ ProtocolUnit, ProtocolUnitResult, ) +from gufe.storage.externalresource.base import ExternalStorage from gufe.storage.externalresource.filestorage import FileStorage from gufe.tokenization import GufeKey @@ -36,14 +37,124 @@ class Worker: warehouse: FileSystemWarehouse task_db_path: Path = Path("./warehouse/tasks.db") - def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: + _RESULT_INDEX_PREFIX = "protocol_unit_results" + _TASK_WORKDIR_PREFIX = "task_workdirs" + + @staticmethod + def _collect_protocol_unit_keys(value: object) -> set[GufeKey]: + """Collect `ProtocolUnit` keys from nested unit inputs.""" + + if isinstance(value, ProtocolUnit): + return {value.key} + + found: set[GufeKey] = set() + if isinstance(value, dict): + items = value.values() + elif isinstance(value, list): + items = value + else: + return found + + for item in items: + found.update(Worker._collect_protocol_unit_keys(item)) + return found + + @classmethod + def _result_index_location(cls, source_key: GufeKey) -> str: + return f"{cls._RESULT_INDEX_PREFIX}/{source_key}" + + @classmethod + def _task_workdir_name(cls, taskid: str) -> str: + return taskid.replace(":", "__") + + def _task_workspace_paths( + self, taskid: str, scratch_root: Path, shared_root: Path + ) -> tuple[Path, Path]: + workdir_name = self._task_workdir_name(taskid) + task_scratch = scratch_root / self._TASK_WORKDIR_PREFIX / workdir_name + task_shared = shared_root / self._TASK_WORKDIR_PREFIX / workdir_name + return task_scratch, task_shared + + def _store_result_index(self, result: ProtocolUnitResult) -> None: + shared_store: ExternalStorage = self.warehouse.stores["shared"] + location = self._result_index_location(result.source_key) + shared_store.store_bytes(location, str(result.key).encode("utf-8")) + + def _load_result_from_index(self, source_key: GufeKey) -> ProtocolUnitResult | None: + shared_store: ExternalStorage = self.warehouse.stores["shared"] + location = self._result_index_location(source_key) + + if not shared_store.exists(location): + return None + + with shared_store.load_stream(location) as stream: + result_key = stream.read().decode("utf-8").strip() + + loaded = self.warehouse.load_result_tokenizable(GufeKey(result_key)) + if isinstance(loaded, ProtocolUnitResult): + return loaded + + return None + + def _scan_result_store_for_sources( + self, source_keys: set[GufeKey] + ) -> dict[GufeKey, ProtocolUnitResult]: + found: dict[GufeKey, ProtocolUnitResult] = {} + + for location in self.warehouse.result_store.iter_contents(): + if len(found) == len(source_keys): + break + + loaded = self.warehouse.load_result_tokenizable(GufeKey(location)) + if not isinstance(loaded, ProtocolUnitResult): + continue + + source_key = loaded.source_key + if source_key in source_keys and source_key not in found: + found[source_key] = loaded + + return found + + def _build_input_result_mapping(self, unit: ProtocolUnit) -> dict[GufeKey, ProtocolUnitResult]: + required_keys = self._collect_protocol_unit_keys(unit.inputs) + if not required_keys: + return {} + + results: dict[GufeKey, ProtocolUnitResult] = {} + unresolved = set(required_keys) + + for source_key in required_keys: + loaded = self._load_result_from_index(source_key) + if loaded is not None: + results[source_key] = loaded + unresolved.discard(source_key) + + if unresolved: + scanned = self._scan_result_store_for_sources(unresolved) + for source_key, loaded in scanned.items(): + results[source_key] = loaded + self._store_result_index(loaded) + unresolved.discard(source_key) + + if unresolved: + missing_keys = ", ".join(sorted(str(k) for k in unresolved)) + raise RuntimeError( + "Missing ProtocolUnitResult(s) for dependency key(s): " + f"{missing_keys}. Ensure upstream tasks completed successfully." + ) + + return results + + def _checkout_task(self) -> tuple[TaskStatusDB, str, ProtocolUnit] | None: """Check out one available task and load its protocol unit. Returns ------- - tuple[str, ProtocolUnit] or None - The checked-out task ID and corresponding protocol unit, or - ``None`` if no task is currently available. + tuple[TaskStatusDB, str, ProtocolUnit] or None + The open database connection, checked-out task ID, and corresponding + protocol unit, or ``None`` if no task is currently available. + The caller is responsible for calling ``mark_task_completed`` on the + returned database using the returned task ID. """ db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) @@ -54,15 +165,15 @@ def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: _, protocol_unit_key = taskid.split(":", maxsplit=1) unit = self.warehouse.load_task(GufeKey(protocol_unit_key)) - return taskid, unit + return db, taskid, unit - def _get_task(self) -> ProtocolUnit: - """Return the next available protocol unit. + def _get_task(self) -> tuple[str, ProtocolUnit]: + """Return the next available task ID and protocol unit. Returns ------- - ProtocolUnit - A protocol unit loaded from the warehouse. + tuple[str, ProtocolUnit] + The checked-out task ID and corresponding protocol unit. Raises ------ @@ -73,8 +184,8 @@ def _get_task(self) -> ProtocolUnit: task = self._checkout_task() if task is None: raise RuntimeError("No AVAILABLE tasks found in the task database.") - _, unit = task - return unit + db, taskid, unit = task + return taskid, unit def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: """Execute one checked-out protocol unit and persist its result. @@ -101,19 +212,23 @@ def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: task = self._checkout_task() if task is None: return None - taskid, unit = task - # 2. Constrcut the context + db, taskid, unit = task + # 2. Construct the context # NOTE: On changes to context, this can easily be replaced with external storage objects # However, to satisfy the current work, we will use this implementation where we # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. - shared_store: FileStorage = self.warehouse.stores["shared"] + shared_store = self.warehouse.stores["shared"] + if not isinstance(shared_store, FileStorage): + raise TypeError("Expected a FileStorage backend for the shared store") shared_root_dir = shared_store.root_dir - ctx = Context(scratch, shared=shared_root_dir) - results: dict[GufeKey, ProtocolUnitResult] = {} - inputs = _pu_to_pur(unit.inputs, results) - db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) + task_scratch, task_shared = self._task_workspace_paths(taskid, scratch, shared_root_dir) + task_scratch.mkdir(parents=True, exist_ok=True) + task_shared.mkdir(parents=True, exist_ok=True) + ctx = Context(task_scratch, shared=task_shared) # 3. Execute unit try: + results = self._build_input_result_mapping(unit) + inputs = _pu_to_pur(unit.inputs, results) result = unit.execute(context=ctx, **inputs) except Exception: db.mark_task_completed(taskid, success=False) @@ -123,4 +238,5 @@ def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: # 4. output result to warehouse # TODO: we may need to end up handling namespacing on the warehouse side for tokenizables self.warehouse.store_result_tokenizable(result) + self._store_result_index(result) return taskid, result diff --git a/src/openfe/tests/orchestration/test_worker.py b/src/openfe/tests/orchestration/test_worker.py index dca306d26..98edcc8e2 100644 --- a/src/openfe/tests/orchestration/test_worker.py +++ b/src/openfe/tests/orchestration/test_worker.py @@ -27,6 +27,24 @@ def _contains_protocol_unit(value) -> bool: return False +class _ToyProtocolUnit(ProtocolUnit): + @staticmethod + def _execute(ctx, **inputs) -> dict[str, int]: + increment = inputs["increment"] + upstream = inputs.get("upstream") + base = 0 if upstream is None else upstream.outputs["value"] + return {"value": base + increment} + + +class _FileWritingUnit(ProtocolUnit): + @staticmethod + def _execute(ctx, **inputs) -> dict[str, str]: + shared_file = ctx.shared / "simulation.nc" + shared_file.parent.mkdir(parents=True, exist_ok=True) + shared_file.write_text("unit output", encoding="utf-8") + return {"shared_file": str(shared_file)} + + def _get_dependency_free_unit(absolute_transformation): for unit in absolute_transformation.create().protocol_units: if not _contains_protocol_unit(unit.inputs): @@ -74,11 +92,12 @@ def test_get_task_uses_default_db_path_without_patching( db = build_task_db_from_alchemical_network(network, warehouse, db_path=db_path) worker = Worker(warehouse=warehouse) - loaded = worker._get_task() + taskid, loaded = worker._get_task() expected_keys = {task_row.taskid.split(":", maxsplit=1)[1] for task_row in db.get_all_tasks()} assert worker.task_db_path == Path("./warehouse/tasks.db") assert str(loaded.key) in expected_keys + assert taskid.endswith(f":{loaded.key}") def test_get_task_returns_task_with_canonical_protocol_unit_suffix(worker_with_real_db): @@ -87,11 +106,12 @@ def test_get_task_returns_task_with_canonical_protocol_unit_suffix(worker_with_r task_ids = [row.taskid for row in db.get_all_tasks()] expected_protocol_unit_keys = {task_id.split(":", maxsplit=1)[1] for task_id in task_ids} - loaded = worker._get_task() + taskid, loaded = worker._get_task() reloaded = warehouse.load_task(loaded.key) assert str(loaded.key) in expected_protocol_unit_keys assert loaded == reloaded + assert taskid.endswith(f":{loaded.key}") def test_execute_unit_stores_real_result(worker_with_executable_task_db, tmp_path): @@ -152,3 +172,103 @@ def test_execute_unit_returns_none_when_no_available_tasks(tmp_path): worker = Worker(warehouse=warehouse, task_db_path=db_path) assert worker.execute_unit(scratch=tmp_path / "scratch") is None + + +def test_execute_unit_resolves_dependency_results(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + + first_unit = _ToyProtocolUnit(name="first", increment=1) + second_unit = _ToyProtocolUnit(name="second", upstream=first_unit, increment=2) + + warehouse.store_task(first_unit) + warehouse.store_task(second_unit) + + transformation_key = "Transformation-toy" + first_taskid = f"{transformation_key}:{first_unit.key}" + second_taskid = f"{transformation_key}:{second_unit.key}" + + task_graph = nx.DiGraph() + task_graph.add_edge(first_taskid, second_taskid) + + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(task_graph, max_tries=1) + + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + first_execution = worker.execute_unit(scratch=tmp_path / "scratch") + second_execution = worker.execute_unit(scratch=tmp_path / "scratch") + + assert first_execution is not None + assert first_execution[0] == first_taskid + assert second_execution is not None + assert second_execution[0] == second_taskid + assert second_execution[1].outputs["value"] == 3 + + status_by_taskid = {row.taskid: row.status for row in db.get_all_tasks()} + assert status_by_taskid[first_taskid] == exorcist.TaskStatus.COMPLETED.value + assert status_by_taskid[second_taskid] == exorcist.TaskStatus.COMPLETED.value + + +def test_execute_unit_marks_missing_dependency_as_failed(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + + missing_upstream = _ToyProtocolUnit(name="missing", increment=1) + dependent_unit = _ToyProtocolUnit(name="dependent", upstream=missing_upstream, increment=2) + warehouse.store_task(dependent_unit) + + taskid = f"Transformation-toy:{dependent_unit.key}" + task_graph = nx.DiGraph() + task_graph.add_node(taskid) + + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(task_graph, max_tries=1) + + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + with pytest.raises(RuntimeError, match="Missing ProtocolUnitResult"): + worker.execute_unit(scratch=tmp_path / "scratch") + + status_by_taskid = {row.taskid: row.status for row in db.get_all_tasks()} + assert status_by_taskid[taskid] == exorcist.TaskStatus.TOO_MANY_RETRIES.value + + +def test_execute_unit_uses_isolated_shared_workspace_per_task(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + + first_unit = _FileWritingUnit(name="first") + second_unit = _FileWritingUnit(name="second") + + warehouse.store_task(first_unit) + warehouse.store_task(second_unit) + + first_taskid = f"Transformation-toy:{first_unit.key}" + second_taskid = f"Transformation-toy:{second_unit.key}" + + task_graph = nx.DiGraph() + task_graph.add_node(first_taskid) + task_graph.add_node(second_taskid) + + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(task_graph, max_tries=1) + + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + first_execution = worker.execute_unit(scratch=tmp_path / "scratch") + second_execution = worker.execute_unit(scratch=tmp_path / "scratch") + + assert first_execution is not None + assert second_execution is not None + + first_path = Path(first_execution[1].outputs["shared_file"]) + second_path = Path(second_execution[1].outputs["shared_file"]) + + assert first_path != second_path + assert first_path.name == "simulation.nc" + assert second_path.name == "simulation.nc" + assert first_path.parent != second_path.parent From 4ec61eed32e7c1f01e170313ee4fe6d76439b226 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Thu, 5 Mar 2026 18:08:16 -0700 Subject: [PATCH 20/20] refactor: fix issues in the CLI for running the worker --- src/openfecli/commands/worker.py | 52 ++++++++++++++++++- .../plan_alchemical_networks_utils.py | 9 ++-- src/openfecli/tests/commands/test_worker.py | 36 +++++++++++++ 3 files changed, 93 insertions(+), 4 deletions(-) diff --git a/src/openfecli/commands/worker.py b/src/openfecli/commands/worker.py index dbc52b6b3..c312e3ff0 100644 --- a/src/openfecli/commands/worker.py +++ b/src/openfecli/commands/worker.py @@ -6,7 +6,7 @@ import click from openfecli import OFECommandPlugin -from openfecli.utils import print_duration, write +from openfecli.utils import configure_logger, print_duration, write def _build_worker(warehouse_path: pathlib.Path, db_path: pathlib.Path): @@ -17,7 +17,54 @@ def _build_worker(warehouse_path: pathlib.Path, db_path: pathlib.Path): return Worker(warehouse=warehouse, task_db_path=db_path) +def _write_failure_result_details(taskid: str, result) -> None: + source_key = getattr(result, "source_key", None) + exception = getattr(result, "exception", None) + traceback_text = getattr(result, "traceback", None) + + write(f"Task '{taskid}' returned a failure result.") + if source_key is not None: + write(f"Failed unit source key: {source_key}") + + if isinstance(exception, tuple) and len(exception) == 2: + exc_type, exc_args = exception + write(f"Protocol unit exception: {exc_type}: {exc_args}") + + if isinstance(traceback_text, str) and traceback_text: + write("Protocol unit traceback:") + write(traceback_text) + + def worker_main(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): + import logging + import os + import sys + import traceback + + from openfe.utils import logging_control + + # avoid problems with output not showing if queueing system kills a job + sys.stdout.reconfigure(line_buffering=True) + + stdout_handler = logging.StreamHandler(sys.stdout) + + configure_logger("gufekey", handler=stdout_handler) + configure_logger("gufe", handler=stdout_handler) + configure_logger("openfe", handler=stdout_handler) + + # silence the openmmtools.multistate API warning + logging_control._silence_message( + msg=[ + "The openmmtools.multistate API is experimental and may change in future releases", + ], + logger_names=[ + "openmmtools.multistate.multistatereporter", + "openmmtools.multistate.multistateanalyzer", + "openmmtools.multistate.multistatesampler", + ], + ) + # turn warnings into log message (don't show stack trace) + logging.captureWarnings(True) db_path = warehouse_path / "tasks.db" if not db_path.is_file(): raise click.ClickException(f"Task database not found at: {db_path}") @@ -30,8 +77,10 @@ def worker_main(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): worker = _build_worker(warehouse_path, db_path) try: + write("Executing unit...") execution = worker.execute_unit(scratch=scratch) except Exception as exc: + write(traceback.format_exc()) raise click.ClickException(f"Task execution failed: {exc}") from exc if execution is None: @@ -40,6 +89,7 @@ def worker_main(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): taskid, result = execution if not result.ok(): + _write_failure_result_details(taskid, result) raise click.ClickException(f"Task '{taskid}' returned a failure result.") write(f"Completed task: {taskid}") diff --git a/src/openfecli/plan_alchemical_networks_utils.py b/src/openfecli/plan_alchemical_networks_utils.py index da0572603..e85965715 100644 --- a/src/openfecli/plan_alchemical_networks_utils.py +++ b/src/openfecli/plan_alchemical_networks_utils.py @@ -4,10 +4,12 @@ import json import pathlib +from pathlib import Path from typing import Optional from openfe import AlchemicalNetwork, LigandNetwork -from openfe.storage.warehouse import WarehouseBaseClass +from openfe.orchestration.exorcist_utils import build_task_db_from_alchemical_network +from openfe.storage.warehouse import FileSystemWarehouse from openfecli.utils import write @@ -15,13 +17,14 @@ def plan_alchemical_network_output( alchemical_network: AlchemicalNetwork, ligand_network: LigandNetwork, folder_path: pathlib.Path, - warehouse: Optional[WarehouseBaseClass], + warehouse: Optional[FileSystemWarehouse] = None, ): """Write the contents of an alchemical network into the structure""" if warehouse: warehouse.store_setup_tokenizable(alchemical_network) - warehouse.store_setup_tokenizable(ligand_network) + db_path = Path(warehouse.root_dir) / "tasks.db" + _ = build_task_db_from_alchemical_network(alchemical_network, warehouse, db_path) else: base_name = folder_path.name folder_path.mkdir(parents=True, exist_ok=True) diff --git a/src/openfecli/tests/commands/test_worker.py b/src/openfecli/tests/commands/test_worker.py index 6d3b55f7e..b060b57e2 100644 --- a/src/openfecli/tests/commands/test_worker.py +++ b/src/openfecli/tests/commands/test_worker.py @@ -16,6 +16,15 @@ def ok(self): return False +class _FailedResultWithDetails: + source_key = "HybridTopologyMultiStateSimulationUnit-deadbeef" + exception = ("RuntimeError", ("simulation blew up",)) + traceback = 'Traceback (most recent call last):\n File "sim.py", line 1\nRuntimeError: simulation blew up' + + def ok(self): + return False + + def test_worker_requires_task_database(): runner = CliRunner() with runner.isolated_filesystem(): @@ -90,6 +99,31 @@ def test_worker_raises_when_result_is_failure(): assert "returned a failure result" in result.output +def test_worker_prints_failure_result_details_when_available(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = ( + "Transformation-abc:ProtocolUnit-def", + _FailedResultWithDetails(), + ) + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 1 + assert ( + "Failed unit source key: HybridTopologyMultiStateSimulationUnit-deadbeef" + in result.output + ) + assert "Protocol unit exception: RuntimeError: ('simulation blew up',)" in result.output + assert "Protocol unit traceback:" in result.output + + def test_worker_raises_when_execution_throws(): runner = CliRunner() with runner.isolated_filesystem(): @@ -104,4 +138,6 @@ def test_worker_raises_when_execution_throws(): result = runner.invoke(worker, ["warehouse"]) assert result.exit_code == 1 + assert "Traceback (most recent call last):" in result.output + assert "RuntimeError: boom" in result.output assert "Task execution failed: boom" in result.output