diff --git a/news/resume.rst b/news/resume.rst new file mode 100644 index 000000000..1c123da2e --- /dev/null +++ b/news/resume.rst @@ -0,0 +1,27 @@ +**Added:** + +* Added API support to resume `RelativeHybridTopologyProtocol` + simulations (`PR 1774 `_). +* Added API support to resume `AbsoluteBindingProtocol` and + `AbsoluteSolvationProtocol` simulations + (`PR 1808 `_). + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/openfe/protocols/openmm_afe/base_afe_units.py b/src/openfe/protocols/openmm_afe/base_afe_units.py index 994873b48..8bbe3545a 100644 --- a/src/openfe/protocols/openmm_afe/base_afe_units.py +++ b/src/openfe/protocols/openmm_afe/base_afe_units.py @@ -34,6 +34,7 @@ SolventComponent, ) from gufe.components import Component +from gufe.protocols.errors import ProtocolUnitExecutionError from openff.toolkit.topology import Molecule as OFFMolecule from openff.units import Quantity from openff.units import unit as offunit @@ -54,6 +55,7 @@ create_thermodynamic_state_protocol, ) +import openfe from openfe.protocols.openmm_afe.equil_afe_settings import ( AlchemicalSettings, BaseSolvationSettings, @@ -61,8 +63,6 @@ MultiStateOutputSettings, MultiStateSimulationSettings, OpenFFPartialChargeSettings, - OpenMMEngineSettings, - OpenMMSystemGeneratorFFSettings, ThermoSettings, ) from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit @@ -72,6 +72,7 @@ omm_compute, settings_validation, system_creation, + system_validation, ) from openfe.protocols.openmm_utils.omm_settings import ( SettingsBaseModel, @@ -149,6 +150,26 @@ def _get_settings(self) -> dict[str, SettingsBaseModel]: """ ... + @staticmethod + def _verify_execution_environment( + setup_outputs: dict[str, Any], + ) -> None: + """ + Check that the Python environment hasn't changed based on the + relevant Python library versions stored in the setup outputs. + """ + try: + if ( + (gufe.__version__ != setup_outputs["gufe_version"]) + or (openfe.__version__ != setup_outputs["openfe_version"]) + or (openmm.__version__ != setup_outputs["openmm_version"]) + ): + errmsg = "Python environment has changed, cannot continue Protocol execution." + raise ProtocolUnitExecutionError(errmsg) + except KeyError: + errmsg = "Missing environment information from setup outputs." + raise ProtocolUnitExecutionError(errmsg) + class BaseAbsoluteSetupUnit(gufe.ProtocolUnit, AbsoluteUnitMixin): """ @@ -782,11 +803,47 @@ def _execute( "repeat_id": self._inputs["repeat_id"], "generation": self._inputs["generation"], "simtype": self.simtype, + "openmm_version": openmm.__version__, + "openfe_version": openfe.__version__, + "gufe_version": gufe.__version__, **outputs, } class BaseAbsoluteMultiStateSimulationUnit(gufe.ProtocolUnit, AbsoluteUnitMixin): + @staticmethod + def _check_restart(output_settings: SettingsBaseModel, shared_path: pathlib.Path): + """ + Check if we are doing a restart. + + Parameters + ---------- + output_settings : SettingsBaseModel + The simulation output settings + shared_path : pathlib.Path + The shared directory where we should be looking for existing files. + + Notes + ----- + For now this just checks if the netcdf files are present in the + shared directory but in the future this may expand depending on + how warehouse works. + """ + trajectory = shared_path / output_settings.output_filename + checkpoint = shared_path / output_settings.checkpoint_storage_filename + + if trajectory.is_file() ^ checkpoint.is_file(): + errmsg = ( + "One of either the trajectory or checkpoint files are missing but " + "the other is not. This should not happen under normal circumstances." + ) + raise IOError(errmsg) + + if trajectory.is_file() and checkpoint.is_file(): + return True + + return False + @abc.abstractmethod def _get_components( self, @@ -1003,7 +1060,13 @@ def _get_reporter( ------- reporter : multistate.MultiStateReporter The reporter for the simulation. + + Notes + ----- + All this does is create the reporter, it works for both + new reporters and if we are doing a restart. """ + # Define the trajectory & checkpoint files nc = storage_path / output_settings.output_filename # The checkpoint file in openmmtools is taken as a file relative # to the location of the nc file, so you only want the filename @@ -1034,7 +1097,7 @@ def _get_reporter( time_per_iteration=simulation_settings.time_per_iteration, ) - reporter = multistate.MultiStateReporter( + return multistate.MultiStateReporter( storage=nc, analysis_particle_indices=selection_indices, checkpoint_interval=chk_intervals, @@ -1043,8 +1106,6 @@ def _get_reporter( velocity_interval=vel_interval, ) - return reporter - @staticmethod def _get_sampler( integrator: openmmtools.mcmc.LangevinDynamicsMove, @@ -1054,6 +1115,7 @@ def _get_sampler( compound_states: list[ThermodynamicState], sampler_states: list[SamplerState], platform: openmm.Platform, + restart: bool, ) -> multistate.MultiStateSampler: """ Get a sampler based on the equilibrium sampling method requested. @@ -1074,51 +1136,115 @@ def _get_sampler( A list of sampler states. platform : openmm.Platform The compute platform to use. + restart : bool + ``True`` if we are doing a simulation restart. Returns ------- sampler : multistate.MultistateSampler A sampler configured for the chosen sampling method. """ + _SAMPLERS = { + "repex": multistate.ReplicaExchangeSampler, + "sams": multistate.SAMSSampler, + "independent": multistate.MultiStateSampler, + } + + sampler_method = simulation_settings.sampler_method.lower() + try: + sampler_class = _SAMPLERS[sampler_method] + except KeyError: + errmsg = f"Unknown sampler {sampler_method}" + raise AttributeError(errmsg) + + # Get the real time analysis values to use rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( simulation_settings=simulation_settings, ) - et_target_err = settings_validation.convert_target_error_from_kcal_per_mole_to_kT( - thermodynamic_settings.temperature, - simulation_settings.early_termination_target_error, - ) - # Select the right sampler - # Note: doesn't need else, settings already validates choices - if simulation_settings.sampler_method.lower() == "repex": - sampler = multistate.ReplicaExchangeSampler( - mcmc_moves=integrator, - online_analysis_interval=rta_its, - online_analysis_target_error=et_target_err, - online_analysis_minimum_iterations=rta_min_its, - ) - elif simulation_settings.sampler_method.lower() == "sams": - sampler = multistate.SAMSSampler( - mcmc_moves=integrator, - online_analysis_interval=rta_its, - online_analysis_minimum_iterations=rta_min_its, - flatness_criteria=simulation_settings.sams_flatness_criteria, - gamma0=simulation_settings.sams_gamma0, - ) - elif simulation_settings.sampler_method.lower() == "independent": - sampler = multistate.MultiStateSampler( - mcmc_moves=integrator, - online_analysis_interval=rta_its, - online_analysis_target_error=et_target_err, - online_analysis_minimum_iterations=rta_min_its, + # Get the number of production iterations to run for + steps_per_iteration = integrator.n_steps + timestep = from_openmm(integrator.timestep) + number_of_iterations = int( + settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=timestep, + mc_steps=steps_per_iteration, ) + / steps_per_iteration + ) - sampler.create( - thermodynamic_states=compound_states, - sampler_states=sampler_states, - storage=reporter, + # convert early_termination_target_error from kcal/mol to kT + early_termination_target_error = ( + settings_validation.convert_target_error_from_kcal_per_mole_to_kT( + thermodynamic_settings.temperature, + simulation_settings.early_termination_target_error, + ) ) + sampler_kwargs = { + "mcmc_moves": integrator, + "online_analysis_interval": rta_its, + "online_analysis_target_error": early_termination_target_error, + "online_analysis_minimum_iterations": rta_min_its, + "number_of_iterations": number_of_iterations, + } + + if sampler_method == "sams": + sampler_kwargs |= { + "flatness_criteria": simulation_settings.sams_flatness_criteria, + "gamma0": simulation_settings.sams_gamma0, + } + + if sampler_method == "repex": + sampler_kwargs |= { + "replica_mixing_scheme": "swap-all", + } + + # Restarting so we just rebuild from storage. + if restart: + sampler = sampler_class.from_storage(reporter) + + # We do some checks to make sure we are running the same system + # including ensuring that we have the same thermodynamic parameters and + # that the lambda schedule is the same. + for index, thermostate in enumerate(sampler._thermodynamic_states): + system_validation.assert_multistate_system_equality( + ref_system=compound_states[index].get_system(remove_thermostat=True), + stored_system=thermostate.get_system(remove_thermostat=True), + ) + + # Loop over each composable state (e.g. GlobalParameterState object) + # get the parameters and check that the values are the same. + for composable_state in compound_states[index]._composable_states: + for param in composable_state._parameters: + expected = getattr(compound_states[index], param) + stored = getattr(thermostate, param) + if expected != stored: + errmsg = ( + f"System parameter {param} in checkpoint does " + "not match protocol system, cannot resume" + ) + raise ValueError(errmsg) + + if ( + (simulation_settings.n_replicas != sampler.n_states) + or (simulation_settings.n_replicas != sampler.n_replicas) + or (sampler.mcmc_moves[0].n_steps != steps_per_iteration) + or (sampler.mcmc_moves[0].timestep != integrator.timestep) + ): + errmsg = "System in checkpoint does not match protocol system, cannot resume" + raise ValueError(errmsg) + else: + sampler = sampler_class(**sampler_kwargs) + + sampler.create( + thermodynamic_states=compound_states, + sampler_states=sampler_states, + storage=reporter, + ) + + # Get and set the context caches sampler.energy_context_cache = openmmtools.cache.ContextCache( capacity=None, time_to_live=None, @@ -1172,22 +1298,27 @@ def _run_simulation( ) if not dry: # pragma: no-cover - # minimize - if self.verbose: - self.logger.info("minimizing systems") + # No production steps have been taken, so start from scratch + if sampler._iteration == 0: + # minimize + if self.verbose: + self.logger.info("minimizing systems") - sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps) + sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps) - # equilibrate - if self.verbose: - self.logger.info("equilibrating systems") + # equilibrate + if self.verbose: + self.logger.info("equilibrating systems") - sampler.equilibrate(int(equil_steps / mc_steps)) + sampler.equilibrate(int(equil_steps / mc_steps)) - # production + # At this point we are ready for production if self.verbose: self.logger.info("running production phase") - sampler.extend(int(prod_steps / mc_steps)) + + # We use `run` so that we're limited by the number of iterations + # we passed when we built the sampler. + sampler.run(n_iterations=int(prod_steps / mc_steps) - sampler._iteration) if self.verbose: self.logger.info("production phase complete") @@ -1257,6 +1388,12 @@ def run( # Get the settings settings = self._get_settings() + # Check for a restart + self.restart = self._check_restart( + output_settings=settings["output_settings"], + shared_path=self.shared_basepath, + ) + # Get the components alchem_comps, solv_comp, prot_comp, small_mols = self._get_components() @@ -1299,7 +1436,7 @@ def run( output_settings=settings["output_settings"], ) - # Get sampler + # Get the sampler sampler = self._get_sampler( integrator=integrator, reporter=reporter, @@ -1308,9 +1445,10 @@ def run( compound_states=cmp_states, sampler_states=sampler_states, platform=platform, + restart=self.restart, ) - # Run simulation + # Run the simulation self._run_simulation( sampler=sampler, reporter=reporter, @@ -1367,6 +1505,10 @@ def _execute( ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) + # Ensure the environment hasn't changed + self._verify_execution_environment(setup_results.outputs) + + # Get the relevant inputs for running the unit system = deserialize(setup_results.outputs["system"]) positions = to_openmm(np.load(setup_results.outputs["positions"]) * offunit.nanometer) selection_indices = setup_results.outputs["selection_indices"] @@ -1509,6 +1651,10 @@ def _execute( ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) + # Ensure the environment hasn't changed + self._verify_execution_environment(setup_results.outputs) + + # Get the relevant inputs for running the unit pdb_file = setup_results.outputs["pdb_structure"] selection_indices = setup_results.outputs["selection_indices"] restraint_geometry = setup_results.outputs["restraint_geometry"] diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_units.py b/src/openfe/protocols/openmm_rfe/hybridtop_units.py index b4cd6b744..4d286a5e7 100644 --- a/src/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/src/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -803,9 +803,6 @@ def run( "positions": positions_outfile, "pdb_structure": self.shared_basepath / settings["output_settings"].output_structure, "selection_indices": selection_indices, - "openmm_version": openmm.__version__, - "openfe_version": openfe.__version__, - "gufe_version": gufe.__version__, } if dry: @@ -830,6 +827,9 @@ def _execute( return { "repeat_id": self._inputs["repeat_id"], "generation": self._inputs["generation"], + "openmm_version": openmm.__version__, + "openfe_version": openfe.__version__, + "gufe_version": gufe.__version__, **outputs, } @@ -1108,7 +1108,7 @@ def _get_sampler( # Restarting doesn't need any setup, we just rebuild from storage. if restart: - sampler = _SAMPLERS[sampler_method].from_storage(reporter) # type: ignore[attr-defined] + sampler = sampler_class.from_storage(reporter) # type: ignore[attr-defined] # We do some checks to make sure we are running the same system system_validation.assert_multistate_system_equality( @@ -1139,7 +1139,7 @@ def _get_sampler( raise ValueError(errmsg) else: - sampler = _SAMPLERS[sampler_method](**sampler_kwargs) + sampler = sampler_class(**sampler_kwargs) sampler.setup( n_replicas=simulation_settings.n_replicas, diff --git a/src/openfe/tests/protocols/conftest.py b/src/openfe/tests/protocols/conftest.py index 6978148a4..1744ad2bd 100644 --- a/src/openfe/tests/protocols/conftest.py +++ b/src/openfe/tests/protocols/conftest.py @@ -360,6 +360,42 @@ def htop_checkpoint_path(): return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") +@pytest.fixture(scope="module") +def ahfe_vac_trajectory_path(): + pooch_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "vacuum.nc" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture(scope="module") +def vac_checkpoint_path(): + pooch_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "vacuum_checkpoint.nc" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture(scope="module") +def ahfe_solv_trajectory_path(): + pooch_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "solvent.nc" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture(scope="module") +def ahfe_solv_checkpoint_path(): + pooch_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "solvent_checkpoint.nc" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + @pytest.fixture def get_available_openmm_platforms() -> set[str]: """ diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py index 21d1009a8..bf280e51e 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py @@ -37,12 +37,6 @@ from openfe.protocols.openmm_afe import ( AbsoluteBindingProtocol, ) -from openfe.protocols.openmm_afe.abfe_units import ( - ABFEComplexSetupUnit, - ABFEComplexSimUnit, - ABFESolventSetupUnit, - ABFESolventSimUnit, -) from .utils import UNIT_TYPES, _get_units diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py index 5d815c713..a6519ddab 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py @@ -8,6 +8,7 @@ import gufe import numpy as np +import openmm import pytest from openff.units import unit as offunit @@ -31,6 +32,9 @@ def patcher(): "box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm, "standard_state_correction": 0 * offunit.kilocalorie_per_mole, "restraint_geometry": None, + "gufe_version": gufe.__version__, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, }, ), mock.patch( diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py index b2f2e387d..536b7fd44 100644 --- a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py @@ -4,6 +4,7 @@ from math import sqrt from unittest import mock +import gufe import mdtraj as mdt import numpy as np import pytest @@ -25,6 +26,7 @@ from openfe.protocols import openmm_afe from openfe.protocols.openmm_afe import ( AbsoluteSolvationProtocol, + AHFESolventSimUnit, ) from openfe.protocols.openmm_utils.charge_generation import ( HAS_ESPALOMA_CHARGE, @@ -63,6 +65,24 @@ def test_serialize_protocol(default_settings): assert protocol == ret +def test_bad_sampler(): + class FakeSimSettings(gufe.settings.SettingsBaseModel): + sampler_method: str = "foo bar" + + errmsg = "Unknown sampler foo bar" + with pytest.raises(AttributeError, match=errmsg): + AHFESolventSimUnit._get_sampler( + integrator=None, + reporter=None, + simulation_settings=FakeSimSettings(), + thermodynamic_settings=None, + compound_states=None, + sampler_states=None, + platform=None, + restart=False, + ) + + def test_repeat_units(benzene_system): protocol = openmm_afe.AbsoluteSolvationProtocol( settings=openmm_afe.AbsoluteSolvationProtocol.default_settings() diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py index 0cb2d2d25..66e12d302 100644 --- a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py @@ -7,6 +7,7 @@ import gufe import numpy as np +import openmm import pytest from openff.units import unit as offunit @@ -51,6 +52,9 @@ def patcher(): "box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm, "standard_state_correction": 0 * offunit.kilocalorie_per_mole, "restraint_geometry": None, + "gufe_version": gufe.__version__, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, }, ), mock.patch( diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_resume.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_resume.py new file mode 100644 index 000000000..ac3448442 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_resume.py @@ -0,0 +1,483 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import copy +import os +import pathlib +import shutil + +import gufe +import openmm +import pytest +from gufe.protocols.errors import ProtocolUnitExecutionError +from numpy.testing import assert_allclose +from openfe_analysis.utils.multistate import _determine_position_indices +from openff.units import unit as offunit +from openff.units.openmm import from_openmm +from openmmtools.multistate import MultiStateReporter, ReplicaExchangeSampler + +import openfe +from openfe.data._registry import POOCH_CACHE +from openfe.protocols import openmm_afe + +from ...conftest import HAS_INTERNET +from .utils import _get_units + + +@pytest.fixture() +def protocol_settings(): + settings = openmm_afe.AbsoluteSolvationProtocol.default_settings() + settings.protocol_repeats = 1 + settings.solvent_output_settings.output_indices = "resname UNK" + settings.solvation_settings.solvent_padding = None + settings.solvation_settings.number_of_solvent_molecules = 750 + settings.solvation_settings.box_shape = "dodecahedron" + settings.vacuum_simulation_settings.equilibration_length = 100 * offunit.picosecond + settings.vacuum_simulation_settings.production_length = 200 * offunit.picosecond + settings.solvent_simulation_settings.equilibration_length = 100 * offunit.picosecond + settings.solvent_simulation_settings.production_length = 200 * offunit.picosecond + settings.vacuum_engine_settings.compute_platform = "CUDA" + settings.solvent_engine_settings.compute_platform = "CUDA" + settings.vacuum_simulation_settings.time_per_iteration = 2.5 * offunit.picosecond + settings.solvent_simulation_settings.time_per_iteration = 2.5 * offunit.picosecond + settings.vacuum_output_settings.checkpoint_interval = 100 * offunit.picosecond + settings.solvent_output_settings.checkpoint_interval = 100 * offunit.picosecond + return settings + + +def test_verify_execution_environment(): + # Verification should pass + openmm_afe.AHFESolventSimUnit._verify_execution_environment( + setup_outputs={ + "gufe_version": gufe.__version__, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, + }, + ) + + +def test_verify_execution_environment_fail(): + # Passing a bad version should fail + with pytest.raises(ProtocolUnitExecutionError, match="Python environment"): + openmm_afe.AHFESolventSimUnit._verify_execution_environment( + setup_outputs={ + "gufe_version": 0.1, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, + }, + ) + + +def test_verify_execution_env_missing_key(): + errmsg = "Missing environment information from setup outputs." + with pytest.raises(ProtocolUnitExecutionError, match=errmsg): + openmm_afe.AHFESolventSimUnit._verify_execution_environment( + setup_outputs={ + "foo_version": 0.1, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, + }, + ) + + +@pytest.mark.skipif( + not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, + reason="Internet unavailable and test data is not cached locally", +) +def test_solvent_check_restart(protocol_settings, ahfe_solv_trajectory_path): + assert openmm_afe.AHFESolventSimUnit._check_restart( + output_settings=protocol_settings.solvent_output_settings, + shared_path=ahfe_solv_trajectory_path.parent, + ) + + assert not openmm_afe.AHFESolventSimUnit._check_restart( + output_settings=protocol_settings.solvent_output_settings, + shared_path=pathlib.Path("."), + ) + + +@pytest.mark.skipif( + not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, + reason="Internet unavailable and test data is not cached locally", +) +def test_vacuum_check_restart(protocol_settings, ahfe_vac_trajectory_path): + assert openmm_afe.AHFEVacuumSimUnit._check_restart( + output_settings=protocol_settings.vacuum_output_settings, + shared_path=ahfe_vac_trajectory_path.parent, + ) + + assert not openmm_afe.AHFEVacuumSimUnit._check_restart( + output_settings=protocol_settings.vacuum_output_settings, + shared_path=pathlib.Path("."), + ) + + +@pytest.mark.skipif( + not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, + reason="Internet unavailable and test data is not cached locally", +) +def test_check_restart_one_file_missing(protocol_settings, ahfe_vac_trajectory_path): + protocol_settings.vacuum_output_settings.checkpoint_storage_filename = "foo.nc" + + errmsg = "One of either the trajectory or checkpoint files are missing" + with pytest.raises(IOError, match=errmsg): + openmm_afe.AHFEVacuumSimUnit._check_restart( + output_settings=protocol_settings.vacuum_output_settings, + shared_path=ahfe_vac_trajectory_path.parent, + ) + + +class TestCheckpointResuming: + @pytest.fixture() + def protocol_dag( + self, + protocol_settings, + benzene_modifications, + ): + stateA = openfe.ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "solvent": openfe.SolventComponent(), + } + ) + + stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) + + protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_settings) + + # Create DAG from protocol, get the vacuum and solvent units + # and eventually dry run the first solvent unit + return protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + + @staticmethod + def _check_sampler(sampler, num_iterations: int): + # Helper method to do some checks on the sampler + assert sampler._iteration == num_iterations + assert sampler.number_of_iterations == 80 + assert sampler.is_completed is (num_iterations == 80) + assert sampler.n_states == sampler.n_replicas == 14 + assert sampler.is_periodic + assert sampler.mcmc_moves[0].n_steps == 625 + assert from_openmm(sampler.mcmc_moves[0].timestep) == 4 * offunit.fs + + @staticmethod + def _get_positions(dataset): + frame_list = _determine_position_indices(dataset) + positions = [] + for frame in frame_list: + positions.append(copy.deepcopy(dataset.variables["positions"][frame].data)) + return positions + + @staticmethod + def _copy_simfiles(cwd: pathlib.Path, filepath): + shutil.copyfile(filepath, f"{cwd}/{filepath.name}") + + @pytest.mark.integration + def test_resume( + self, protocol_dag, ahfe_solv_trajectory_path, ahfe_solv_checkpoint_path, tmpdir + ): + """ + Attempt to resume a simulation unit with pre-existing checkpoint & + trajectory files. + """ + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, ahfe_solv_trajectory_path) + self._copy_simfiles(cwd, ahfe_solv_checkpoint_path) + + # 1. Check that the trajectory / checkpoint contain what we expect + reporter = MultiStateReporter( + f"{cwd}/solvent.nc", + checkpoint_storage="solvent_checkpoint.nc", + ) + sampler = ReplicaExchangeSampler.from_storage(reporter) + + self._check_sampler(sampler, num_iterations=40) + + # Deep copy energies & positions for later comparison + init_energies = copy.deepcopy(reporter.read_energies())[0] + assert init_energies.shape == (41, 14, 14) + init_positions = self._get_positions(reporter._storage[0]) + assert len(init_positions) == 2 + + reporter.close() + del sampler + + # 2. get & run the units + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, openmm_afe.AHFESolventSetupUnit)[0] + sim_unit = _get_units(pus, openmm_afe.AHFESolventSimUnit)[0] + analysis_unit = _get_units(pus, openmm_afe.AHFESolventAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run( + dry=True, + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + # Now we run the simultion in resume mode + sim_results = sim_unit.run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + # Finally we analyze the results + _ = analysis_unit.run( + trajectory=sim_results["trajectory"], + checkpoint=sim_results["checkpoint"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + # Analyze the trajectory / checkpoint again + reporter = MultiStateReporter( + f"{cwd}/solvent.nc", + checkpoint_storage="solvent_checkpoint.nc", + ) + + sampler = ReplicaExchangeSampler.from_storage(reporter) + + self._check_sampler(sampler, num_iterations=80) + + # Check the energies and positions + energies = reporter.read_energies()[0] + assert energies.shape == (81, 14, 14) + assert_allclose(init_energies, energies[:41]) + + positions = self._get_positions(reporter._storage[0]) + assert len(positions) == 3 + for i in range(2): + assert_allclose(positions[i], init_positions[i]) + + reporter.close() + del sampler + + # Check the free energy plots are there + mbar_overlap_file = cwd / "mbar_overlap_matrix.png" + assert (mbar_overlap_file).exists() + + @pytest.mark.slow + def test_resume_fail_particles( + self, protocol_dag, ahfe_solv_trajectory_path, ahfe_solv_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check that we don't have the same particles / mass. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, ahfe_solv_trajectory_path) + self._copy_simfiles(cwd, ahfe_solv_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, openmm_afe.AHFESolventSetupUnit)[0] + sim_unit = _get_units(pus, openmm_afe.AHFESolventSimUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Create a fake system where we will add a particle + fake_system = copy.deepcopy(setup_results["alchem_system"]) + fake_system.addParticle(42) + + # Fake system should trigger a mismatch + errmsg = "Stored checkpoint System particles do not" + with pytest.raises(ValueError, match=errmsg): + _ = sim_unit.run( + system=fake_system, + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + def test_resume_fail_constraints( + self, protocol_dag, ahfe_solv_trajectory_path, ahfe_solv_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check that we don't have the same constraints. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, ahfe_solv_trajectory_path) + self._copy_simfiles(cwd, ahfe_solv_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, openmm_afe.AHFESolventSetupUnit)[0] + sim_unit = _get_units(pus, openmm_afe.AHFESolventSimUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Create a fake system without constraints + fake_system = copy.deepcopy(setup_results["alchem_system"]) + + for i in reversed(range(fake_system.getNumConstraints())): + fake_system.removeConstraint(i) + + # Fake system should trigger a mismatch + errmsg = "Stored checkpoint System constraints do not" + with pytest.raises(ValueError, match=errmsg): + _ = sim_unit.run( + system=fake_system, + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + def test_resume_fail_forces( + self, protocol_dag, ahfe_solv_trajectory_path, ahfe_solv_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check we don't have the same forces. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, ahfe_solv_trajectory_path) + self._copy_simfiles(cwd, ahfe_solv_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, openmm_afe.AHFESolventSetupUnit)[0] + sim_unit = _get_units(pus, openmm_afe.AHFESolventSimUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Create a fake system without the last force + fake_system = copy.deepcopy(setup_results["alchem_system"]) + fake_system.removeForce(fake_system.getNumForces() - 1) + + # Fake system should trigger a mismatch + errmsg = "Number of forces stored in checkpoint System" + with pytest.raises(ValueError, match=errmsg): + _ = sim_unit.run( + system=fake_system, + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + @pytest.mark.parametrize("forcetype", [openmm.NonbondedForce, openmm.MonteCarloBarostat]) + def test_resume_differ_forces( + self, forcetype, protocol_dag, ahfe_solv_trajectory_path, ahfe_solv_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check we have a different force + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, ahfe_solv_trajectory_path) + self._copy_simfiles(cwd, ahfe_solv_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, openmm_afe.AHFESolventSetupUnit)[0] + sim_unit = _get_units(pus, openmm_afe.AHFESolventSimUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Create a fake system with the fake forcetype + fake_system = copy.deepcopy(setup_results["alchem_system"]) + + # Loop through forces and remove the force matching forcetype + for i, f in enumerate(fake_system.getForces()): + if isinstance(f, forcetype): + findex = i + + fake_system.removeForce(findex) + + # Now add a fake force + if forcetype == openmm.MonteCarloBarostat: + new_force = forcetype(1 * openmm.unit.atmosphere, 300 * openmm.unit.kelvin, 100) + elif forcetype == openmm.NonbondedForce: + new_force = forcetype() + new_force.setNonbondedMethod(openmm.NonbondedForce.PME) + new_force.addGlobalParameter("lambda_electrostatics", 1.0) + + fake_system.addForce(new_force) + + # Fake system should trigger a mismatch + errmsg = "stored checkpoint System does not match the same force" + with pytest.raises(ValueError, match=errmsg): + _ = sim_unit.run( + system=fake_system, + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + @pytest.mark.parametrize("bad_file", ["trajectory", "checkpoint"]) + def test_resume_bad_files( + self, protocol_dag, ahfe_solv_trajectory_path, ahfe_solv_checkpoint_path, bad_file, tmpdir + ): + """ + Test what happens when you have a bad trajectory and/or checkpoint + files. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + + if bad_file == "trajectory": + with open(f"{cwd}/solvent.nc", "w") as f: + f.write("foo") + else: + self._copy_simfiles(cwd, ahfe_solv_trajectory_path) + + if bad_file == "checkpoint": + with open(f"{cwd}/solvent_checkpoint.nc", "w") as f: + f.write("bar") + else: + self._copy_simfiles(cwd, ahfe_solv_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, openmm_afe.AHFESolventSetupUnit)[0] + sim_unit = _get_units(pus, openmm_afe.AHFESolventSimUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + with pytest.raises(OSError, match="Unknown file format"): + _ = sim_unit.run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + scratch_basepath=cwd, + shared_basepath=cwd, + ) diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_resume.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_resume.py index a14f32c3f..8884cf23b 100644 --- a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_resume.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_resume.py @@ -183,7 +183,7 @@ def test_resume(self, protocol_dag, htop_trajectory_path, htop_checkpoint_path, ) # Finally we analyze the results - analysis_results = analysis_unit.run( + _ = analysis_unit.run( pdb_file=setup_results["pdb_structure"], trajectory=sim_results["nc"], checkpoint=sim_results["checkpoint"], @@ -235,7 +235,6 @@ def test_resume_fail_particles( pus = list(protocol_dag.protocol_units) setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] - analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] # Dry run the setup since it'll be easier to use the objects directly setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) @@ -243,7 +242,7 @@ def test_resume_fail_particles( # Fake system should trigger a mismatch errmsg = "Stored checkpoint System particles do not" with pytest.raises(ValueError, match=errmsg): - sim_results = simulation_unit.run( + _ = simulation_unit.run( system=openmm.System(), positions=setup_results["hybrid_positions"], selection_indices=setup_results["selection_indices"], @@ -269,7 +268,6 @@ def test_resume_fail_constraints( pus = list(protocol_dag.protocol_units) setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] - analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] # Dry run the setup since it'll be easier to use the objects directly setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) @@ -283,7 +281,7 @@ def test_resume_fail_constraints( # Fake system should trigger a mismatch errmsg = "Stored checkpoint System constraints do not" with pytest.raises(ValueError, match=errmsg): - sim_results = simulation_unit.run( + _ = simulation_unit.run( system=fake_system, positions=setup_results["hybrid_positions"], selection_indices=setup_results["selection_indices"], @@ -309,7 +307,6 @@ def test_resume_fail_forces( pus = list(protocol_dag.protocol_units) setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] - analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] # Dry run the setup since it'll be easier to use the objects directly setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) @@ -321,7 +318,7 @@ def test_resume_fail_forces( # Fake system should trigger a mismatch errmsg = "Number of forces stored in checkpoint System" with pytest.raises(ValueError, match=errmsg): - sim_results = simulation_unit.run( + _ = simulation_unit.run( system=fake_system, positions=setup_results["hybrid_positions"], selection_indices=setup_results["selection_indices"], @@ -348,7 +345,6 @@ def test_resume_differ_forces( pus = list(protocol_dag.protocol_units) setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] - analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] # Dry run the setup since it'll be easier to use the objects directly setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) @@ -374,7 +370,7 @@ def test_resume_differ_forces( # Fake system should trigger a mismatch errmsg = "stored checkpoint System does not match the same force" with pytest.raises(ValueError, match=errmsg): - sim_results = simulation_unit.run( + _ = simulation_unit.run( system=fake_system, positions=setup_results["hybrid_positions"], selection_indices=setup_results["selection_indices"], @@ -409,13 +405,12 @@ def test_resume_bad_files( pus = list(protocol_dag.protocol_units) setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] - analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] # Dry run the setup since it'll be easier to use the objects directly setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) with pytest.raises(OSError, match="Unknown file format"): - sim_results = simulation_unit.run( + _ = simulation_unit.run( system=setup_results["hybrid_system"], positions=setup_results["hybrid_positions"], selection_indices=setup_results["selection_indices"], @@ -447,14 +442,13 @@ def test_missing_file( pus = list(protocol_dag.protocol_units) setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] - analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] # Dry run the setup since it'll be easier to use the objects directly setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) errmsg = "One of either the trajectory or checkpoint files are missing" with pytest.raises(IOError, match=errmsg): - sim_results = simulation_unit.run( + _ = simulation_unit.run( system=setup_results["hybrid_system"], positions=setup_results["hybrid_positions"], selection_indices=setup_results["selection_indices"],