diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index b1ca0b786..03102435f 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -22,16 +22,26 @@ import matplotlib.pyplot as plt import mdtraj import numpy as np +import numpy.typing as npt +import openmm import openmmtools from gufe import ( ChemicalSystem, + Component, LigandAtomMapping, + ProteinComponent, SmallMoleculeComponent, - settings, + SolventComponent, +) +from gufe.settings import ( + SettingsBaseModel, + ThermoSettings, ) from openff.toolkit.topology import Molecule as OFFMolecule +from openff.units import Quantity from openff.units import unit as offunit from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmmforcefields.generators import SystemGenerator from openmmtools import multistate from openfe.protocols.openmm_utils.omm_settings import ( @@ -49,6 +59,7 @@ system_validation, ) from . import _rfe_utils +from ._rfe_utils.relative import HybridTopologyFactory from .equil_rfe_settings import ( AlchemicalSettings, IntegratorSettings, @@ -56,6 +67,7 @@ MultiStateOutputSettings, MultiStateSimulationSettings, OpenFFPartialChargeSettings, + OpenMMEngineSettings, OpenMMSolvationSettings, RelativeHybridTopologyProtocolSettings, ) @@ -112,25 +124,128 @@ def __init__( generation=generation, ) + def _prepare( + self, + verbose: bool, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, + ): + """ + Set basepaths and do some initial logging. + + Parameters + ---------- + verbose : bool + Verbose output of the simulation progress. Output is provided at the + INFO level logging. + scratch_basepath : pathlib.Path | None + Optional scratch base path to write scratch files to. + shared_basepath : pathlib.Path | None + Optional shared base path to write shared files to. + """ + self.verbose = verbose + + if self.verbose: + self.logger.info("Setting up the hybrid topology simulation") + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path(".") + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + @staticmethod + def _get_settings( + settings: RelativeHybridTopologyProtocolSettings, + ) -> dict[str, SettingsBaseModel]: + """ + Get a dictionary of Protocol settings. + + Returns + ------- + protocol_settings : dict[str, SettingsBaseModel] + + Notes + ----- + We return a dict so that we can duck type behaviour between phases. + For example subclasses may contain both `solvent` and `complex` + settings, using this approach we can extract the relevant entry + to the same key and pass it to other methods in a seamless manner. + """ + protocol_settings: dict[str, SettingsBaseModel] = {} + protocol_settings["forcefield_settings"] = settings.forcefield_settings + protocol_settings["thermo_settings"] = settings.thermo_settings + protocol_settings["alchemical_settings"] = settings.alchemical_settings + protocol_settings["lambda_settings"] = settings.lambda_settings + protocol_settings["charge_settings"] = settings.partial_charge_settings + protocol_settings["solvation_settings"] = settings.solvation_settings + protocol_settings["simulation_settings"] = settings.simulation_settings + protocol_settings["output_settings"] = settings.output_settings + protocol_settings["integrator_settings"] = settings.integrator_settings + protocol_settings["engine_settings"] = settings.engine_settings + return protocol_settings + + @staticmethod + def _get_components( + stateA: ChemicalSystem, stateB: ChemicalSystem + ) -> tuple[ + dict[str, Component], + SolventComponent, + ProteinComponent, + dict[SmallMoleculeComponent, OFFMolecule], + ]: + """ + Get the components from the ChemicalSystem inputs. + + Parameters + ---------- + stateA : ChemicalSystem + ChemicalSystem defining the state A components. + stateB : CHemicalSystem + ChemicalSystem defining the state B components. + + Returns + ------- + alchem_comps : dict[str, Component] + Dictionary of alchemical components. + solv_comp : SolventComponent + The solvent component. + protein_comp : ProteinComponent + The protein component. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of small molecule components paired + with their OpenFF Molecule. + """ + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + + solvent_comp, protein_comp, smcs_A = system_validation.get_components(stateA) + _, _, smcs_B = system_validation.get_components(stateB) + + small_mols = {m: m.to_openff() for m in set(smcs_A).union(set(smcs_B))} + + return alchem_comps, solvent_comp, protein_comp, small_mols + @staticmethod def _assign_partial_charges( charge_settings: OpenFFPartialChargeSettings, - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]], + small_mols: dict[SmallMoleculeComponent, OFFMolecule], ) -> None: """ - Assign partial charges to SMCs. + Assign partial charges to the OpenFF Molecules associated with all + the SmallMoleculeComponents in the transformation. Parameters ---------- charge_settings : OpenFFPartialChargeSettings Settings for controlling how the partial charges are assigned. - off_small_mols : dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - Dictionary of dictionary of OpenFF Molecules to add, keyed by - state and SmallMoleculeComponent. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules to add, keyed by + their associated SmallMoleculeComponent. """ - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): + for smc, mol in small_mols.items(): charge_generation.assign_offmol_partial_charges( offmol=mol, overwrite=False, @@ -140,227 +255,407 @@ def _assign_partial_charges( nagl_model=charge_settings.nagl_model, ) - def run( - self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None - ) -> dict[str, Any]: - """Run the relative free energy calculation. + @staticmethod + def _get_system_generator( + shared_basepath: pathlib.Path, + settings: dict[str, SettingsBaseModel], + solvent_comp: SolventComponent | None, + openff_molecules: list[OFFMolecule] | None, + ) -> SystemGenerator: + """ + Get an OpenMM SystemGenerator. Parameters ---------- - dry : bool - Do a dry run of the calculation, creating all necessary hybrid - system components (topology, system, sampler, etc...) but without - running the simulation. - verbose : bool - Verbose output of the simulation progress. Output is provided via - INFO level logging. - scratch_basepath: Pathlike, optional - Where to store temporary files, defaults to current working directory - shared_basepath : Pathlike, optional - Where to run the calculation, defaults to current working directory + settings : dict[str, SettingsBaseModel] + A dictionary of protocol settings. + solvent_comp : SolventComponent | None + The solvent component of the system, if any. + openff_molecules : list[openff.Toolkit] | None + A list of openff molecules to generate templates for, if any. Returns ------- - dict - Outputs created in the basepath directory or the debug objects - (i.e. sampler) if ``dry==True``. + system_generator : openmmtools.SystemGenerator + The SystemGenerator for the protocol. + """ + ffcache = settings["output_settings"].forcefield_cache - Raises - ------ - error - Exception if anything failed + if ffcache is not None: + ffcache = shared_basepath / ffcache + + # Block out oechem backend in system_generator calls to avoid + # any issues with smiles roundtripping between rdkit and oechem + with without_oechem_backend(): + system_generator = system_creation.get_system_generator( + forcefield_settings=settings["forcefield_settings"], + integrator_settings=settings["integrator_settings"], + thermo_settings=settings["thermo_settings"], + cache=ffcache, + has_solvent=solvent_comp is not None, + ) + + # Handle openff Molecule templates + # TODO: revisit this once the SystemGenerator update happens + # and we start loading the whole protein into OpenFF Topologies + + # First deduplicate isomorphic molecules, if there are any + if openff_molecules is None: + return system_generator + + unique_offmols: list[OFFMolecule] = [] + for mol in openff_molecules: + unique = all([not mol.is_isomorphic_with(umol) for umol in unique_offmols]) + if unique: + unique_offmols.append(mol) + + # register all the templates + system_generator.add_molecules(unique_offmols) + + return system_generator + + @staticmethod + def _create_stateA_system( + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + system_generator: SystemGenerator, + solvation_settings: OpenMMSolvationSettings, + ) -> tuple[ + openmm.System, openmm.app.Topology, openmm.unit.Quantity, dict[Component, npt.NDArray] + ]: """ - if verbose: - self.logger.info("Preparing the hybrid topology simulation") - if scratch_basepath is None: - scratch_basepath = pathlib.Path(".") - if shared_basepath is None: - # use cwd - shared_basepath = pathlib.Path(".") + Create an OpenMM System for state A. + + Parameters + ---------- + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + A list of small molecules to include in the System. + protein_component : ProteinComponent | None + Optionally, the protein component to include in the System. + solvent_component : SolventComponent | None + Optionally, the solvent component to include in the System. + system_generator : SystemGenerator + The SystemGenerator object ot use to construct the System. + solvation_settings : OpenMMSolvationSettings + Settings defining how to build the System. - # 0. General setup and settings dependency resolution step + Returns + ------- + system : openmm.System + The System that defines state A. + topology : openmm.app.Topology + The Topology defining the returned System. + positions : openmm.unit.Quantity + The positions of the particles in the System. + comp_residues : dict[Component, npt.NDArray] + A dictionary defining which residues in the System + belong to which ChemicalSystem Component. + """ + modeller, comp_resids = system_creation.get_omm_modeller( + protein_comp=protein_component, + solvent_comp=solvent_component, + small_mols=small_mols, + omm_forcefield=system_generator.forcefield, + solvent_settings=solvation_settings, + ) - # Extract relevant settings - protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs[ - "protocol" - ].settings - stateA = self._inputs["stateA"] - stateB = self._inputs["stateB"] - mapping = self._inputs["ligandmapping"] + topology = modeller.getTopology() + # Note: roundtrip positions to remove vec3 issues + positions = to_openmm(from_openmm(modeller.getPositions())) - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = ( - protocol_settings.forcefield_settings + system = system_generator.create_system( + modeller.topology, + molecules=list(small_mols.values()), ) - thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings - alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings - lambda_settings: LambdaSettings = protocol_settings.lambda_settings - charge_settings: BasePartialChargeSettings = protocol_settings.partial_charge_settings - solvation_settings: OpenMMSolvationSettings = protocol_settings.solvation_settings - sampler_settings: MultiStateSimulationSettings = protocol_settings.simulation_settings - output_settings: MultiStateOutputSettings = protocol_settings.output_settings - integrator_settings: IntegratorSettings = protocol_settings.integrator_settings - - # TODO: Also validate various conversions? - # Convert various time based inputs to steps/iterations - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings=sampler_settings, - integrator_settings=integrator_settings, + + return system, topology, positions, comp_resids + + @staticmethod + def _create_stateB_system( + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + mapping: LigandAtomMapping, + stateA_topology: openmm.app.Topology, + exclude_resids: npt.NDArray, + system_generator: SystemGenerator, + ) -> tuple[openmm.System, openmm.app.Topology, npt.NDArray]: + """ + Create the state B System from the state A Topology. + + Parameters + ---------- + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules keyed by SmallMoleculeComponent + to be present in system B. + mapping : LigandAtomMapping + LigandAtomMapping defining the correspondance betwee state A + and B's alchemical ligand. + stateA_topology : openmm.app.Topology + The OpenMM topology for state A. + exclude_resids : npt.NDArray + A list of residues to exclude from state A when building state B. + system_generator : SystemGenerator + The SystemGenerator to use to build System B. + + Returns + ------- + system : openmm.System + The state B System. + topology : openmm.app.Topology + The OpenMM Topology associated with the state B System. + alchem_resids : npt.NDArray + The residue indices of the state B alchemical species. + """ + topology, alchem_resids = _rfe_utils.topologyhelpers.combined_topology( + topology1=stateA_topology, + topology2=small_mols[mapping.componentB].to_topology().to_openmm(), + exclude_resids=exclude_resids, ) - equil_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.equilibration_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, + system = system_generator.create_system( + topology, + molecules=list(small_mols.values()), ) - prod_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.production_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, + + return system, topology, alchem_resids + + @staticmethod + def _handle_net_charge( + stateA_topology: openmm.app.Topology, + stateA_positions: openmm.unit.Quantity, + stateB_topology: openmm.app.Topology, + stateB_system: openmm.System, + charge_difference: int, + system_mappings: dict[str, dict[int, int]], + distance_cutoff: Quantity, + solvent_component: SolventComponent | None, + ) -> None: + """ + Handle system net charge by adding an alchemical water. + + Parameters + ---------- + stateA_topology : openmm.app.Topology + stateA_positions : openmm.unit.Quantity + stateB_topology : openmm.app.Topology + stateB_system : openmm.System + charge_difference : int + system_mappings : dict[str, dict[int, int]] + distance_cutoff : Quantity + solvent_component : SolventComponent | None + """ + # Base case, return if no net charge + if charge_difference == 0: + return + + # Get the residue ids for waters to turn alchemical + alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( + topology=stateA_topology, + positions=stateA_positions, + charge_difference=charge_difference, + distance_cutoff=distance_cutoff, ) - solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) - - # Get the change difference between the end states - # and check if the charge correction used is appropriate - charge_difference = mapping.get_alchemical_charge_difference() - - # 1. Create stateA system - self.logger.info("Parameterizing molecules") - - # a. create offmol dictionaries and assign partial charges - # workaround for conformer generation failures - # see openfe issue #576 - # calculate partial charges manually if not already given - # convert to OpenFF here, - # and keep the molecule around to maintain the partial charges - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - off_small_mols = { - "stateA": [(mapping.componentA, mapping.componentA.to_openff())], - "stateB": [(mapping.componentB, mapping.componentB.to_openff())], - "both": [ - (m, m.to_openff()) - for m in small_mols - if (m != mapping.componentA and m != mapping.componentB) - ], - } + # In-place modify state B alchemical waters to ions + _rfe_utils.topologyhelpers.handle_alchemical_waters( + water_resids=alchem_water_resids, + topology=stateB_topology, + system=stateB_system, + system_mapping=system_mappings, + charge_difference=charge_difference, + solvent_component=solvent_component, + ) - self._assign_partial_charges(charge_settings, off_small_mols) + def _get_omm_objects( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: LigandAtomMapping, + settings: dict[str, SettingsBaseModel], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + ) -> tuple[ + openmm.System, + openmm.app.Topology, + openmm.unit.Quantity, + openmm.System, + openmm.app.Topology, + openmm.unit.Quantity, + dict[str, dict[int, int]], + ]: + """ + Get OpenMM objects for both end states A and B. - # b. get a system generator - if output_settings.forcefield_cache is not None: - ffcache = shared_basepath / output_settings.forcefield_cache - else: - ffcache = None + Parameters + ---------- + stateA : ChemicalSystem + ChemicalSystem defining end state A. + stateB : ChemicalSystem + ChemicalSystem defining end state B. + mapping : LigandAtomMapping + The mapping for alchemical components between state A and B. + settings : dict[str, SettingsBaseModel] + Settings for the transformation. + protein_component : ProteinComponent | None + The common ProteinComponent between the end states, if there is is one. + solvent_component : SolventComponent | None + The common SolventComponent between the end states, if there is one. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + The small molecules for both end states. - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - system_generator = system_creation.get_system_generator( - forcefield_settings=forcefield_settings, - integrator_settings=integrator_settings, - thermo_settings=thermo_settings, - cache=ffcache, - has_solvent=solvent_comp is not None, - ) + Returns + ------- + stateA_system : openmm.System + OpenMM System for state A. + stateA_topology : openmm.app.Topology + OpenMM Topology for the state A System. + stateA_positions : openmm.unit.Quantity + Positions of partials for state A System. + stateB_system : openmm.System + OpenMM System for state B. + stateB_topology : openmm.app.Topology + OpenMM Topology for the state B System. + stateB_positions : openmm.unit.Quantity + Positions of partials for state B System. + system_mapping : dict[str, dict[int, int]] + Dictionary of mappings defining the correspondance between + the two state Systems. + """ + if self.verbose: + self.logger.info("Parameterizing systems") - # c. force the creation of parameters - # This is necessary because we need to have the FF templates - # registered ahead of solvating the system. - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): - system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) - - # c. get OpenMM Modeller + a dictionary of resids for each component - stateA_modeller, comp_resids = system_creation.get_omm_modeller( - protein_comp=protein_comp, - solvent_comp=solvent_comp, - small_mols=dict(chain(off_small_mols["stateA"], off_small_mols["both"])), - omm_forcefield=system_generator.forcefield, - solvent_settings=solvation_settings, - ) + def _filter_small_mols(smols, state): + return {smc: offmol for smc, offmol in smols.items() if state.contains(smc)} - # d. get topology & positions - # Note: roundtrip positions to remove vec3 issues - stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) + small_mols_stateA = _filter_small_mols(small_mols, stateA) + small_mols_stateB = _filter_small_mols(small_mols, stateB) - # e. create the stateA System - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem + # Everything involving systemgenerator handling has a risk of + # oechem <-> rdkit smiles conversion clashes, cautiously ban it. with without_oechem_backend(): - stateA_system = system_generator.create_system( - stateA_modeller.topology, - molecules=[m for _, m in chain(off_small_mols["stateA"], off_small_mols["both"])], + # TODO: get two generators, one for state A and one for stateB + # See issue #1120 + # Get the system generator with all the templates registered + system_generator = self._get_system_generator( + shared_basepath=self.shared_basepath, + settings=settings, + solvent_comp=solvent_component, + openff_molecules=list(small_mols.values()), ) - # 2. Get stateB system - # a. get the topology - stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( - stateA_topology, - # zeroth item (there's only one) then get the OFF representation - off_small_mols["stateB"][0][1].to_topology().to_openmm(), - exclude_resids=comp_resids[mapping.componentA], - ) + (stateA_system, stateA_topology, stateA_positions, comp_resids) = ( + self._create_stateA_system( + small_mols=small_mols_stateA, + protein_component=protein_component, + solvent_component=solvent_component, + system_generator=system_generator, + solvation_settings=settings["solvation_settings"], + ) + ) - # b. get a list of small molecules for stateB - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - stateB_system = system_generator.create_system( - stateB_topology, - molecules=[m for _, m in chain(off_small_mols["stateB"], off_small_mols["both"])], + (stateB_system, stateB_topology, stateB_alchem_resids) = self._create_stateB_system( + small_mols=small_mols_stateB, + mapping=mapping, + stateA_topology=stateA_topology, + exclude_resids=comp_resids[mapping.componentA], + system_generator=system_generator, ) - # c. Define correspondence mappings between the two systems - ligand_mappings = _rfe_utils.topologyhelpers.get_system_mappings( - mapping.componentA_to_componentB, - stateA_system, - stateA_topology, - comp_resids[mapping.componentA], - stateB_system, - stateB_topology, - stateB_alchem_resids, + # Get the mapping between the two systems + system_mappings = _rfe_utils.topologyhelpers.get_system_mappings( + old_to_new_atom_map=mapping.componentA_to_componentB, + old_system=stateA_system, + old_topology=stateA_topology, + old_resids=comp_resids[mapping.componentA], + new_system=stateB_system, + new_topology=stateB_topology, + new_resids=stateB_alchem_resids, # These are non-optional settings for this method fix_constraints=True, ) - # d. if a charge correction is necessary, select alchemical waters - # and transform them - if alchem_settings.explicit_charge_correction: - alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( - stateA_topology, - stateA_positions, - charge_difference, - alchem_settings.explicit_charge_correction_cutoff, - ) - _rfe_utils.topologyhelpers.handle_alchemical_waters( - alchem_water_resids, - stateB_topology, - stateB_system, - ligand_mappings, - charge_difference, - solvent_comp, + # Net charge: add alchemical water if needed + # Must be done here as we in-place modify the particles of state B. + if settings["alchemical_settings"].explicit_charge_correction: + self._handle_net_charge( + stateA_topology=stateA_topology, + stateA_positions=stateA_positions, + stateB_topology=stateB_topology, + stateB_system=stateB_system, + charge_difference=mapping.get_alchemical_charge_difference(), + system_mappings=system_mappings, + distance_cutoff=settings["alchemical_settings"].explicit_charge_correction_cutoff, + solvent_component=solvent_component, ) - # e. Finally get the positions + # Finally get the state B positions stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( - ligand_mappings, + system_mappings, stateA_topology, stateB_topology, old_positions=ensure_quantity(stateA_positions, "openmm"), insert_positions=ensure_quantity( - off_small_mols["stateB"][0][1].conformers[0], "openmm" + small_mols[mapping.componentB].conformers[0], "openmm" ), ) - # 3. Create the hybrid topology - # a. Get softcore potential settings - if alchem_settings.softcore_LJ.lower() == "gapsys": + return ( + stateA_system, + stateA_topology, + stateA_positions, + stateB_system, + stateB_topology, + stateB_positions, + system_mappings, + ) + + @staticmethod + def _get_alchemical_system( + stateA_system: openmm.System, + stateA_positions: openmm.unit.Quantity, + stateA_topology: openmm.app.Topology, + stateB_system: openmm.System, + stateB_positions: openmm.unit.Quantity, + stateB_topology: openmm.app.Topology, + system_mappings: dict[str, dict[int, int]], + alchemical_settings: AlchemicalSettings, + ): + """ + Get the hybrid topology alchemical system. + + Parameters + ---------- + stateA_system : openmm.System + State A OpenMM System + stateA_positions : openmm.unit.Quantity + Positions of state A System + stateA_topology : openmm.app.Topology + Topology of state A System + stateB_system : openmm.System + State B OpenMM System + stateB_positions : openmm.unit.Quantity + Positions of state B System + stateB_topology : openmm.app.Topology + Topology of state B System + system_mappings : dict[str, dict[int, int]] + Mapping of corresponding atoms between the two Systems. + alchemical_settings : AlchemicalSettings + The alchemical settings defining how the alchemical system + will be built. + + Returns + ------- + hybrid_factory : HybridTopologyFactory + The factory creating the hybrid system. + hybrid_system : openmm.System + The hybrid System. + """ + if alchemical_settings.softcore_LJ.lower() == "gapsys": softcore_LJ_v2 = True - elif alchem_settings.softcore_LJ.lower() == "beutler": + elif alchemical_settings.softcore_LJ.lower() == "beutler": softcore_LJ_v2 = False - # b. Get hybrid topology factory + hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( stateA_system, stateA_positions, @@ -368,54 +663,161 @@ def run( stateB_system, stateB_positions, stateB_topology, - old_to_new_atom_map=ligand_mappings["old_to_new_atom_map"], - old_to_new_core_atom_map=ligand_mappings["old_to_new_core_atom_map"], - use_dispersion_correction=alchem_settings.use_dispersion_correction, - softcore_alpha=alchem_settings.softcore_alpha, + old_to_new_atom_map=system_mappings["old_to_new_atom_map"], + old_to_new_core_atom_map=system_mappings["old_to_new_core_atom_map"], + use_dispersion_correction=alchemical_settings.use_dispersion_correction, + softcore_alpha=alchemical_settings.softcore_alpha, softcore_LJ_v2=softcore_LJ_v2, - softcore_LJ_v2_alpha=alchem_settings.softcore_alpha, - interpolate_old_and_new_14s=alchem_settings.turn_off_core_unique_exceptions, + softcore_LJ_v2_alpha=alchemical_settings.softcore_alpha, + interpolate_old_and_new_14s=alchemical_settings.turn_off_core_unique_exceptions, ) - # 4. Create lambda schedule - # TODO - this should be exposed to users, maybe we should offer the - # ability to print the schedule directly in settings? - # fmt: off - lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( - functions=lambda_settings.lambda_functions, - windows=lambda_settings.lambda_windows - ) - # fmt: on - # PR #125 temporarily pin lambda schedule spacing to n_replicas - n_replicas = sampler_settings.n_replicas - if n_replicas != len(lambdas.lambda_schedule): - errmsg = ( - f"Number of replicas {n_replicas} " - f"does not equal the number of lambda windows " - f"{len(lambdas.lambda_schedule)}" + return hybrid_factory, hybrid_factory.hybrid_system + + def _subsample_topology( + self, + hybrid_topology: openmm.app.Topology, + hybrid_positions: openmm.unit.Quantity, + output_selection: str, + output_filename: str, + atom_classes: dict[str, set[int]], + ) -> npt.NDArray: + """ + Subsample the hybrid topology based on user-selected output selection + and write the subsampled topology to a PDB file. + + Parameters + ---------- + hybrid_topology : openmm.app.Topology + The hybrid system topology to subsample. + hybrid_positions : openmm.unit.Quantity + The hybrid system positions. + output_selection : str + An MDTraj selection string to subsample the topology with. + output_filename : str + The name of the file to write the PDB to. + atom_classes : dict[str, set[int]] + A dictionary defining what atoms belong to the different + components of the hybrid system. + + Returns + ------- + selection_indices : npt.NDArray + The indices of the subselected system. + + TODO + ---- + Modify this to also store the full system. + """ + selection_indices = hybrid_topology.select(output_selection) + + # Write out a PDB containing the subsampled hybrid state + # We use bfactors as a hack to label different states + # bfactor of 0 is environment atoms + # bfactor of 0.25 is unique old atoms + # bfactor of 0.5 is core atoms + # bfactor of 0.75 is unique new atoms + bfactors = np.zeros_like(selection_indices, dtype=float) + bfactors[np.isin(selection_indices, list(atom_classes["unique_old_atoms"]))] = 0.25 + bfactors[np.isin(selection_indices, list(atom_classes["core_atoms"]))] = 0.50 + bfactors[np.isin(selection_indices, list(atom_classes["unique_new_atoms"]))] = 0.75 + + if len(selection_indices) > 0: + traj = mdtraj.Trajectory( + hybrid_positions[selection_indices, :], + hybrid_topology.subset(selection_indices), + ).save_pdb( + self.shared_basepath / output_filename, + bfactors=bfactors, ) - raise ValueError(errmsg) - # 9. Create the multistate reporter - # Get the sub selection of the system to print coords for - selection_indices = hybrid_factory.hybrid_topology.select(output_settings.output_indices) + return selection_indices - # a. Create the multistate reporter - # convert checkpoint_interval from time to iterations - chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=output_settings.checkpoint_interval, - time_per_iteration=sampler_settings.time_per_iteration, + @staticmethod + def _get_integrator( + integrator_settings: IntegratorSettings, + simulation_settings: MultiStateSimulationSettings, + system: openmm.System, + ) -> openmmtools.mcmc.LangevinDynamicsMove: + """ + Get and validate the integrator + + Parameters + ---------- + integrator_settings : IntegratorSettings + Settings controlling the Langevin integrator. + simulation_settings : MultiStateSimulationSettings + Settings controlling the simulation. + system : openmm.System + The OpenMM System. + + Returns + ------- + integrator : openmmtools.mcmc.LangevinDynamicsMove + The LangevinDynamicsMove integrator. + + Raises + ------ + ValueError + If there are virtual sites in the system, but velocities + are not being reassigned after every MCMC move. + """ + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings, integrator_settings ) - nc = shared_basepath / output_settings.output_filename + integrator = openmmtools.mcmc.LangevinDynamicsMove( + timestep=to_openmm(integrator_settings.timestep), + collision_rate=to_openmm(integrator_settings.langevin_collision_rate), + n_steps=steps_per_iteration, + reassign_velocities=integrator_settings.reassign_velocities, + n_restart_attempts=integrator_settings.n_restart_attempts, + constraint_tolerance=integrator_settings.constraint_tolerance, + ) + + # Validate for known issue when dealing with virtual sites + # and multistate simulations + if not integrator_settings.reassign_velocities: + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + errmsg = ( + "Simulations with virtual sites without velocity " + "reassignments are unstable with MCMC integrators." + ) + raise ValueError(errmsg) + + return integrator + + @staticmethod + def _get_reporter( + storage_path: pathlib.Path, + selection_indices: npt.NDArray, + output_settings: MultiStateOutputSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> multistate.MultiStateReporter: + """ + Get the multistate reporter. + + Parameters + ---------- + storage_path : pathlib.Path + Path to the directory where files should be written. + selection_indices : npt.NDArray + The set of system indices to report positions & velocities for. + output_settings : MultiStateOutputSettings + Settings defining how outputs should be written. + simulation_settings : MultiStateSimulationSettings + Settings defining out the simulation should be run. + """ + nc = storage_path / output_settings.output_filename chk = output_settings.checkpoint_storage_filename if output_settings.positions_write_frequency is not None: pos_interval = settings_validation.divmod_time_and_check( numerator=output_settings.positions_write_frequency, - denominator=sampler_settings.time_per_iteration, + denominator=simulation_settings.time_per_iteration, numerator_name="output settings' position_write_frequency", - denominator_name="sampler settings' time_per_iteration", + denominator_name="simulation settings' time_per_iteration", ) else: pos_interval = 0 @@ -423,14 +825,19 @@ def run( if output_settings.velocities_write_frequency is not None: vel_interval = settings_validation.divmod_time_and_check( numerator=output_settings.velocities_write_frequency, - denominator=sampler_settings.time_per_iteration, + denominator=simulation_settings.time_per_iteration, numerator_name="output settings' velocity_write_frequency", denominator_name="sampler settings' time_per_iteration", ) else: vel_interval = 0 - reporter = multistate.MultiStateReporter( + chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + return multistate.MultiStateReporter( storage=nc, analysis_particle_indices=selection_indices, checkpoint_interval=chk_intervals, @@ -439,100 +846,98 @@ def run( velocity_interval=vel_interval, ) - # b. Write out a PDB containing the subsampled hybrid state - # fmt: off - bfactors = np.zeros_like(selection_indices, dtype=float) # solvent - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B - # bfactors[np.in1d(selection_indices, protein)] = 1.0 # prot+cofactor - if len(selection_indices) > 0: - traj = mdtraj.Trajectory( - hybrid_factory.hybrid_positions[selection_indices, :], - hybrid_factory.hybrid_topology.subset(selection_indices), - ).save_pdb( - shared_basepath / output_settings.output_structure, - bfactors=bfactors, - ) - # fmt: on - - # 10. Get compute platform - # restrict to a single CPU if running vacuum - restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" - platform = omm_compute.get_openmm_platform( - platform_name=protocol_settings.engine_settings.compute_platform, - gpu_device_index=protocol_settings.engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu, - ) - - # 11. Set the integrator - # a. Validate integrator settings for current system - # Virtual sites sanity check - ensure we restart velocities when - # there are virtual sites in the system - if hybrid_factory.has_virtual_sites: - if not integrator_settings.reassign_velocities: - errmsg = ( - "Simulations with virtual sites without velocity " - "reassignments are unstable in openmmtools" - ) - raise ValueError(errmsg) + @staticmethod + def _get_sampler( + system: openmm.System, + positions: openmm.unit.Quantity, + lambdas: _rfe_utils.lambdaprotocol.LambdaProtocol, + integrator: openmmtools.mcmc.MCMCMove, + reporter: multistate.MultiStateReporter, + simulation_settings: MultiStateSimulationSettings, + thermo_settings: ThermoSettings, + alchem_settings: AlchemicalSettings, + platform: openmm.Platform, + dry: bool, + ) -> multistate.MultiStateSampler: + """ + Get the MultiStateSampler. - # b. create langevin integrator - integrator = openmmtools.mcmc.LangevinDynamicsMove( - timestep=to_openmm(integrator_settings.timestep), - collision_rate=to_openmm(integrator_settings.langevin_collision_rate), - n_steps=steps_per_iteration, - reassign_velocities=integrator_settings.reassign_velocities, - n_restart_attempts=integrator_settings.n_restart_attempts, - constraint_tolerance=integrator_settings.constraint_tolerance, - ) + Parameters + ---------- + system : openmm.System + The OpenMM System to simulate. + positions : openmm.unit.Quantity + The positions of the OpenMM System. + lambdas : LambdaProtocol + The lambda protocol to sample along. + integrator : openmmtools.mcmc.MCMCMove + The integrator to use. + reporter : multistate.MultiStateReporter + The reporter to attach to the sampler. + simulation_settings : MultiStateSimulationSettings + The simulation control settings. + thermo_settings : ThermoSettings + The thermodynamic control settings. + alchem_settings : AlchemicalSettings + The alchemical transformation settings. + platform : openmm.Platform + The compute platform to use. + dry : bool + Whether or not this is a dry run. - # 12. Create sampler - self.logger.info("Creating and setting up the sampler") + Returns + ------- + sampler : multistate.MultiStateSampler + The requested sampler. + """ rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( - simulation_settings=sampler_settings, + simulation_settings=simulation_settings, ) + # convert early_termination_target_error from kcal/mol to kT early_termination_target_error = ( settings_validation.convert_target_error_from_kcal_per_mole_to_kT( thermo_settings.temperature, - sampler_settings.early_termination_target_error, + simulation_settings.early_termination_target_error, ) ) - if sampler_settings.sampler_method.lower() == "repex": + if simulation_settings.sampler_method.lower() == "repex": sampler = _rfe_utils.multistate.HybridRepexSampler( mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, + hybrid_system=system, + hybrid_positions=positions, online_analysis_interval=rta_its, online_analysis_target_error=early_termination_target_error, online_analysis_minimum_iterations=rta_min_its, ) - elif sampler_settings.sampler_method.lower() == "sams": + + elif simulation_settings.sampler_method.lower() == "sams": sampler = _rfe_utils.multistate.HybridSAMSSampler( mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, + hybrid_system=system, + hybrid_positions=positions, online_analysis_interval=rta_its, online_analysis_minimum_iterations=rta_min_its, - flatness_criteria=sampler_settings.sams_flatness_criteria, - gamma0=sampler_settings.sams_gamma0, + flatness_criteria=simulation_settings.sams_flatness_criteria, + gamma0=simulation_settings.sams_gamma0, ) - elif sampler_settings.sampler_method.lower() == "independent": + + elif simulation_settings.sampler_method.lower() == "independent": sampler = _rfe_utils.multistate.HybridMultiStateSampler( mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, + hybrid_system=system, + hybrid_positions=positions, online_analysis_interval=rta_its, online_analysis_target_error=early_termination_target_error, online_analysis_minimum_iterations=rta_min_its, ) + else: - raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") + raise AttributeError(f"Unknown sampler {simulation_settings.sampler_method}") sampler.setup( - n_replicas=sampler_settings.n_replicas, + n_replicas=simulation_settings.n_replicas, reporter=reporter, lambda_protocol=lambdas, temperature=to_openmm(thermo_settings.temperature), @@ -543,63 +948,259 @@ def run( minimization_steps=100 if not dry else None, ) - try: - # Create context caches (energy + sampler) - energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) + # Get and set the context caches + sampler.energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + sampler.sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) - sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, + return sampler + + def _run_simulation( + self, + sampler: multistate.MultiStateSampler, + reporter: multistate.MultiStateReporter, + simulation_settings: MultiStateSimulationSettings, + integrator_settings: IntegratorSettings, + output_settings: MultiStateOutputSettings, + dry: bool, + ): + """ + Run the simulation. + + Parameters + ---------- + sampler : multistate.MultiStateSampler. + The sampler associated with the simulation to run. + reporter : multistate.MultiStateReporter + The reporter associated with the sampler. + simulation_settings : MultiStateSimulationSettings + Simulation control settings. + integrator_settings : IntegratorSettings + Integrator control settings. + output_settings : MultiStateOutputSettings + Simulation output control settings. + dry : bool + Whether or not to dry run the simulation. + + Returns + ------- + unit_results_dict : dict | None + A dictionary containing the free energy results to report. + ``None`` if it is a dry run. + """ + # Get the relevant simulation steps + mc_steps = settings_validation.convert_steps_per_iteration( + simulation_settings=simulation_settings, + integrator_settings=integrator_settings, + ) + + equil_steps = settings_validation.get_simsteps( + sim_length=simulation_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=mc_steps, + ) + prod_steps = settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=mc_steps, + ) + + if not dry: # pragma: no-cover + # minimize + if self.verbose: + self.logger.info("minimizing systems") + + sampler.minimize(max_iterations=simulation_settings.minimization_steps) + + # equilibrate + if self.verbose: + self.logger.info("equilibrating systems") + + sampler.equilibrate(int(equil_steps / mc_steps)) + + # production + if self.verbose: + self.logger.info("running production phase") + + sampler.extend(int(prod_steps / mc_steps)) + + if self.verbose: + self.logger.info("production phase complete") + + if self.verbose: + self.logger.info("post-simulation result analysis") + + # calculate relevant analysis of the free energies & sampling + analyzer = multistate_analysis.MultistateEquilFEAnalysis( + reporter, + sampling_method=simulation_settings.sampler_method.lower(), + result_units=offunit.kilocalorie_per_mole, ) + analyzer.plot(filepath=self.shared_basepath, filename_prefix="") + analyzer.close() - sampler.energy_context_cache = energy_context_cache - sampler.sampler_context_cache = sampler_context_cache + return analyzer.unit_results_dict - if not dry: # pragma: no-cover - # minimize - if verbose: - self.logger.info("Running minimization") + else: + # We ran a dry simulation + # close reporter when you're done, prevent file handle clashes + reporter.close() - sampler.minimize(max_iterations=sampler_settings.minimization_steps) + # TODO: review this is likely no longer necessary + # clean up the reporter file + fns = [ + self.shared_basepath / output_settings.output_filename, + self.shared_basepath / output_settings.checkpoint_storage_filename, + ] + for fn in fns: + os.remove(fn) - # equilibrate - if verbose: - self.logger.info("Running equilibration phase") + return None - sampler.equilibrate(int(equil_steps / steps_per_iteration)) + def run( + self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + ) -> dict[str, Any]: + """Run the relative free energy calculation. - # production - if verbose: - self.logger.info("Running production phase") + Parameters + ---------- + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: Pathlike, optional + Where to store temporary files, defaults to current working directory + shared_basepath : Pathlike, optional + Where to run the calculation, defaults to current working directory - sampler.extend(int(prod_steps / steps_per_iteration)) + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. - self.logger.info("Production phase complete") + Raises + ------ + error + Exception if anything failed + """ + # Prepare paths & verbosity + self._prepare(verbose, scratch_basepath, shared_basepath) - self.logger.info("Post-simulation analysis of results") - # calculate relevant analyses of the free energies & sampling - # First close & reload the reporter to avoid netcdf clashes - analyzer = multistate_analysis.MultistateEquilFEAnalysis( - reporter, - sampling_method=sampler_settings.sampler_method.lower(), - result_units=offunit.kilocalorie_per_mole, - ) - analyzer.plot(filepath=shared_basepath, filename_prefix="") - analyzer.close() - - else: - # clean up the reporter file - fns = [ - shared_basepath / output_settings.output_filename, - shared_basepath / output_settings.checkpoint_storage_filename, - ] - for fn in fns: - os.remove(fn) + # Get settings + settings = self._get_settings(self._inputs["protocol"].settings) + + # Get components + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + mapping = self._inputs["ligandmapping"] + alchem_comps, solvent_comp, protein_comp, small_mols = self._get_components(stateA, stateB) + + # Assign partial charges now to avoid any discrepancies later + self._assign_partial_charges(settings["charge_settings"], small_mols) + + ( + stateA_system, + stateA_topology, + stateA_positions, + stateB_system, + stateB_topology, + stateB_positions, + system_mappings, + ) = self._get_omm_objects( + stateA=stateA, + stateB=stateB, + mapping=mapping, + settings=settings, + protein_component=protein_comp, + solvent_component=solvent_comp, + small_mols=small_mols, + ) + + # Get the hybrid factory & system + hybrid_factory, hybrid_system = self._get_alchemical_system( + stateA_system=stateA_system, + stateA_positions=stateA_positions, + stateA_topology=stateA_topology, + stateB_system=stateB_system, + stateB_positions=stateB_positions, + stateB_topology=stateB_topology, + system_mappings=system_mappings, + alchemical_settings=settings["alchemical_settings"], + ) + + # Subselect system based on user inputs & write initial PDB + selection_indices = self._subsample_topology( + hybrid_topology=hybrid_factory.hybrid_topology, + hybrid_positions=hybrid_factory.hybrid_positions, + output_selection=settings["output_settings"].output_indices, + output_filename=settings["output_settings"].output_structure, + atom_classes=hybrid_factory._atom_classes, + ) + + # Get the lambda schedule + # TODO - this should be better exposed to users + lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( + functions=settings["lambda_settings"].lambda_functions, + windows=settings["lambda_settings"].lambda_windows, + ) + + # Get the compute platform + restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" + platform = omm_compute.get_openmm_platform( + platform_name=settings["engine_settings"].compute_platform, + gpu_device_index=settings["engine_settings"].gpu_device_index, + restrict_cpu_count=restrict_cpu, + ) + + # Get the integrator + integrator = self._get_integrator( + integrator_settings=settings["integrator_settings"], + simulation_settings=settings["simulation_settings"], + system=hybrid_system, + ) + + try: + # get the reporter + reporter = self._get_reporter( + storage_path=self.shared_basepath, + selection_indices=selection_indices, + output_settings=settings["output_settings"], + simulation_settings=settings["simulation_settings"], + ) + + # Get sampler + sampler = self._get_sampler( + system=hybrid_system, + positions=hybrid_factory.hybrid_positions, + lambdas=lambdas, + integrator=integrator, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + thermo_settings=settings["thermo_settings"], + alchem_settings=settings["alchemical_settings"], + platform=platform, + dry=dry, + ) + + unit_results_dict = self._run_simulation( + sampler=sampler, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + integrator_settings=settings["integrator_settings"], + output_settings=settings["output_settings"], + dry=dry, + ) finally: # close reporter when you're done, prevent # file handle clashes @@ -608,23 +1209,33 @@ def run( # clear GPU contexts # TODO: use cache.empty() calls when openmmtools #690 is resolved # replace with above - for context in list(energy_context_cache._lru._data.keys()): - del energy_context_cache._lru._data[context] - for context in list(sampler_context_cache._lru._data.keys()): - del sampler_context_cache._lru._data[context] + for context in list(sampler.energy_context_cache._lru._data.keys()): + del sampler.energy_context_cache._lru._data[context] + for context in list(sampler.sampler_context_cache._lru._data.keys()): + del sampler.sampler_context_cache._lru._data[context] # cautiously clear out the global context cache too for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] - del sampler_context_cache, energy_context_cache + del sampler.sampler_context_cache, sampler.energy_context_cache if not dry: del integrator, sampler if not dry: # pragma: no-cover - return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} + nc = self.shared_basepath / settings["output_settings"].output_filename + chk = settings["output_settings"].checkpoint_storage_filename + unit_results_dict["nc"] = nc + unit_results_dict["last_checkpoint"] = chk + unit_results_dict["selection_indices"] = selection_indices + return unit_results_dict else: - return {"debug": {"sampler": sampler, "hybrid_factory": hybrid_factory}} + return { + "debug": { + "sampler": sampler, + "hybrid_factory": hybrid_factory, + } + } @staticmethod def structural_analysis(scratch, shared) -> dict: