Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies:
# Control blas/openmp threads
- threadpoolctl
- pip:
- git+https://github.com/OpenFreeEnergy/gufe@main
- git+https://github.com/OpenFreeEnergy/gufe@restart_execute
- run_constrained:
# drop this pin when handled upstream in espaloma-feedstock
- smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def patcher():
yield


def test_gather(benzene_complex_dag, patcher, tmpdir):
def test_gather(benzene_complex_dag, patcher, tmp_path):
# check that .gather behaves as expected
dagres = gufe.protocols.execute_DAG(
benzene_complex_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
shared_basedir=tmp_path,
scratch_basedir=tmp_path,
keep_shared=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def patcher():
yield


def test_gather(benzene_solvation_dag, patcher, tmpdir):
def test_gather(benzene_solvation_dag, patcher, tmp_path):
# check that .gather behaves as expected
dagres = gufe.protocols.execute_DAG(
benzene_solvation_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
shared_basedir=tmp_path,
scratch_basedir=tmp_path,
keep_shared=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir):
assert len(repeats) == 3


def test_gather(solvent_protocol_dag, tmpdir):
def test_gather(solvent_protocol_dag, tmp_path):
# check .gather behaves as expected
with mock.patch(
"openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run",
Expand All @@ -519,8 +519,8 @@ def test_gather(solvent_protocol_dag, tmpdir):
):
dagres = gufe.protocols.execute_DAG(
solvent_protocol_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
shared_basedir=tmp_path,
scratch_basedir=tmp_path,
keep_shared=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1257,12 +1257,12 @@ def test_unit_tagging(solvent_protocol_dag, unit_mock_patcher, tmpdir):
assert len(setup_results) == len(sim_results) == len(analysis_results) == 3


def test_gather(solvent_protocol_dag, unit_mock_patcher, tmpdir):
def test_gather(solvent_protocol_dag, unit_mock_patcher, tmp_path):
# check .gather behaves as expected
dagres = gufe.protocols.execute_DAG(
solvent_protocol_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
shared_basedir=tmp_path,
scratch_basedir=tmp_path,
keep_shared=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ def test_unit_tagging(benzene_toluene_dag, tmpdir):
assert len(complex_repeats) == len(solv_repeats) == 2


def test_gather(benzene_toluene_dag, tmpdir):
def test_gather(benzene_toluene_dag, tmp_path):
# check that .gather behaves as expected
with (
mock.patch(
Expand Down Expand Up @@ -1339,8 +1339,8 @@ def test_gather(benzene_toluene_dag, tmpdir):
):
dagres = gufe.protocols.execute_DAG(
benzene_toluene_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
shared_basedir=tmp_path,
scratch_basedir=tmp_path,
keep_shared=True,
)

Expand Down
21 changes: 19 additions & 2 deletions src/openfecli/commands/quickrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def quickrun(transformation, work_dir, output):
import logging
import os
import sys
from json import JSONDecodeError

from gufe import ProtocolDAG
from gufe.protocols.protocoldag import execute_DAG
from gufe.tokenization import JSON_HANDLER
from gufe.transformations.transformation import Transformation
Expand Down Expand Up @@ -94,13 +96,28 @@ def quickrun(transformation, work_dir, output):
else:
output.parent.mkdir(exist_ok=True, parents=True)

write("Planning simulations for this edge...")
dag = trans.create()
# Attempt to either deserialize or freshly create DAG
trans_DAG_json = work_dir / f"Transformation-{trans.key}-protocolDAG.json"

if trans_DAG_json.is_file():
write(f"Attempting to resume execution using existing edges from '{trans_DAG_json}'")
try:
dag = ProtocolDAG.from_json(trans_DAG_json)
except JSONDecodeError:
errmsg = f"Recovery failed, please remove {trans_DAG_json} and any results from your working directory before continuing to create a new protocol."
raise click.ClickException(errmsg)
else:
# Create the DAG instead and then serialize for later resuming
write("Planning simulations for this edge...")
dag = trans.create()
dag.to_json(trans_DAG_json)

write("Starting the simulations for this edge...")
dagresult = execute_DAG(
dag,
shared_basedir=work_dir,
scratch_basedir=work_dir,
unitresults_basedir=work_dir,
keep_shared=True,
raise_error=False,
n_retries=2,
Expand Down
20 changes: 20 additions & 0 deletions src/openfecli/tests/commands/test_quickrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import click
import pytest
from click.testing import CliRunner
from gufe import Transformation
from gufe.tokenization import JSON_HANDLER

from openfecli.commands.quickrun import quickrun

# from ..utils import assert_click_success


@pytest.fixture
def json_file():
Expand All @@ -33,6 +36,10 @@ def test_quickrun(extra_args, json_file):
result = runner.invoke(quickrun, [json_file] + extras)
assert result.exit_code == 0
assert "Here is the result" in result.output
trans = Transformation.from_json(json_file)
assert pathlib.Path(
extra_args.get("-d", ""), f"Transformation-{trans.key}-protocolDAG.json"
).exists()

if outfile := extra_args.get("-o"):
assert pathlib.Path(outfile).exists()
Expand Down Expand Up @@ -92,3 +99,16 @@ def test_quickrun_unit_error():
# to be stored in JSON
# not sure whether that means we should always be storing all
# protocol dag results maybe?


def test_quickrun_resume(json_file):
trans = Transformation.from_json(json_file)
dag = trans.create()

runner = CliRunner()
with runner.isolated_filesystem():
dag.to_json(f"Transformation-{trans.key}-protocolDAG.json")
result = runner.invoke(quickrun, [json_file])

assert result.exit_code == 0
assert "Attempting to resume" in result.output
Loading