diff --git a/environment.yml b/environment.yml index 5b23e9065..b4fdbe9c1 100644 --- a/environment.yml +++ b/environment.yml @@ -53,7 +53,7 @@ dependencies: # Control blas/openmp threads - threadpoolctl - pip: - - git+https://github.com/OpenFreeEnergy/gufe@main + - git+https://github.com/OpenFreeEnergy/gufe@restart_execute - run_constrained: # drop this pin when handled upstream in espaloma-feedstock - smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109 diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py index 5d815c713..53574f972 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py @@ -79,12 +79,12 @@ def patcher(): yield -def test_gather(benzene_complex_dag, patcher, tmpdir): +def test_gather(benzene_complex_dag, patcher, tmp_path): # check that .gather behaves as expected dagres = gufe.protocols.execute_DAG( benzene_complex_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py index 0cb2d2d25..619b199d1 100644 --- a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py @@ -99,12 +99,12 @@ def patcher(): yield -def test_gather(benzene_solvation_dag, patcher, tmpdir): +def test_gather(benzene_solvation_dag, patcher, tmp_path): # check that .gather behaves as expected dagres = gufe.protocols.execute_DAG( benzene_solvation_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py b/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py index 60c7e8c47..b8e3153ff 100644 --- a/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py +++ b/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py @@ -508,7 +508,7 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): assert len(repeats) == 3 -def test_gather(solvent_protocol_dag, tmpdir): +def test_gather(solvent_protocol_dag, tmp_path): # check .gather behaves as expected with mock.patch( "openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run", @@ -519,8 +519,8 @@ def test_gather(solvent_protocol_dag, tmpdir): ): dagres = gufe.protocols.execute_DAG( solvent_protocol_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index 7e257e865..83b795eb2 100644 --- a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -1257,12 +1257,12 @@ def test_unit_tagging(solvent_protocol_dag, unit_mock_patcher, tmpdir): assert len(setup_results) == len(sim_results) == len(analysis_results) == 3 -def test_gather(solvent_protocol_dag, unit_mock_patcher, tmpdir): +def test_gather(solvent_protocol_dag, unit_mock_patcher, tmp_path): # check .gather behaves as expected dagres = gufe.protocols.execute_DAG( solvent_protocol_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py index e5e4a9f91..4a22aebcf 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py @@ -1295,7 +1295,7 @@ def test_unit_tagging(benzene_toluene_dag, tmpdir): assert len(complex_repeats) == len(solv_repeats) == 2 -def test_gather(benzene_toluene_dag, tmpdir): +def test_gather(benzene_toluene_dag, tmp_path): # check that .gather behaves as expected with ( mock.patch( @@ -1339,8 +1339,8 @@ def test_gather(benzene_toluene_dag, tmpdir): ): dagres = gufe.protocols.execute_DAG( benzene_toluene_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfecli/commands/quickrun.py b/src/openfecli/commands/quickrun.py index f34410d69..308d8f7b0 100644 --- a/src/openfecli/commands/quickrun.py +++ b/src/openfecli/commands/quickrun.py @@ -51,7 +51,9 @@ def quickrun(transformation, work_dir, output): import logging import os import sys + from json import JSONDecodeError + from gufe import ProtocolDAG from gufe.protocols.protocoldag import execute_DAG from gufe.tokenization import JSON_HANDLER from gufe.transformations.transformation import Transformation @@ -94,13 +96,28 @@ def quickrun(transformation, work_dir, output): else: output.parent.mkdir(exist_ok=True, parents=True) - write("Planning simulations for this edge...") - dag = trans.create() + # Attempt to either deserialize or freshly create DAG + trans_DAG_json = work_dir / f"Transformation-{trans.key}-protocolDAG.json" + + if trans_DAG_json.is_file(): + write(f"Attempting to resume execution using existing edges from '{trans_DAG_json}'") + try: + dag = ProtocolDAG.from_json(trans_DAG_json) + except JSONDecodeError: + errmsg = f"Recovery failed, please remove {trans_DAG_json} and any results from your working directory before continuing to create a new protocol." + raise click.ClickException(errmsg) + else: + # Create the DAG instead and then serialize for later resuming + write("Planning simulations for this edge...") + dag = trans.create() + dag.to_json(trans_DAG_json) + write("Starting the simulations for this edge...") dagresult = execute_DAG( dag, shared_basedir=work_dir, scratch_basedir=work_dir, + unitresults_basedir=work_dir, keep_shared=True, raise_error=False, n_retries=2, diff --git a/src/openfecli/tests/commands/test_quickrun.py b/src/openfecli/tests/commands/test_quickrun.py index 86fe00b26..6a290cbac 100644 --- a/src/openfecli/tests/commands/test_quickrun.py +++ b/src/openfecli/tests/commands/test_quickrun.py @@ -5,10 +5,13 @@ import click import pytest from click.testing import CliRunner +from gufe import Transformation from gufe.tokenization import JSON_HANDLER from openfecli.commands.quickrun import quickrun +# from ..utils import assert_click_success + @pytest.fixture def json_file(): @@ -33,6 +36,10 @@ def test_quickrun(extra_args, json_file): result = runner.invoke(quickrun, [json_file] + extras) assert result.exit_code == 0 assert "Here is the result" in result.output + trans = Transformation.from_json(json_file) + assert pathlib.Path( + extra_args.get("-d", ""), f"Transformation-{trans.key}-protocolDAG.json" + ).exists() if outfile := extra_args.get("-o"): assert pathlib.Path(outfile).exists() @@ -92,3 +99,16 @@ def test_quickrun_unit_error(): # to be stored in JSON # not sure whether that means we should always be storing all # protocol dag results maybe? + + +def test_quickrun_resume(json_file): + trans = Transformation.from_json(json_file) + dag = trans.create() + + runner = CliRunner() + with runner.isolated_filesystem(): + dag.to_json(f"Transformation-{trans.key}-protocolDAG.json") + result = runner.invoke(quickrun, [json_file]) + + assert result.exit_code == 0 + assert "Attempting to resume" in result.output