Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies:
- threadpoolctl
- pip:
- git+https://github.com/OpenFreeEnergy/gufe@main
- git+https://github.com/OpenFreeEnergy/exorcist@main
- run_constrained:
# drop this pin when handled upstream in espaloma-feedstock
- smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109
242 changes: 242 additions & 0 deletions src/openfe/orchestration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""Task orchestration utilities backed by Exorcist and a warehouse."""

from dataclasses import dataclass
from pathlib import Path

from exorcist.taskdb import TaskStatusDB
from gufe.protocols.protocoldag import _pu_to_pur
from gufe.protocols.protocolunit import (
Context,
ProtocolUnit,
ProtocolUnitResult,
)
from gufe.storage.externalresource.base import ExternalStorage
from gufe.storage.externalresource.filestorage import FileStorage
from gufe.tokenization import GufeKey

from openfe.storage.warehouse import FileSystemWarehouse

from .exorcist_utils import (
alchemical_network_to_task_graph,
build_task_db_from_alchemical_network,
)


@dataclass
class Worker:
"""Execute protocol units from an Exorcist task database.

Parameters
----------
warehouse : FileSystemWarehouse
Warehouse used to load queued tasks and store execution results.
task_db_path : pathlib.Path, default=Path("./warehouse/tasks.db")
Path to the Exorcist SQLite task database.
"""

warehouse: FileSystemWarehouse
task_db_path: Path = Path("./warehouse/tasks.db")

_RESULT_INDEX_PREFIX = "protocol_unit_results"
_TASK_WORKDIR_PREFIX = "task_workdirs"

@staticmethod
def _collect_protocol_unit_keys(value: object) -> set[GufeKey]:
"""Collect `ProtocolUnit` keys from nested unit inputs."""

if isinstance(value, ProtocolUnit):
return {value.key}

found: set[GufeKey] = set()
if isinstance(value, dict):
items = value.values()
elif isinstance(value, list):
items = value
else:
return found

for item in items:
found.update(Worker._collect_protocol_unit_keys(item))
return found

@classmethod
def _result_index_location(cls, source_key: GufeKey) -> str:
return f"{cls._RESULT_INDEX_PREFIX}/{source_key}"

@classmethod
def _task_workdir_name(cls, taskid: str) -> str:
return taskid.replace(":", "__")

def _task_workspace_paths(
self, taskid: str, scratch_root: Path, shared_root: Path
) -> tuple[Path, Path]:
workdir_name = self._task_workdir_name(taskid)
task_scratch = scratch_root / self._TASK_WORKDIR_PREFIX / workdir_name
task_shared = shared_root / self._TASK_WORKDIR_PREFIX / workdir_name
return task_scratch, task_shared

def _store_result_index(self, result: ProtocolUnitResult) -> None:
shared_store: ExternalStorage = self.warehouse.stores["shared"]
location = self._result_index_location(result.source_key)
shared_store.store_bytes(location, str(result.key).encode("utf-8"))

def _load_result_from_index(self, source_key: GufeKey) -> ProtocolUnitResult | None:
shared_store: ExternalStorage = self.warehouse.stores["shared"]
location = self._result_index_location(source_key)

if not shared_store.exists(location):
return None

with shared_store.load_stream(location) as stream:
result_key = stream.read().decode("utf-8").strip()

loaded = self.warehouse.load_result_tokenizable(GufeKey(result_key))
if isinstance(loaded, ProtocolUnitResult):
return loaded

return None

def _scan_result_store_for_sources(
self, source_keys: set[GufeKey]
) -> dict[GufeKey, ProtocolUnitResult]:
found: dict[GufeKey, ProtocolUnitResult] = {}

for location in self.warehouse.result_store.iter_contents():
if len(found) == len(source_keys):
break

loaded = self.warehouse.load_result_tokenizable(GufeKey(location))
if not isinstance(loaded, ProtocolUnitResult):
continue

source_key = loaded.source_key
if source_key in source_keys and source_key not in found:
found[source_key] = loaded

return found

def _build_input_result_mapping(self, unit: ProtocolUnit) -> dict[GufeKey, ProtocolUnitResult]:
required_keys = self._collect_protocol_unit_keys(unit.inputs)
if not required_keys:
return {}

results: dict[GufeKey, ProtocolUnitResult] = {}
unresolved = set(required_keys)

for source_key in required_keys:
loaded = self._load_result_from_index(source_key)
if loaded is not None:
results[source_key] = loaded
unresolved.discard(source_key)

if unresolved:
scanned = self._scan_result_store_for_sources(unresolved)
for source_key, loaded in scanned.items():
results[source_key] = loaded
self._store_result_index(loaded)
unresolved.discard(source_key)

if unresolved:
missing_keys = ", ".join(sorted(str(k) for k in unresolved))
raise RuntimeError(
"Missing ProtocolUnitResult(s) for dependency key(s): "
f"{missing_keys}. Ensure upstream tasks completed successfully."
)

return results

def _checkout_task(self) -> tuple[TaskStatusDB, str, ProtocolUnit] | None:
"""Check out one available task and load its protocol unit.

Returns
-------
tuple[TaskStatusDB, str, ProtocolUnit] or None
The open database connection, checked-out task ID, and corresponding
protocol unit, or ``None`` if no task is currently available.
The caller is responsible for calling ``mark_task_completed`` on the
returned database using the returned task ID.
"""

db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path)
# The format for the taskid is "Transformation-<HASH>:ProtocolUnit-<HASH>"
taskid = db.check_out_task()
if taskid is None:
return None

_, protocol_unit_key = taskid.split(":", maxsplit=1)
unit = self.warehouse.load_task(GufeKey(protocol_unit_key))
return db, taskid, unit

def _get_task(self) -> tuple[str, ProtocolUnit]:
"""Return the next available task ID and protocol unit.

Returns
-------
tuple[str, ProtocolUnit]
The checked-out task ID and corresponding protocol unit.

Raises
------
RuntimeError
Raised when no task is available in the task database.
"""

task = self._checkout_task()
if task is None:
raise RuntimeError("No AVAILABLE tasks found in the task database.")
db, taskid, unit = task
return taskid, unit

def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None:
"""Execute one checked-out protocol unit and persist its result.

Parameters
----------
scratch : pathlib.Path
Scratch directory passed to the protocol execution context.

Returns
-------
tuple[str, ProtocolUnitResult] or None
The task ID and execution result for the processed task, or
``None`` if no task is currently available.

Raises
------
Exception
Re-raises any exception thrown during protocol unit execution after
marking the task as failed.
"""

# 1. Get task/unit
task = self._checkout_task()
if task is None:
return None
db, taskid, unit = task
# 2. Construct the context
# NOTE: On changes to context, this can easily be replaced with external storage objects
# However, to satisfy the current work, we will use this implementation where we
# force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage.
shared_store = self.warehouse.stores["shared"]
if not isinstance(shared_store, FileStorage):
raise TypeError("Expected a FileStorage backend for the shared store")
shared_root_dir = shared_store.root_dir
task_scratch, task_shared = self._task_workspace_paths(taskid, scratch, shared_root_dir)
task_scratch.mkdir(parents=True, exist_ok=True)
task_shared.mkdir(parents=True, exist_ok=True)
ctx = Context(task_scratch, shared=task_shared)
# 3. Execute unit
try:
results = self._build_input_result_mapping(unit)
inputs = _pu_to_pur(unit.inputs, results)
result = unit.execute(context=ctx, **inputs)
except Exception:
db.mark_task_completed(taskid, success=False)
raise

db.mark_task_completed(taskid, success=result.ok())
# 4. output result to warehouse
# TODO: we may need to end up handling namespacing on the warehouse side for tokenizables
self.warehouse.store_result_tokenizable(result)
self._store_result_index(result)
return taskid, result
95 changes: 95 additions & 0 deletions src/openfe/orchestration/exorcist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Utilities for building Exorcist task graphs and task databases.

This module translates an :class:`gufe.AlchemicalNetwork` into Exorcist task
structures and can initialize an Exorcist task database from that graph.
"""

from pathlib import Path

import exorcist
import networkx as nx
from gufe import AlchemicalNetwork

from openfe.storage.warehouse import WarehouseBaseClass


def alchemical_network_to_task_graph(
alchemical_network: AlchemicalNetwork, warehouse: WarehouseBaseClass
) -> nx.DiGraph:
"""Build a global task DAG from an alchemical network.

Parameters
----------
alchemical_network : AlchemicalNetwork
Network containing transformations to execute.
warehouse : WarehouseBaseClass
Warehouse used to persist protocol units as tasks while the graph is
constructed.

Returns
-------
nx.DiGraph
A directed acyclic graph where each node is a task ID in the form
``"<transformation_key>:<protocol_unit_key>"`` and edges encode
protocol-unit dependencies.

Raises
------
ValueError
Raised if the assembled task graph is not acyclic.
"""

global_dag = nx.DiGraph()
for transformation in alchemical_network.edges:
dag = transformation.create()
for unit in dag.protocol_units:
node_id = f"{str(transformation.key)}:{str(unit.key)}"
global_dag.add_node(
node_id,
)
warehouse.store_task(unit)
for dependent_unit, dependency_unit in dag.graph.edges:
upstream_id = f"{str(transformation.key)}:{str(dependency_unit.key)}"
downstream_id = f"{str(transformation.key)}:{str(dependent_unit.key)}"
global_dag.add_edge(upstream_id, downstream_id)

if not nx.is_directed_acyclic_graph(global_dag):
raise ValueError("AlchemicalNetwork produced a task graph that is not a DAG.")

return global_dag


def build_task_db_from_alchemical_network(
alchemical_network: AlchemicalNetwork,
warehouse: WarehouseBaseClass,
db_path: Path | None = None,
max_tries: int = 1,
) -> exorcist.TaskStatusDB:
"""Create and populate a task database from an alchemical network.

Parameters
----------
alchemical_network : AlchemicalNetwork
Network containing transformations to convert into task records.
warehouse : WarehouseBaseClass
Warehouse used to persist protocol units while building the task DAG.
db_path : pathlib.Path or None, optional
Location of the SQLite-backed Exorcist database. If ``None``, defaults
to ``Path("tasks.db")`` in the current working directory.
max_tries : int, default=1
Maximum number of retries for each task before Exorcist marks it as
``TOO_MANY_RETRIES``.

Returns
-------
exorcist.TaskStatusDB
Initialized task database populated with graph nodes and dependency
edges derived from ``alchemical_network``.
"""
if db_path is None:
db_path = Path("tasks.db")

global_dag = alchemical_network_to_task_graph(alchemical_network, warehouse)
db = exorcist.TaskStatusDB.from_filename(db_path)
db.add_task_network(global_dag, max_tries)
return db
Loading
Loading