From 43d60073540a84f2338709932e0de9d4450c5fe0 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 5 Dec 2025 13:18:52 +0000 Subject: [PATCH 01/36] Migrate validation to Protocol._validate --- .../protocols/openmm_rfe/equil_rfe_methods.py | 434 +++++++++++------- 1 file changed, 260 insertions(+), 174 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 514237634..f272631b8 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -117,161 +117,6 @@ def _get_resname(off_mol) -> str: return names[0] -def _get_alchemical_charge_difference( - mapping: LigandAtomMapping, - nonbonded_method: str, - explicit_charge_correction: bool, - solvent_component: SolventComponent, -) -> int: - """ - Checks and returns the difference in formal charge between state A and B. - - Raises - ------ - ValueError - * If an explicit charge correction is attempted and the - nonbonded method is not PME. - * If the absolute charge difference is greater than one - and an explicit charge correction is attempted. - UserWarning - If there is any charge difference. - - Parameters - ---------- - mapping : dict[str, ComponentMapping] - Dictionary of mappings between transforming components. - nonbonded_method : str - The OpenMM nonbonded method used for the simulation. - explicit_charge_correction : bool - Whether or not to use an explicit charge correction. - solvent_component : openfe.SolventComponent - The SolventComponent of the simulation. - - Returns - ------- - int - The formal charge difference between states A and B. - This is defined as sum(charge state A) - sum(charge state B) - """ - - difference = mapping.get_alchemical_charge_difference() - - if abs(difference) > 0: - if explicit_charge_correction: - if nonbonded_method.lower() != "pme": - errmsg = "Explicit charge correction when not using PME is not currently supported." - raise ValueError(errmsg) - if abs(difference) > 1: - errmsg = ( - f"A charge difference of {difference} is observed " - "between the end states and an explicit charge " - "correction has been requested. Unfortunately " - "only absolute differences of 1 are supported." - ) - raise ValueError(errmsg) - - ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[ - difference - ] - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - logger.warning(wmsg) - warnings.warn(wmsg) - else: - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. No charge correction has " - "been requested, please account for this in your " - "final results." - ) - logger.warning(wmsg) - warnings.warn(wmsg) - - return difference - - -def _validate_alchemical_components( - alchemical_components: dict[str, list[Component]], - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], -): - """ - Checks that the alchemical components are suitable for the RFE protocol. - - Specifically we check: - 1. That all alchemical components are mapped. - 2. That all alchemical components are SmallMoleculeComponents. - 3. If the mappings involves element changes in core atoms - - Parameters - ---------- - alchemical_components : dict[str, list[Component]] - Dictionary contatining the alchemical components for - states A and B. - mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] - all mappings between transforming components. - - Raises - ------ - ValueError - * If there are more than one mapping or mapping is None - * If there are any unmapped alchemical components. - * If there are any alchemical components that are not - SmallMoleculeComponents. - UserWarning - * Mappings which involve element changes in core atoms - """ - if isinstance(mapping, ComponentMapping): - mapping = [mapping] - # Check mapping - # For now we only allow for a single mapping, this will likely change - if mapping is None or len(mapping) != 1: - errmsg = "A single LigandAtomMapping is expected for this Protocol" - raise ValueError(errmsg) - - # Check that all alchemical components are mapped & small molecules - mapped = { - "stateA": [m.componentA for m in mapping], - "stateB": [m.componentB for m in mapping], - } - - for idx in ["stateA", "stateB"]: - if len(alchemical_components[idx]) != len(mapped[idx]): - errmsg = f"missing alchemical components in {idx}" - raise ValueError(errmsg) - for comp in alchemical_components[idx]: - if comp not in mapped[idx]: - raise ValueError(f"Unmapped alchemical component {comp}") - if not isinstance(comp, SmallMoleculeComponent): # pragma: no-cover - errmsg = ( - "Transformations involving non " - "SmallMoleculeComponent species {comp} " - "are not currently supported" - ) - raise ValueError(errmsg) - - # Validate element changes in mappings - for m in mapping: - molA = m.componentA.to_rdkit() - molB = m.componentB.to_rdkit() - for i, j in m.componentA_to_componentB.items(): - atomA = molA.GetAtomWithIdx(i) - atomB = molB.GetAtomWithIdx(j) - if atomA.GetAtomicNum() != atomB.GetAtomicNum(): - wmsg = ( - f"Element change in mapping between atoms " - f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " - f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" - "No mass scaling is attempted in the hybrid topology, " - "the average mass of the two atoms will be used in the " - "simulation" - ) - logger.warning(wmsg) - warnings.warn(wmsg) # TODO: remove this once logging is fixed - - class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): """Dict-like container for the output of a RelativeHybridTopologyProtocol""" @@ -612,21 +457,204 @@ def _adaptive_settings( return protocol_settings - def _create( + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the end states for the RFE protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If either state contains more than one unique Component. + * If unique components are not SmallMoleculeComponents. + """ + # Get the difference in Components between each state + diff = stateA.component_diff(stateB) + + for i, entry in enumerate(diff): + state_label = "A" if i == 0 else "B" + + # Check that there is only one unique Component in each state + if len(entry) != 0: + errmsg = ( + "Only one alchemical component is allowed per end state. " + f"Found {len(entry)} in state {state_label}." + ) + raise ValueError(errmsg) + + # Check that the unique Component is a SmallMoleculeComponent + if not isinstance(entry[0], SmallMoleculeComponent): + errmsg = ( + f"Alchemical component in state {state_label} is of type " + f"{type(entry[0])}, but only SmallMoleculeComponents " + "transformations are currently supported." + ) + raise ValueError(errmsg) + + @staticmethod + def _validate_mapping( + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], + alchemical_components: dict[str, list[Component]], + ) -> None: + """ + Validates that the provided mapping(s) are suitable for the RFE protocol. + + Parameters + ---------- + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] + all mappings between transforming components. + alchemical_components : dict[str, list[Component]] + Dictionary contatining the alchemical components for + states A and B. + + Raises + ------ + ValueError + * If there are more than one mapping or mapping is None + * If the mapping components are not in the alchemical components. + UserWarning + * Mappings which involve element changes in core atoms + """ + # if a single mapping is provided, convert to list + if isinstance(mapping, ComponentMapping): + mapping = [mapping] + + # For now we only support a single mapping + if mapping is None or len(mapping) > 1: + errmsg = "A single LigandAtomMapping is expected for this Protocol" + raise ValueError(errmsg) + + # check that the mapping components are in the alchemical components + for m in mapping: + if m.componentA not in alchemical_components["stateA"]: + raise ValueError(f"Mapping componentA {m.componentA} not in alchemical components of stateA") + if m.componentB not in alchemical_components["stateB"]: + raise ValueError(f"Mapping componentB {m.componentB} not in alchemical components of stateB") + + # TODO: remove - this is now the default behaviour? + # Check for element changes in mappings + for m in mapping: + molA = m.componentA.to_rdkit() + molB = m.componentB.to_rdkit() + for i, j in m.componentA_to_componentB.items(): + atomA = molA.GetAtomWithIdx(i) + atomB = molB.GetAtomWithIdx(j) + if atomA.GetAtomicNum() != atomB.GetAtomicNum(): + wmsg = ( + f"Element change in mapping between atoms " + f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " + f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" + "No mass scaling is attempted in the hybrid topology, " + "the average mass of the two atoms will be used in the " + "simulation" + ) + logger.warning(wmsg) + warnings.warn(wmsg) + + @staticmethod + def _validate_charge_difference( + mapping: LigandAtomMapping, + nonbonded_method: str, + explicit_charge_correction: bool, + solvent_component: SolventComponent | None, + ): + """ + Validates the net charge difference between the two states. + + Parameters + ---------- + mapping : dict[str, ComponentMapping] + Dictionary of mappings between transforming components. + nonbonded_method : str + The OpenMM nonbonded method used for the simulation. + explicit_charge_correction : bool + Whether or not to use an explicit charge correction. + solvent_component : openfe.SolventComponent | None + The SolventComponent of the simulation. + + Raises + ------ + ValueError + * If an explicit charge correction is attempted and the + nonbonded method is not PME. + * If the absolute charge difference is greater than one + and an explicit charge correction is attempted. + UserWarning + * If there is any charge difference. + """ + difference = mapping.get_alchemical_charge_difference() + + if abs(difference) == 0: + return + + if not explicit_charge_correction: + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. No charge correction has " + "been requested, please account for this in your " + "final results." + ) + logger.warning(wmsg) + warnings.warn(wmsg) + return + + # We implicitly check earlier that we have to have pme for a solvated + # system, so we only need to check the nonbonded method here + if nonbonded_method.lower() != "pme": + errmsg = "Explicit charge correction when not using PME is not currently supported." + raise ValueError(errmsg) + + if abs(difference) > 1: + errmsg = ( + f"A charge difference of {difference} is observed " + "between the end states and an explicit charge " + "correction has been requested. Unfortunately " + "only absolute differences of 1 are supported." + ) + raise ValueError(errmsg) + + ion = { + -1: solvent_component.positive_ion, + 1: solvent_component.negative_ion + }[difference] + + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + logger.info(wmsg) + + def _validate( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], - extends: Optional[gufe.ProtocolDAGResult] = None, - ) -> list[gufe.ProtocolUnit]: - # TODO: Extensions? + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> None: + # Check we're not trying to extend if extends: - raise NotImplementedError("Can't extend simulations yet") + # This technically should be NotImplementedError + # but gufe.Protocol.validate calls `_validate` wrapped around an + # except for NotImplementedError, so we can't raise it here + raise ValueError("Can't extend simulations yet") - # Get alchemical components & validate them + mapping + # Validate the end states + self._validate_endstates(stateA, stateB) + + # Valildate the mapping alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - _validate_alchemical_components(alchem_comps, mapping) - ligandmapping = mapping[0] if isinstance(mapping, list) else mapping + self._validate_mapping(mapping, alchem_comps) # Validate solvent component nonbond = self.settings.forcefield_settings.nonbonded_method @@ -638,11 +666,78 @@ def _create( # Validate protein component system_validation.validate_protein(stateA) + # Validate charge difference + # Note: validation depends on the mapping & solvent component checks + if stateA.contains(SolventComponent): + solv_comp = stateA.get_components_of_type(SolventComponent)[0] + else: + solv_comp = None + + self._validate_charge_difference( + mapping=mapping[0] if isinstance(mapping, list) else mapping, + nonbonded_method=self.settings.forcefield_settings.nonbonded_method, + explicit_charge_correction=self.settings.alchemical_settings.explicit_charge_correction, + solvent_component=solv_comp, + ) + + # Validate integrator things + settings_validation.validate_timestep( + self.settings.forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + _ = settings_validation.convert_steps_per_iteration( + simulation_settings=self.settings.simulation_settings, + integrator_settings=self.settings.integrator_settings, + ) + + _ = settings_validation.get_simsteps( + sim_length=self.settings.simulation_settings.equilibration_length, + timestep=self.settings.integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.get_simsteps( + sim_length=self.settings.simulation_settings.production_length, + timestep=self.settings.integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=self.settings.output_settings.checkpoint_interval, + time_per_iteration=self.settings.simulation_settings.time_per_iteration, + ) + + # Validate alchemical settings + # PR #125 temporarily pin lambda schedule spacing to n_replicas + if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.n_windows: + errmsg = ( + "Number of replicas in simulation_settings must equal " + "number of lambda windows in lambda_settings." + ) + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # validate inputs + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # get alchemical components and mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + ligandmapping = mapping[0] if isinstance(mapping, list) else mapping + # actually create and return Units Anames = ",".join(c.name for c in alchem_comps["stateA"]) Bnames = ",".join(c.name for c in alchem_comps["stateB"]) + # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats + units = [ RelativeHybridTopologyProtocolUnit( protocol=self, @@ -816,10 +911,6 @@ def run( output_settings: MultiStateOutputSettings = protocol_settings.output_settings integrator_settings: IntegratorSettings = protocol_settings.integrator_settings - # is the timestep good for the mass? - settings_validation.validate_timestep( - forcefield_settings.hydrogen_mass, integrator_settings.timestep - ) # TODO: Also validate various conversions? # Convert various time based inputs to steps/iterations steps_per_iteration = settings_validation.convert_steps_per_iteration( @@ -842,12 +933,7 @@ def run( # Get the change difference between the end states # and check if the charge correction used is appropriate - charge_difference = _get_alchemical_charge_difference( - mapping, - forcefield_settings.nonbonded_method, - alchem_settings.explicit_charge_correction, - solvent_comp, - ) + charge_difference = mapping.get_alchemical_charge_difference() # 1. Create stateA system self.logger.info("Parameterizing molecules") From 2cd56ba75a7f5b602521a94def4a72a11750dee7 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 5 Dec 2025 13:25:43 +0000 Subject: [PATCH 02/36] some fixes --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 6 +++--- .../tests/protocols/openmm_rfe/test_hybrid_top_protocol.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index f272631b8..be6c1d39a 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -485,7 +485,7 @@ def _validate_endstates( state_label = "A" if i == 0 else "B" # Check that there is only one unique Component in each state - if len(entry) != 0: + if len(entry) != 1: errmsg = ( "Only one alchemical component is allowed per end state. " f"Found {len(entry)} in state {state_label}." @@ -686,7 +686,7 @@ def _validate( self.settings.integrator_settings.timestep, ) - _ = settings_validation.convert_steps_per_iteration( + steps_per_iteration = settings_validation.convert_steps_per_iteration( simulation_settings=self.settings.simulation_settings, integrator_settings=self.settings.integrator_settings, ) @@ -710,7 +710,7 @@ def _validate( # Validate alchemical settings # PR #125 temporarily pin lambda schedule spacing to n_replicas - if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.n_windows: + if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.lambda_windows: errmsg = ( "Number of replicas in simulation_settings must equal " "number of lambda windows in lambda_settings." diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index f5ea92cff..1f178991a 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -30,10 +30,6 @@ from openfe import setup from openfe.protocols import openmm_rfe from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers -from openfe.protocols.openmm_rfe.equil_rfe_methods import ( - _get_alchemical_charge_difference, - _validate_alchemical_components, -) from openfe.protocols.openmm_utils import omm_compute, system_creation from openfe.protocols.openmm_utils.charge_generation import ( HAS_ESPALOMA_CHARGE, From f3305622f3d147bd22464bb8726cb12551c41c4a Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 8 Dec 2025 13:34:41 +0000 Subject: [PATCH 03/36] move some things around --- .../protocols/openmm_rfe/equil_rfe_methods.py | 74 +++- .../openmm_rfe/test_hybrid_top_validation.py | 391 ++++++++++++++++++ 2 files changed, 445 insertions(+), 20 deletions(-) create mode 100644 openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index be6c1d39a..59e33ef52 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -635,6 +635,55 @@ def _validate_charge_difference( ) logger.info(wmsg) + @staticmethod + def _validate_simulation_settings( + simulation_settings, + integrator_settings, + output_settings, + ): + + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings=simulation_settings, + integrator_settings=integrator_settings, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + if output_settings.positions_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=sampler_settings.time_per_iteration, + numerator_name="output settings' position_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + if output_settings.velocities_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=sampler_settings.time_per_iteration, + numerator_name="output settings' velocity_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + _, _ = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=sampler_settings, + ) + def _validate( self, stateA: ChemicalSystem, @@ -686,26 +735,11 @@ def _validate( self.settings.integrator_settings.timestep, ) - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings=self.settings.simulation_settings, - integrator_settings=self.settings.integrator_settings, - ) - - _ = settings_validation.get_simsteps( - sim_length=self.settings.simulation_settings.equilibration_length, - timestep=self.settings.integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - - _ = settings_validation.get_simsteps( - sim_length=self.settings.simulation_settings.production_length, - timestep=self.settings.integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - - _ = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=self.settings.output_settings.checkpoint_interval, - time_per_iteration=self.settings.simulation_settings.time_per_iteration, + # Validate simulation & output settings + self._validate_simulation_settings( + self.settings.simulation_settings, + self.settings.integrator_settings, + self.settings.output_settings, ) # Validate alchemical settings diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py new file mode 100644 index 000000000..62c8e5d23 --- /dev/null +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -0,0 +1,391 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import copy +import json +import sys +import xml.etree.ElementTree as ET +from importlib import resources +from math import sqrt +from pathlib import Path +from unittest import mock + +import gufe +import mdtraj as mdt +import numpy as np +import pytest +from kartograf import KartografAtomMapper +from kartograf.atom_aligner import align_mol_shape +from numpy.testing import assert_allclose +from openff.toolkit import Molecule +from openff.units import unit +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmm import CustomNonbondedForce, MonteCarloBarostat, NonbondedForce, XmlSerializer, app +from openmm import unit as omm_unit +from openmmforcefields.generators import SMIRNOFFTemplateGenerator +from openmmtools.multistate.multistatesampler import MultiStateSampler +from rdkit import Chem +from rdkit.Geometry import Point3D + +import openfe +from openfe import setup +from openfe.protocols import openmm_rfe +from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers +from openfe.protocols.openmm_utils import omm_compute, system_creation +from openfe.protocols.openmm_utils.charge_generation import ( + HAS_ESPALOMA_CHARGE, + HAS_NAGL, + HAS_OPENEYE, +) + + +@pytest.fixture() +def vac_settings(): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + settings.forcefield_settings.nonbonded_method = "nocutoff" + settings.engine_settings.compute_platform = None + settings.protocol_repeats = 1 + return settings + + +@pytest.fixture() +def solv_settings(): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + settings.engine_settings.compute_platform = None + settings.protocol_repeats = 1 + return settings + + +def test_invalid_protocol_repeats(): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + with pytest.raises(ValueError, match="must be a positive value"): + settings.protocol_repeats = -1 + + +@pytest.mark.parametrize( + "mapping", + [None, [], ["A", "B"]], +) +def test_validate_alchemical_components_wrong_mappings(mapping): + with pytest.raises(ValueError, match="A single LigandAtomMapping"): + _validate_alchemical_components({"stateA": [], "stateB": []}, mapping) + + +def test_validate_alchemical_components_missing_alchem_comp(benzene_to_toluene_mapping): + alchem_comps = {"stateA": [openfe.SolventComponent()], "stateB": []} + with pytest.raises(ValueError, match="Unmapped alchemical component"): + _validate_alchemical_components(alchem_comps, benzene_to_toluene_mapping) + + +def test_hightimestep( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, + tmpdir, +): + vac_settings.forcefield_settings.hydrogen_mass = 1.0 + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=vac_settings, + ) + + dag = p.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + ) + dag_unit = list(dag.protocol_units)[0] + + errmsg = "too large for hydrogen mass" + with tmpdir.as_cwd(): + with pytest.raises(ValueError, match=errmsg): + dag_unit.run(dry=True) + + +def test_n_replicas_not_n_windows( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, + tmpdir, +): + # For PR #125 we pin such that the number of lambda windows + # equals the numbers of replicas used - TODO: remove limitation + # default lambda windows is 11 + vac_settings.simulation_settings.n_replicas = 13 + + errmsg = "Number of replicas 13 does not equal the number of lambda windows 11" + + with tmpdir.as_cwd(): + with pytest.raises(ValueError, match=errmsg): + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=vac_settings, + ) + dag = p.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + ) + dag_unit = list(dag.protocol_units)[0] + dag_unit.run(dry=True) + + +def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): + # state B doesn't have a ligand component + stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + + match_str = "missing alchemical components in stateB" + with pytest.raises(ValueError, match=match_str): + _ = p.create( + stateA=benzene_system, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_vaccuum_PME_error( + benzene_vacuum_system, benzene_modifications, benzene_to_toluene_mapping +): + # state B doesn't have a solvent component (i.e. its vacuum) + stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"]}) + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + errmsg = "PME cannot be used for vacuum transform" + with pytest.raises(ValueError, match=errmsg): + _ = p.create( + stateA=benzene_vacuum_system, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_incompatible_solvent(benzene_system, benzene_modifications, benzene_to_toluene_mapping): + # the solvents are different + stateB = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent(positive_ion="K", negative_ion="Cl"), + } + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + # We don't have a way to map non-ligand components so for now it + # just triggers that it's not a mapped component + errmsg = "missing alchemical components in stateA" + with pytest.raises(ValueError, match=errmsg): + _ = p.create( + stateA=benzene_system, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_mapping_mismatch_A(benzene_system, toluene_system, benzene_modifications): + # the atom mapping doesn't refer to the ligands in the systems + mapping = setup.LigandAtomMapping( + componentA=benzene_system.components["ligand"], + componentB=benzene_modifications["phenol"], + componentA_to_componentB=dict(), + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + errmsg = ( + r"Unmapped alchemical component " + r"SmallMoleculeComponent\(name=toluene\)" + ) + with pytest.raises(ValueError, match=errmsg): + _ = p.create( + stateA=benzene_system, + stateB=toluene_system, + mapping=mapping, + ) + + +def test_mapping_mismatch_B(benzene_system, toluene_system, benzene_modifications): + mapping = setup.LigandAtomMapping( + componentA=benzene_modifications["phenol"], + componentB=toluene_system.components["ligand"], + componentA_to_componentB=dict(), + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + errmsg = ( + r"Unmapped alchemical component " + r"SmallMoleculeComponent\(name=benzene\)" + ) + with pytest.raises(ValueError, match=errmsg): + _ = p.create( + stateA=benzene_system, + stateB=toluene_system, + mapping=mapping, + ) + + +def test_complex_mismatch(benzene_system, toluene_complex_system, benzene_to_toluene_mapping): + # only one complex + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + with pytest.raises(ValueError): + _ = p.create( + stateA=benzene_system, + stateB=toluene_complex_system, + mapping=benzene_to_toluene_mapping, + ) + + +def test_too_many_specified_mappings(benzene_system, toluene_system, benzene_to_toluene_mapping): + # mapping dict requires 'ligand' key + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + errmsg = "A single LigandAtomMapping is expected for this Protocol" + with pytest.raises(ValueError, match=errmsg): + _ = p.create( + stateA=benzene_system, + stateB=toluene_system, + mapping=[benzene_to_toluene_mapping, benzene_to_toluene_mapping], + ) + + +def test_protein_mismatch( + benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping +): + # hack one protein to be labelled differently + prot = toluene_complex_system["protein"] + alt_prot = openfe.ProteinComponent(prot.to_rdkit(), name="Mickey Mouse") + alt_toluene_complex_system = openfe.ChemicalSystem( + { + "ligand": toluene_complex_system["ligand"], + "solvent": toluene_complex_system["solvent"], + "protein": alt_prot, + } + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + with pytest.raises(ValueError): + _ = p.create( + stateA=benzene_complex_system, + stateB=alt_toluene_complex_system, + mapping=benzene_to_toluene_mapping, + ) + + +def test_element_change_warning(atom_mapping_basic_test_files): + # check a mapping with element change gets rejected early + l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] + l2 = atom_mapping_basic_test_files["2-naftanol"] + + # We use the 'old' lomap defaults because the + # basic test files inputs we use aren't fully aligned + mapper = setup.LomapAtomMapper( + time=20, threed=True, max3d=1000.0, element_change=True, seed="", shift=True + ) + + mapping = next(mapper.suggest_mappings(l1, l2)) + + sys1 = openfe.ChemicalSystem( + {"ligand": l1, "solvent": openfe.SolventComponent()}, + ) + sys2 = openfe.ChemicalSystem( + {"ligand": l2, "solvent": openfe.SolventComponent()}, + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol( + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + ) + with pytest.warns(UserWarning, match="Element change"): + _ = p.create( + stateA=sys1, + stateB=sys2, + mapping=mapping, + ) + + +@pytest.mark.parametrize( + "mapping_name,result", + [ + ["benzene_to_toluene_mapping", 0], + ["benzene_to_benzoic_mapping", 1], + ["benzene_to_aniline_mapping", -1], + ["aniline_to_benzene_mapping", 1], + ], +) +def test_get_charge_difference(mapping_name, result, request): + mapping = request.getfixturevalue(mapping_name) + if result != 0: + ion = r"Na\+" if result == -1 else r"Cl\-" + wmsg = ( + f"A charge difference of {result} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + with pytest.warns(UserWarning, match=wmsg): + val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) + assert result == pytest.approx(val) + else: + val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) + assert result == pytest.approx(val) + + +def test_get_charge_difference_no_pme(benzene_to_benzoic_mapping): + errmsg = "Explicit charge correction when not using PME" + with pytest.raises(ValueError, match=errmsg): + _get_alchemical_charge_difference( + benzene_to_benzoic_mapping, + "nocutoff", + True, + openfe.SolventComponent(), + ) + + +def test_get_charge_difference_no_corr(benzene_to_benzoic_mapping): + wmsg = ( + "A charge difference of 1 is observed between the end states. " + "No charge correction has been requested" + ) + with pytest.warns(UserWarning, match=wmsg): + _get_alchemical_charge_difference( + benzene_to_benzoic_mapping, + "pme", + False, + openfe.SolventComponent(), + ) + + +def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): + errmsg = "A charge difference of 2" + with pytest.raises(ValueError, match=errmsg): + _get_alchemical_charge_difference( + aniline_to_benzoic_mapping, + "pme", + True, + openfe.SolventComponent(), + ) + + +def test_get_alchemical_waters_no_waters( + benzene_solvent_openmm_system, +): + system, topology, positions = benzene_solvent_openmm_system + + errmsg = "There are no waters" + + with pytest.raises(ValueError, match=errmsg): + topologyhelpers.get_alchemical_waters( + topology, positions, charge_difference=1, distance_cutoff=3.0 * unit.nanometer + ) From 1e0153e949fbeb4a947ffb7beafd614eefaebc80 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 15 Dec 2025 10:26:58 -1000 Subject: [PATCH 04/36] add validate endstate tests --- .../openmm_rfe/test_hybrid_top_validation.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 62c8e5d23..50a518606 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -61,6 +61,51 @@ def test_invalid_protocol_repeats(): settings.protocol_repeats = -1 +@pytest.mark.parametrize('state', ['A', 'B']) +def test_endstate_two_alchemcomp_stateA(state, benzene_modifications): + first_state = openfe.ChemicalSystem({ + 'ligandA': benzene_modifications['benzene'], + 'ligandB': benzene_modifications['toluene'], + 'solvent': openfe.SolventComponent(), + }) + other_state = openfe.ChemicalSystem({ + 'ligandC': benzene_modifications['phenol'], + 'solvent': openfe.SolventComponent(), + }) + + if state == 'A': + args = (first_state, other_state) + else: + args = (other_state, first_state) + + with pytest.raises(ValueError, match="Only one alchemical component"): + openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates( + *args + ) + +@pytest.mark.parametrize('state', ['A', 'B']) +def test_endstates_not_smc(state, benzene_modifications): + first_state = openfe.ChemicalSystem({ + 'ligand': benzene_modifications['benzene'], + 'foo': openfe.SolventComponent(), + }) + other_state = openfe.ChemicalSystem({ + 'ligand': benzene_modifications['benzene'], + 'foo': benzene_modifications['toluene'], + }) + + if state == 'A': + args = (first_state, other_state) + else: + args = (other_state, first_state) + + errmsg = "only SmallMoleculeComponents transformations" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates( + *args + ) + + @pytest.mark.parametrize( "mapping", [None, [], ["A", "B"]], From fbc455416c5b3eb9683c31515deec80b2d664da4 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 15 Dec 2025 10:52:30 -1000 Subject: [PATCH 05/36] validate mapping tests --- .../openmm_rfe/test_hybrid_top_validation.py | 84 ++++++++++++------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 50a518606..241be3ebd 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -106,6 +106,59 @@ def test_endstates_not_smc(state, benzene_modifications): ) +def test_validate_mapping_none_mapping(): + errmsg = "A single LigandAtomMapping is expected" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping(None, None) + + +def test_validate_mapping_multi_mapping(benzene_to_toluene_mapping): + errmsg = "A single LigandAtomMapping is expected" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping( + [benzene_to_toluene_mapping] * 2, + None + ) + + +@pytest.mark.parametrize('state', ['A', 'B']) +def test_validate_mapping_alchem_not_in(state, benzene_to_toluene_mapping): + errmsg = f"not in alchemical components of state{state}" + + if state == "A": + alchem_comps = {"stateA": [], "stateB": [benzene_to_toluene_mapping.componentB]} + else: + alchem_comps = {"stateA": [benzene_to_toluene_mapping.componentA], "stateB": []} + + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping( + [benzene_to_toluene_mapping], + alchem_comps, + ) + + +def test_element_change_warning(atom_mapping_basic_test_files): + # check a mapping with element change gets rejected early + l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] + l2 = atom_mapping_basic_test_files["2-naftanol"] + + # We use the 'old' lomap defaults because the + # basic test files inputs we use aren't fully aligned + mapper = setup.LomapAtomMapper( + time=20, threed=True, max3d=1000.0, element_change=True, seed="", shift=True + ) + + mapping = next(mapper.suggest_mappings(l1, l2)) + + alchem_comps = {"stateA": [l1], "stateB": [l2]} + + with pytest.warns(UserWarning, match="Element change"): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping( + [mapping], + alchem_comps, + ) + + @pytest.mark.parametrize( "mapping", [None, [], ["A", "B"]], @@ -330,37 +383,6 @@ def test_protein_mismatch( ) -def test_element_change_warning(atom_mapping_basic_test_files): - # check a mapping with element change gets rejected early - l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] - l2 = atom_mapping_basic_test_files["2-naftanol"] - - # We use the 'old' lomap defaults because the - # basic test files inputs we use aren't fully aligned - mapper = setup.LomapAtomMapper( - time=20, threed=True, max3d=1000.0, element_change=True, seed="", shift=True - ) - - mapping = next(mapper.suggest_mappings(l1, l2)) - - sys1 = openfe.ChemicalSystem( - {"ligand": l1, "solvent": openfe.SolventComponent()}, - ) - sys2 = openfe.ChemicalSystem( - {"ligand": l2, "solvent": openfe.SolventComponent()}, - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.warns(UserWarning, match="Element change"): - _ = p.create( - stateA=sys1, - stateB=sys2, - mapping=mapping, - ) - - @pytest.mark.parametrize( "mapping_name,result", [ From c2f49d25af7ead72a8f31ad22224bf5510e6b500 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 15 Dec 2025 11:45:04 -1000 Subject: [PATCH 06/36] net charge validation tests --- .../protocols/openmm_rfe/equil_rfe_methods.py | 4 + .../openmm_rfe/test_hybrid_top_validation.py | 186 ++++++++---------- 2 files changed, 83 insertions(+), 107 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 59e33ef52..e20c7360b 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -608,6 +608,10 @@ def _validate_charge_difference( warnings.warn(wmsg) return + if solvent_component is None: + errmsg = "Cannot use eplicit charge correction without solvent" + raise ValueError(errmsg) + # We implicitly check earlier that we have to have pme for a solvated # system, so we only need to check the nonbonded method here if nonbonded_method.lower() != "pme": diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 241be3ebd..f2b64d292 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -1,5 +1,6 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe +import logging import copy import json import sys @@ -159,19 +160,87 @@ def test_element_change_warning(atom_mapping_basic_test_files): ) +def test_charge_difference_no_corr(benzene_to_benzoic_mapping): + wmsg = ( + "A charge difference of 1 is observed between the end states. " + "No charge correction has been requested" + ) + + with pytest.warns(UserWarning, match=wmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + benzene_to_benzoic_mapping, + "pme", + False, + openfe.SolventComponent(), + ) + + +def test_charge_difference_no_solvent(benzene_to_benzoic_mapping): + errmsg = "Cannot use eplicit charge correction without solvent" + + with pytest.raises(ValueError, errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + benzene_to_benzoic_mapping, + "pme", + True, + None, + ) + + +def test_charge_difference_no_pme(benzene_to_benzoic_mapping): + errmsg = "Explicit charge correction when not using PME" + + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + benzene_to_benzoic_mapping, + "nocutoff", + True, + openfe.SolventComponent(), + ) + + +def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): + errmsg = "A charge difference of 2" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + aniline_to_benzoic_mapping, + "pme", + True, + openfe.SolventComponent(), + ) + + @pytest.mark.parametrize( - "mapping", - [None, [], ["A", "B"]], + "mapping_name,result", + [ + ["benzene_to_toluene_mapping", 0], + ["benzene_to_benzoic_mapping", 1], + ["benzene_to_aniline_mapping", -1], + ["aniline_to_benzene_mapping", 1], + ], ) -def test_validate_alchemical_components_wrong_mappings(mapping): - with pytest.raises(ValueError, match="A single LigandAtomMapping"): - _validate_alchemical_components({"stateA": [], "stateB": []}, mapping) - +def test_get_charge_difference(mapping_name, result, request, caplog): + mapping = request.getfixturevalue(mapping_name) + caplog.set_level(logging.INFO) + + ion = r"Na\+" if result == -1 else r"Cl\-" + msg = ( + f"A charge difference of {result} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + mapping, + "pme", + True, + openfe.SolventComponent() + ) -def test_validate_alchemical_components_missing_alchem_comp(benzene_to_toluene_mapping): - alchem_comps = {"stateA": [openfe.SolventComponent()], "stateB": []} - with pytest.raises(ValueError, match="Unmapped alchemical component"): - _validate_alchemical_components(alchem_comps, benzene_to_toluene_mapping) + if result != 0: + assert msg in caplog.text + else: + assert msg not in caplog.text def test_hightimestep( @@ -286,78 +355,6 @@ def test_incompatible_solvent(benzene_system, benzene_modifications, benzene_to_ ) -def test_mapping_mismatch_A(benzene_system, toluene_system, benzene_modifications): - # the atom mapping doesn't refer to the ligands in the systems - mapping = setup.LigandAtomMapping( - componentA=benzene_system.components["ligand"], - componentB=benzene_modifications["phenol"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=toluene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_mapping_mismatch_B(benzene_system, toluene_system, benzene_modifications): - mapping = setup.LigandAtomMapping( - componentA=benzene_modifications["phenol"], - componentB=toluene_system.components["ligand"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=benzene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_complex_mismatch(benzene_system, toluene_complex_system, benzene_to_toluene_mapping): - # only one complex - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_system, - stateB=toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - -def test_too_many_specified_mappings(benzene_system, toluene_system, benzene_to_toluene_mapping): - # mapping dict requires 'ligand' key - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = "A single LigandAtomMapping is expected for this Protocol" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=[benzene_to_toluene_mapping, benzene_to_toluene_mapping], - ) - - def test_protein_mismatch( benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping ): @@ -409,31 +406,6 @@ def test_get_charge_difference(mapping_name, result, request): assert result == pytest.approx(val) -def test_get_charge_difference_no_pme(benzene_to_benzoic_mapping): - errmsg = "Explicit charge correction when not using PME" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "nocutoff", - True, - openfe.SolventComponent(), - ) - - -def test_get_charge_difference_no_corr(benzene_to_benzoic_mapping): - wmsg = ( - "A charge difference of 1 is observed between the end states. " - "No charge correction has been requested" - ) - with pytest.warns(UserWarning, match=wmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "pme", - False, - openfe.SolventComponent(), - ) - - def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): errmsg = "A charge difference of 2" with pytest.raises(ValueError, match=errmsg): From c50f99cad7daf09272bd81033d080d10cc5193af Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 22 Dec 2025 10:23:47 -1000 Subject: [PATCH 07/36] more stuff --- .../openmm_rfe/test_hybrid_top_validation.py | 102 ------------------ 1 file changed, 102 deletions(-) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index f2b64d292..d0452eb9a 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -297,23 +297,6 @@ def test_n_replicas_not_n_windows( dag_unit.run(dry=True) -def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): - # state B doesn't have a ligand component - stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - - match_str = "missing alchemical components in stateB" - with pytest.raises(ValueError, match=match_str): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - def test_vaccuum_PME_error( benzene_vacuum_system, benzene_modifications, benzene_to_toluene_mapping ): @@ -332,91 +315,6 @@ def test_vaccuum_PME_error( ) -def test_incompatible_solvent(benzene_system, benzene_modifications, benzene_to_toluene_mapping): - # the solvents are different - stateB = openfe.ChemicalSystem( - { - "ligand": benzene_modifications["toluene"], - "solvent": openfe.SolventComponent(positive_ion="K", negative_ion="Cl"), - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - # We don't have a way to map non-ligand components so for now it - # just triggers that it's not a mapped component - errmsg = "missing alchemical components in stateA" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_protein_mismatch( - benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping -): - # hack one protein to be labelled differently - prot = toluene_complex_system["protein"] - alt_prot = openfe.ProteinComponent(prot.to_rdkit(), name="Mickey Mouse") - alt_toluene_complex_system = openfe.ChemicalSystem( - { - "ligand": toluene_complex_system["ligand"], - "solvent": toluene_complex_system["solvent"], - "protein": alt_prot, - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_complex_system, - stateB=alt_toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - -@pytest.mark.parametrize( - "mapping_name,result", - [ - ["benzene_to_toluene_mapping", 0], - ["benzene_to_benzoic_mapping", 1], - ["benzene_to_aniline_mapping", -1], - ["aniline_to_benzene_mapping", 1], - ], -) -def test_get_charge_difference(mapping_name, result, request): - mapping = request.getfixturevalue(mapping_name) - if result != 0: - ion = r"Na\+" if result == -1 else r"Cl\-" - wmsg = ( - f"A charge difference of {result} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - with pytest.warns(UserWarning, match=wmsg): - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - else: - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - - -def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): - errmsg = "A charge difference of 2" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - aniline_to_benzoic_mapping, - "pme", - True, - openfe.SolventComponent(), - ) - - def test_get_alchemical_waters_no_waters( benzene_solvent_openmm_system, ): From 9e0d29be83df2a7f747d6bf385d0853c1d92589e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 24 Dec 2025 00:43:36 -0500 Subject: [PATCH 08/36] remove old tests --- .../openmm_rfe/test_hybrid_top_protocol.py | 260 ------------------ 1 file changed, 260 deletions(-) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index 1f178991a..148f32fe1 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -192,21 +192,6 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t assert len(repeat_ids) == 6 -@pytest.mark.parametrize( - "mapping", - [None, [], ["A", "B"]], -) -def test_validate_alchemical_components_wrong_mappings(mapping): - with pytest.raises(ValueError, match="A single LigandAtomMapping"): - _validate_alchemical_components({"stateA": [], "stateB": []}, mapping) - - -def test_validate_alchemical_components_missing_alchem_comp(benzene_to_toluene_mapping): - alchem_comps = {"stateA": [openfe.SolventComponent()], "stateB": []} - with pytest.raises(ValueError, match="Unmapped alchemical component"): - _validate_alchemical_components(alchem_comps, benzene_to_toluene_mapping) - - @pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"]) def test_dry_run_default_vacuum( benzene_vacuum_system, @@ -989,189 +974,6 @@ def test_hightimestep( dag_unit.run(dry=True) -def test_n_replicas_not_n_windows( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - vac_settings, - tmpdir, -): - # For PR #125 we pin such that the number of lambda windows - # equals the numbers of replicas used - TODO: remove limitation - # default lambda windows is 11 - vac_settings.simulation_settings.n_replicas = 13 - - errmsg = "Number of replicas 13 does not equal the number of lambda windows 11" - - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - dag_unit.run(dry=True) - - -def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): - # state B doesn't have a ligand component - stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - - match_str = "missing alchemical components in stateB" - with pytest.raises(ValueError, match=match_str): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_vaccuum_PME_error( - benzene_vacuum_system, benzene_modifications, benzene_to_toluene_mapping -): - # state B doesn't have a solvent component (i.e. its vacuum) - stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"]}) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = "PME cannot be used for vacuum transform" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_vacuum_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_incompatible_solvent(benzene_system, benzene_modifications, benzene_to_toluene_mapping): - # the solvents are different - stateB = openfe.ChemicalSystem( - { - "ligand": benzene_modifications["toluene"], - "solvent": openfe.SolventComponent(positive_ion="K", negative_ion="Cl"), - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - # We don't have a way to map non-ligand components so for now it - # just triggers that it's not a mapped component - errmsg = "missing alchemical components in stateA" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_mapping_mismatch_A(benzene_system, toluene_system, benzene_modifications): - # the atom mapping doesn't refer to the ligands in the systems - mapping = setup.LigandAtomMapping( - componentA=benzene_system.components["ligand"], - componentB=benzene_modifications["phenol"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=toluene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_mapping_mismatch_B(benzene_system, toluene_system, benzene_modifications): - mapping = setup.LigandAtomMapping( - componentA=benzene_modifications["phenol"], - componentB=toluene_system.components["ligand"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=benzene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_complex_mismatch(benzene_system, toluene_complex_system, benzene_to_toluene_mapping): - # only one complex - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_system, - stateB=toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - -def test_too_many_specified_mappings(benzene_system, toluene_system, benzene_to_toluene_mapping): - # mapping dict requires 'ligand' key - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = "A single LigandAtomMapping is expected for this Protocol" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=[benzene_to_toluene_mapping, benzene_to_toluene_mapping], - ) - - -def test_protein_mismatch( - benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping -): - # hack one protein to be labelled differently - prot = toluene_complex_system["protein"] - alt_prot = openfe.ProteinComponent(prot.to_rdkit(), name="Mickey Mouse") - alt_toluene_complex_system = openfe.ChemicalSystem( - { - "ligand": toluene_complex_system["ligand"], - "solvent": toluene_complex_system["solvent"], - "protein": alt_prot, - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_complex_system, - stateB=alt_toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - def test_element_change_warning(atom_mapping_basic_test_files): # check a mapping with element change gets rejected early l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] @@ -1745,68 +1547,6 @@ def test_filenotfound_replica_states(self, protocolresult): protocolresult.get_replica_states() -@pytest.mark.parametrize( - "mapping_name,result", - [ - ["benzene_to_toluene_mapping", 0], - ["benzene_to_benzoic_mapping", 1], - ["benzene_to_aniline_mapping", -1], - ["aniline_to_benzene_mapping", 1], - ], -) -def test_get_charge_difference(mapping_name, result, request): - mapping = request.getfixturevalue(mapping_name) - if result != 0: - ion = r"Na\+" if result == -1 else r"Cl\-" - wmsg = ( - f"A charge difference of {result} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - with pytest.warns(UserWarning, match=wmsg): - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - else: - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - - -def test_get_charge_difference_no_pme(benzene_to_benzoic_mapping): - errmsg = "Explicit charge correction when not using PME" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "nocutoff", - True, - openfe.SolventComponent(), - ) - - -def test_get_charge_difference_no_corr(benzene_to_benzoic_mapping): - wmsg = ( - "A charge difference of 1 is observed between the end states. " - "No charge correction has been requested" - ) - with pytest.warns(UserWarning, match=wmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "pme", - False, - openfe.SolventComponent(), - ) - - -def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): - errmsg = "A charge difference of 2" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - aniline_to_benzoic_mapping, - "pme", - True, - openfe.SolventComponent(), - ) - - @pytest.fixture(scope="session") def benzene_solvent_openmm_system(benzene_modifications): smc = benzene_modifications["benzene"] From 2fe8ff937683e2d237cddb0b54ea9ace79be1cf7 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 24 Dec 2025 01:49:58 -0500 Subject: [PATCH 09/36] make hybrid samplers not rely on htf --- .../openmm_rfe/_rfe_utils/multistate.py | 61 ++++++++++++------- .../protocols/openmm_rfe/equil_rfe_methods.py | 17 ++++-- .../openmm_rfe/test_hybrid_top_protocol.py | 27 ++++---- 3 files changed, 65 insertions(+), 40 deletions(-) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index 8c6b4eddc..299a846f6 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -26,14 +26,15 @@ logger = logging.getLogger(__name__) -class HybridCompatibilityMixin(object): +class HybridCompatibilityMixin: """ Mixin that allows the MultistateSampler to accommodate the situation where unsampled endpoints have a different number of degrees of freedom. """ - def __init__(self, *args, hybrid_factory=None, **kwargs): - self._hybrid_factory = hybrid_factory + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): + self._hybrid_system = hybrid_system + self._hybrid_positions = hybrid_positions super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) def setup(self, reporter, lambda_protocol, @@ -73,15 +74,17 @@ class creation of LambdaProtocol. """ n_states = len(lambda_protocol.lambda_schedule) - hybrid_system = self._factory.hybrid_system + lambda_zero_state = RelativeAlchemicalState.from_system(self._hybrid_system) - lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system) + thermostate = ThermodynamicState( + self._hybrid_system, + temperature=temperature + ) - thermostate = ThermodynamicState(hybrid_system, - temperature=temperature) compound_thermostate = CompoundThermodynamicState( - thermostate, - composable_states=[lambda_zero_state]) + thermostate, + composable_states=[lambda_zero_state] + ) # create lists for storing thermostates and sampler states thermodynamic_state_list = [] @@ -105,16 +108,20 @@ class creation of LambdaProtocol. raise ValueError(errmsg) # starting with the hybrid factory positions - box = hybrid_system.getDefaultPeriodicBoxVectors() - sampler_state = SamplerState(self._factory.hybrid_positions, - box_vectors=box) + box = self._hybrid_system.getDefaultPeriodicBoxVectors() + sampler_state = SamplerState( + self._hybrid_positions, + box_vectors=box + ) # Loop over the lambdas and create & store a compound thermostate at # that lambda value for lambda_val in lambda_schedule: compound_thermostate_copy = copy.deepcopy(compound_thermostate) compound_thermostate_copy.set_alchemical_parameters( - lambda_val, lambda_protocol) + lambda_val, + lambda_protocol + ) thermodynamic_state_list.append(compound_thermostate_copy) # now generating a sampler_state for each thermodyanmic state, @@ -143,7 +150,8 @@ class creation of LambdaProtocol. # generating unsampled endstates unsampled_dispersion_endstates = create_endstates( copy.deepcopy(thermodynamic_state_list[0]), - copy.deepcopy(thermodynamic_state_list[-1])) + copy.deepcopy(thermodynamic_state_list[-1]) + ) self.create(thermodynamic_states=thermodynamic_state_list, sampler_states=sampler_state_list, storage=reporter, unsampled_thermodynamic_states=unsampled_dispersion_endstates) @@ -159,10 +167,13 @@ class HybridRepexSampler(HybridCompatibilityMixin, number of positions """ - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridRepexSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs) - self._factory = hybrid_factory + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs + ) class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): @@ -171,11 +182,13 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): of positions """ - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridSAMSSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs ) - self._factory = hybrid_factory class HybridMultiStateSampler(HybridCompatibilityMixin, @@ -184,11 +197,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin, MultiStateSampler that supports unsample end states with a different number of positions """ - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridMultiStateSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs ) - self._factory = hybrid_factory def create_endstates(first_thermostate, last_thermostate): diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 514237634..c1ddab71e 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -1128,7 +1128,8 @@ def run( if sampler_settings.sampler_method.lower() == "repex": sampler = _rfe_utils.multistate.HybridRepexSampler( mcmc_moves=integrator, - hybrid_factory=hybrid_factory, + hybrid_system=hybrid_factory.hybrid_system, + hybrid_positions=hybrid_factory.hybrid_positions, online_analysis_interval=rta_its, online_analysis_target_error=early_termination_target_error, online_analysis_minimum_iterations=rta_min_its, @@ -1136,7 +1137,8 @@ def run( elif sampler_settings.sampler_method.lower() == "sams": sampler = _rfe_utils.multistate.HybridSAMSSampler( mcmc_moves=integrator, - hybrid_factory=hybrid_factory, + hybrid_system=hybrid_factory.hybrid_system, + hybrid_positions=hybrid_factory.hybrid_positions, online_analysis_interval=rta_its, online_analysis_minimum_iterations=rta_min_its, flatness_criteria=sampler_settings.sams_flatness_criteria, @@ -1145,12 +1147,12 @@ def run( elif sampler_settings.sampler_method.lower() == "independent": sampler = _rfe_utils.multistate.HybridMultiStateSampler( mcmc_moves=integrator, - hybrid_factory=hybrid_factory, + hybrid_system=hybrid_factory.hybrid_system, + hybrid_positions=hybrid_factory.hybrid_positions, online_analysis_interval=rta_its, online_analysis_target_error=early_termination_target_error, online_analysis_minimum_iterations=rta_min_its, ) - else: raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") @@ -1247,7 +1249,12 @@ def run( if not dry: # pragma: no-cover return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} else: - return {"debug": {"sampler": sampler}} + return {"debug": + { + "sampler": sampler, + "hybrid_factory": hybrid_factory + } + } @staticmethod def structural_analysis(scratch, shared) -> dict: diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index f5ea92cff..711f72d2a 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -236,13 +236,14 @@ def test_dry_run_default_vacuum( dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] + debug = dag_unit.run(dry=True)["debug"] + sampler = debug["sampler"] assert isinstance(sampler, MultiStateSampler) assert not sampler.is_periodic assert sampler._thermodynamic_states[0].barostat is None # Check hybrid OMM and MDTtraj Topologies - htf = sampler._hybrid_factory + htf = debug["hybrid_factory"] # 16 atoms: # 11 common atoms, 1 extra hydrogen in benzene, 4 extra in toluene # 12 bonds in benzene + 4 extra toluene bonds @@ -414,7 +415,7 @@ def test_dry_core_element_change(vac_settings, tmpdir): with tmpdir.as_cwd(): sampler = dag_unit.run(dry=True)["debug"]["sampler"] - system = sampler._hybrid_factory.hybrid_system + system = sampler._hybrid_system assert system.getNumParticles() == 12 # Average mass between nitrogen and carbon assert system.getParticleMass(1) == 12.0127235 * omm_unit.amu @@ -518,7 +519,7 @@ def tip4p_hybrid_factory( shared_basepath=shared_temp, ) - return dag_unit_result["debug"]["sampler"]._factory + return dag_unit_result["debug"]["hybrid_factory"] def test_tip4p_particle_count(tip4p_hybrid_factory): @@ -624,7 +625,7 @@ def test_dry_run_ligand_system_cutoff( with tmpdir.as_cwd(): sampler = dag_unit.run(dry=True)["debug"]["sampler"] - hs = sampler._factory.hybrid_system + hs = sampler._hybrid_system nbfs = [ f @@ -691,9 +692,10 @@ def test_dry_run_charge_backends( dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory - hybrid_system = htf.hybrid_system + debug = dag_unit.run(dry=True)["debug"] + sampler = debug["sampler"] + htf = debug["hybrid_factory"] + hybrid_system = sampler._hybrid_system # get the standard nonbonded force nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)] @@ -785,9 +787,10 @@ def check_propchgs(smc, charge_array): dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory - hybrid_system = htf.hybrid_system + debug = dag_unit.run(dry=True)["debug"] + sampler = debug["sampler"] + htf = debug["hybrid_factory"] + hybrid_system = sampler._hybrid_system # get the standard nonbonded force nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)] @@ -902,7 +905,7 @@ def test_dodecahdron_ligand_box( with tmpdir.as_cwd(): sampler = dag_unit.run(dry=True)["debug"]["sampler"] - hs = sampler._factory.hybrid_system + hs = sampler._hybrid_system vectors = hs.getDefaultPeriodicBoxVectors() From 4a0bd26308868b48b1519a7c7fe0fef8a90f78fb Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 24 Dec 2025 01:54:00 -0500 Subject: [PATCH 10/36] fix up test --- openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index 711f72d2a..814cc899f 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -2156,8 +2156,8 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, solv_settings, unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory + debug = unit.run(dry=True)["debug"] + htf = debug["hybrid_factory"] _assert_total_charge(htf.hybrid_system, htf._atom_classes, 0, 0) assert len(htf._atom_classes["core_atoms"]) == 14 From 5848adcb54f4e17e2d7c5f1ee5ce33aa79eb70ce Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 24 Dec 2025 01:56:56 -0500 Subject: [PATCH 11/36] fix up some slow tests --- .../tests/protocols/openmm_rfe/test_hybrid_top_protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index 814cc899f..fc11cf164 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -1601,7 +1601,7 @@ def tyk2_xml(tmp_path_factory): dryrun = pu.run(dry=True, shared_basepath=tmp) - system = dryrun["debug"]["sampler"]._hybrid_factory.hybrid_system + system = dryrun["debug"]["sampler"]._hybrid_system return ET.fromstring(XmlSerializer.serialize(system)) @@ -2225,8 +2225,8 @@ def test_dry_run_complex_alchemwater_totcharge( unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory + debug = unit.run(dry=True)["debug"] + htf = debug["hybrid_factory"] _assert_total_charge(htf.hybrid_system, htf._atom_classes, chgA, chgB) assert len(htf._atom_classes["core_atoms"]) == core_atoms From b6d5ecd315b6ca186c62494e8aed37e4e0f69cde Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 25 Dec 2025 19:27:36 -0500 Subject: [PATCH 12/36] Fix up the one test --- .../protocols/openmm_rfe/equil_rfe_methods.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index e20c7360b..356423922 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -641,10 +641,28 @@ def _validate_charge_difference( @staticmethod def _validate_simulation_settings( - simulation_settings, - integrator_settings, - output_settings, + simulation_settings: MultiStateSimulationSettings, + integrator_settings: IntegratorSettings, + output_settings: MultiStateOutputSettings, ): + """ + Validate various simulation settings, including but not limited to + timestep conversions, and output file write frequencies. + + Parameters + ---------- + simulation_settings : MultiStateSimulationSettings + The sampler simulation settings. + integrator_settings : IntegratorSettings + Settings defining the behaviour of the integrator. + output_settings : MultiStateOutputSettings + Settings defining the simulation file writing behaviour. + + Raises + ------ + ValueError + * If the + """ steps_per_iteration = settings_validation.convert_steps_per_iteration( simulation_settings=simulation_settings, @@ -671,7 +689,7 @@ def _validate_simulation_settings( if output_settings.positions_write_frequency is not None: _ = settings_validation.divmod_time_and_check( numerator=output_settings.positions_write_frequency, - denominator=sampler_settings.time_per_iteration, + denominator=simulation_settings.time_per_iteration, numerator_name="output settings' position_write_frequency", denominator_name="sampler settings' time_per_iteration", ) @@ -679,13 +697,13 @@ def _validate_simulation_settings( if output_settings.velocities_write_frequency is not None: _ = settings_validation.divmod_time_and_check( numerator=output_settings.velocities_write_frequency, - denominator=sampler_settings.time_per_iteration, + denominator=simulation_settings.time_per_iteration, numerator_name="output settings' velocity_write_frequency", denominator_name="sampler settings' time_per_iteration", ) _, _ = settings_validation.convert_real_time_analysis_iterations( - simulation_settings=sampler_settings, + simulation_settings=simulation_settings, ) def _validate( From 0605d11049934524bafa5d14d4b12362dce7d29b Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 26 Dec 2025 01:03:39 +0000 Subject: [PATCH 13/36] fix a few things --- .../tests/protocols/openmm_rfe/test_hybrid_top_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index d0452eb9a..c4a686a01 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -178,7 +178,7 @@ def test_charge_difference_no_corr(benzene_to_benzoic_mapping): def test_charge_difference_no_solvent(benzene_to_benzoic_mapping): errmsg = "Cannot use eplicit charge correction without solvent" - with pytest.raises(ValueError, errmsg): + with pytest.raises(ValueError, match=errmsg): openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( benzene_to_benzoic_mapping, "pme", @@ -223,7 +223,7 @@ def test_get_charge_difference(mapping_name, result, request, caplog): mapping = request.getfixturevalue(mapping_name) caplog.set_level(logging.INFO) - ion = r"Na\+" if result == -1 else r"Cl\-" + ion = r"Na+" if result == -1 else r"Cl-" msg = ( f"A charge difference of {result} is observed " "between the end states. This will be addressed by " From 48106a297237fffb9e76ca16d2ed99a3d6834bac Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 25 Dec 2025 23:45:07 -0500 Subject: [PATCH 14/36] fix the remaining tests --- .../protocols/openmm_rfe/equil_rfe_methods.py | 10 +- .../openmm_utils/system_validation.py | 33 +- .../openmm_rfe/test_hybrid_top_protocol.py | 94 ----- .../openmm_rfe/test_hybrid_top_validation.py | 363 +++++++++++++++--- 4 files changed, 337 insertions(+), 163 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 356423922..710f4db8e 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -690,7 +690,7 @@ def _validate_simulation_settings( _ = settings_validation.divmod_time_and_check( numerator=output_settings.positions_write_frequency, denominator=simulation_settings.time_per_iteration, - numerator_name="output settings' position_write_frequency", + numerator_name="output settings' positions_write_frequency", denominator_name="sampler settings' time_per_iteration", ) @@ -698,7 +698,7 @@ def _validate_simulation_settings( _ = settings_validation.divmod_time_and_check( numerator=output_settings.velocities_write_frequency, denominator=simulation_settings.time_per_iteration, - numerator_name="output settings' velocity_write_frequency", + numerator_name="output settings' velocities_write_frequency", denominator_name="sampler settings' time_per_iteration", ) @@ -768,8 +768,10 @@ def _validate( # PR #125 temporarily pin lambda schedule spacing to n_replicas if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.lambda_windows: errmsg = ( - "Number of replicas in simulation_settings must equal " - "number of lambda windows in lambda_settings." + "Number of replicas in ``simulation_settings``: " + f"{self.settings.simulation_settings.n_replicas} must equal " + "the number of lambda windows in lambda_settings: " + f"{self.settings.lambda_settings.lambda_windows}." ) raise ValueError(errmsg) diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index 0fd3c3518..9d67e108f 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -95,23 +95,24 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): `nocutoff`. * If the SolventComponent solvent is not water. """ - solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)] + solv_comps = state.get_components_of_type(SolventComponent) - if len(solv) > 0 and nonbonded_method.lower() == "nocutoff": - errmsg = "nocutoff cannot be used for solvent transformations" - raise ValueError(errmsg) + if len(solv_comps) > 0: + if nonbonded_method.lower() == "nocutoff": + errmsg = "nocutoff cannot be used for solvent transformations" + raise ValueError(errmsg) - if len(solv) == 0 and nonbonded_method.lower() == "pme": - errmsg = "PME cannot be used for vacuum transform" - raise ValueError(errmsg) + if len(solv_comps) > 1: + errmsg = "Multiple SolventComponent found, only one is supported" + raise ValueError(errmsg) - if len(solv) > 1: - errmsg = "Multiple SolventComponent found, only one is supported" - raise ValueError(errmsg) - - if len(solv) > 0 and solv[0].smiles != "O": - errmsg = "Non water solvent is not currently supported" - raise ValueError(errmsg) + if solv_comps[0].smiles != "O": + errmsg = "Non water solvent is not currently supported" + raise ValueError(errmsg) + else: + if nonbonded_method.lower() == "pme": + errmsg = "PME cannot be used for vacuum transform" + raise ValueError(errmsg) def validate_protein(state: ChemicalSystem): @@ -129,9 +130,9 @@ def validate_protein(state: ChemicalSystem): ValueError If there are multiple ProteinComponent in the ChemicalSystem. """ - nprot = sum(1 for comp in state.values() if isinstance(comp, ProteinComponent)) + prot_comps = state.get_components_of_type(ProteinComponent) - if nprot > 1: + if len(prot_comps) > 1: errmsg = "Multiple ProteinComponent found, only one is supported" raise ValueError(errmsg) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index 148f32fe1..d4818f355 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -948,63 +948,6 @@ def test_lambda_schedule(windows): assert len(lambdas.lambda_schedule) == windows -def test_hightimestep( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - vac_settings, - tmpdir, -): - vac_settings.forcefield_settings.hydrogen_mass = 1.0 - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - - errmsg = "too large for hydrogen mass" - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True) - - -def test_element_change_warning(atom_mapping_basic_test_files): - # check a mapping with element change gets rejected early - l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] - l2 = atom_mapping_basic_test_files["2-naftanol"] - - # We use the 'old' lomap defaults because the - # basic test files inputs we use aren't fully aligned - mapper = setup.LomapAtomMapper( - time=20, threed=True, max3d=1000.0, element_change=True, seed="", shift=True - ) - - mapping = next(mapper.suggest_mappings(l1, l2)) - - sys1 = openfe.ChemicalSystem( - {"ligand": l1, "solvent": openfe.SolventComponent()}, - ) - sys2 = openfe.ChemicalSystem( - {"ligand": l2, "solvent": openfe.SolventComponent()}, - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.warns(UserWarning, match="Element change"): - _ = p.create( - stateA=sys1, - stateB=sys2, - mapping=mapping, - ) - - def test_ligand_overlap_warning( benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, vac_settings, tmpdir ): @@ -2023,40 +1966,3 @@ def test_dry_run_vacuum_write_frequency( assert reporter.velocity_interval == velocities_write_frequency.m else: assert reporter.velocity_interval == 0 - - -@pytest.mark.parametrize( - "positions_write_frequency,velocities_write_frequency", - [ - [100.1 * unit.picosecond, 100 * unit.picosecond], - [100 * unit.picosecond, 100.1 * unit.picosecond], - ], -) -def test_pos_write_frequency_not_divisible( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - positions_write_frequency, - velocities_write_frequency, - tmpdir, - vac_settings, -): - vac_settings.output_settings.positions_write_frequency = positions_write_frequency - vac_settings.output_settings.velocities_write_frequency = velocities_write_frequency - - protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - - # create DAG from protocol and take first (and only) work unit from within - dag = protocol.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - - with tmpdir.as_cwd(): - errmsg = "The output settings' " - with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True)["debug"]["sampler"] diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index c4a686a01..984c4d6bb 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -18,7 +18,7 @@ from kartograf.atom_aligner import align_mol_shape from numpy.testing import assert_allclose from openff.toolkit import Molecule -from openff.units import unit +from openff.units import unit as offunit from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmm import CustomNonbondedForce, MonteCarloBarostat, NonbondedForce, XmlSerializer, app from openmm import unit as omm_unit @@ -138,6 +138,169 @@ def test_validate_mapping_alchem_not_in(state, benzene_to_toluene_mapping): ) +def test_vaccuum_PME_error( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + solv_settings +): + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "PME cannot be used for vacuum transform" + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + ) + + +def test_solvent_nocutoff_error( + benzene_system, + toluene_system, + benzene_to_toluene_mapping, + vac_settings, +): + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "nocutoff cannot be used for solvent transformation" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_system, + stateB=toluene_system, + mapping=benzene_to_toluene_mapping, + ) + + +def test_nonwater_solvent_error( + benzene_modifications, + benzene_to_toluene_mapping, + solv_settings, +): + solvent = openfe.SolventComponent(smiles='C') + stateA = openfe.ChemicalSystem( + { + 'ligand': benzene_modifications['benzene'], + 'solvent': solvent, + } + ) + + stateB = openfe.ChemicalSystem( + { + 'ligand': benzene_modifications['toluene'], + 'solvent': solvent + } + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Non water solvent is not currently supported" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=stateA, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_too_many_solv_comps_error( + benzene_modifications, + benzene_to_toluene_mapping, + solv_settings, +): + stateA = openfe.ChemicalSystem( + { + 'ligand': benzene_modifications['benzene'], + 'solvent!': openfe.SolventComponent(neutralize=True), + 'solvent2': openfe.SolventComponent(neutralize=False), + } + ) + + stateB = openfe.ChemicalSystem( + { + 'ligand': benzene_modifications['toluene'], + 'solvent!': openfe.SolventComponent(neutralize=True), + 'solvent2': openfe.SolventComponent(neutralize=False), + } + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Multiple SolventComponent found, only one is supported" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=stateA, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_bad_solv_settings( + benzene_system, + toluene_system, + benzene_to_toluene_mapping, + solv_settings, +): + """ + Test a case where the solvent settings would be wrong. + Not doing every cases since those are covered under + ``test_openmmutils.py``. + """ + solv_settings.solvation_settings.solvent_padding = 1.2 * offunit.nanometer + solv_settings.solvation_settings.number_of_solvent_molecules = 20 + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Only one of solvent_padding, number_of_solvent_molecules," + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_system, + stateB=toluene_system, + mapping=benzene_to_toluene_mapping + ) + + +def test_too_many_prot_comps_error( + benzene_modifications, + benzene_to_toluene_mapping, + T4_protein_component, + eg5_protein, + solv_settings, +): + + stateA = openfe.ChemicalSystem( + { + 'ligand': benzene_modifications['benzene'], + 'solvent': openfe.SolventComponent(), + 'protein1': T4_protein_component, + 'protein2': eg5_protein, + } + ) + + stateB = openfe.ChemicalSystem( + { + 'ligand': benzene_modifications['toluene'], + 'solvent': openfe.SolventComponent(), + 'protein1': T4_protein_component, + 'protein2': eg5_protein, + } + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Multiple ProteinComponent found, only one is supported" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=stateA, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + def test_element_change_warning(atom_mapping_basic_test_files): # check a mapping with element change gets rejected early l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] @@ -248,81 +411,183 @@ def test_hightimestep( toluene_vacuum_system, benzene_to_toluene_mapping, vac_settings, - tmpdir, ): vac_settings.forcefield_settings.hydrogen_mass = 1.0 - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) errmsg = "too large for hydrogen mass" - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True) + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None + ) -def test_n_replicas_not_n_windows( +def test_time_per_iteration_divmod( benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, vac_settings, - tmpdir, ): - # For PR #125 we pin such that the number of lambda windows - # equals the numbers of replicas used - TODO: remove limitation - # default lambda windows is 11 - vac_settings.simulation_settings.n_replicas = 13 + vac_settings.simulation_settings.time_per_iteration = 10 * offunit.ps + vac_settings.integrator_settings.timestep = 4 * offunit.ps - errmsg = "Number of replicas 13 does not equal the number of lambda windows 11" + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - dag_unit.run(dry=True) + errmsg = "does not evenly divide by the timestep" + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None + ) -def test_vaccuum_PME_error( - benzene_vacuum_system, benzene_modifications, benzene_to_toluene_mapping + +@pytest.mark.parametrize( + "attribute", ["equilibration_length", "production_length"] +) +def test_simsteps_not_timestep_divisible( + attribute, + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, ): - # state B doesn't have a solvent component (i.e. its vacuum) - stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"]}) + setattr(vac_settings.simulation_settings, attribute, 102 * offunit.fs) + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), + errmsg = "Simulation time not divisible by timestep" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None + ) + + +@pytest.mark.parametrize( + "attribute", ["equilibration_length", "production_length"] +) +def test_simsteps_not_mcstep_divisible( + attribute, + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, +): + setattr(vac_settings.simulation_settings, attribute, 102 * offunit.ps) + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = ( + "should contain a number of steps divisible by the number of " + "integrator timesteps" ) - errmsg = "PME cannot be used for vacuum transform" + with pytest.raises(ValueError, match=errmsg): - _ = p.create( + p.validate( stateA=benzene_vacuum_system, - stateB=stateB, + stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, + extends=None ) -def test_get_alchemical_waters_no_waters( - benzene_solvent_openmm_system, +def test_checkpoint_interval_not_divisible_time_per_iter( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, ): - system, topology, positions = benzene_solvent_openmm_system + vac_settings.output_settings.checkpoint_interval = 4 * offunit.ps + vac_settings.simulation_settings.time_per_iteration = 2.5 * offunit.ps + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) - errmsg = "There are no waters" + errmsg = "does not evenly divide by the amount of time per state MCMC" with pytest.raises(ValueError, match=errmsg): - topologyhelpers.get_alchemical_waters( - topology, positions, charge_difference=1, distance_cutoff=3.0 * unit.nanometer + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None + ) + + +@pytest.mark.parametrize( + "attribute", + ["positions_write_frequency", "velocities_write_frequency"] +) +def test_pos_vel_write_frequency_not_divisible( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + attribute, + vac_settings, +): + setattr(vac_settings.output_settings, attribute, 100.1 * offunit.picosecond) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = f"The output settings' {attribute}" + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None + ) + + +@pytest.mark.parametrize( + "attribute", + ["real_time_analysis_interval", "real_time_analysis_interval"] +) +def test_pos_vel_write_frequency_not_divisible( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + attribute, + vac_settings, +): + setattr(vac_settings.simulation_settings, attribute, 100.1 * offunit.picosecond) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = f"The {attribute}" + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None + ) + +def test_n_replicas_not_n_windows( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, + tmpdir, +): + # For PR #125 we pin such that the number of lambda windows + # equals the numbers of replicas used - TODO: remove limitation + vac_settings.simulation_settings.n_replicas = 13 + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "Number of replicas in ``simulation_settings``:" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None ) From 5af66e81688603d997282176fd4d11c73c50e454 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 25 Dec 2025 23:50:36 -0500 Subject: [PATCH 15/36] cleanup imports --- .../protocols/openmm_rfe/equil_rfe_methods.py | 1 - .../openmm_rfe/test_hybrid_top_validation.py | 31 +------------------ 2 files changed, 1 insertion(+), 31 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 710f4db8e..189fffe79 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -53,7 +53,6 @@ from openff.units import Quantity, unit from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmmtools import multistate -from rdkit import Chem from openfe.due import Doi, due from openfe.protocols.openmm_utils.omm_settings import ( diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 984c4d6bb..0f7db50d6 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -1,42 +1,13 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import logging -import copy -import json -import sys -import xml.etree.ElementTree as ET -from importlib import resources -from math import sqrt -from pathlib import Path -from unittest import mock - -import gufe -import mdtraj as mdt -import numpy as np + import pytest -from kartograf import KartografAtomMapper -from kartograf.atom_aligner import align_mol_shape -from numpy.testing import assert_allclose -from openff.toolkit import Molecule from openff.units import unit as offunit -from openff.units.openmm import ensure_quantity, from_openmm, to_openmm -from openmm import CustomNonbondedForce, MonteCarloBarostat, NonbondedForce, XmlSerializer, app -from openmm import unit as omm_unit -from openmmforcefields.generators import SMIRNOFFTemplateGenerator -from openmmtools.multistate.multistatesampler import MultiStateSampler -from rdkit import Chem -from rdkit.Geometry import Point3D import openfe from openfe import setup from openfe.protocols import openmm_rfe -from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers -from openfe.protocols.openmm_utils import omm_compute, system_creation -from openfe.protocols.openmm_utils.charge_generation import ( - HAS_ESPALOMA_CHARGE, - HAS_NAGL, - HAS_OPENEYE, -) @pytest.fixture() From 58dd71cceb1154af70b9e6652d125ace5e23fe8f Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 26 Dec 2025 00:26:36 -0500 Subject: [PATCH 16/36] Migrate protocol, units, and results for the hybridtop protocol --- openfe/protocols/openmm_rfe/__init__.py | 12 +- .../protocols/openmm_rfe/equil_rfe_methods.py | 1460 +---------------- .../openmm_rfe/hybridtop_protocols.py | 571 +++++++ .../openmm_rfe/hybridtop_unit_results.py | 240 +++ .../protocols/openmm_rfe/hybridtop_units.py | 697 ++++++++ 5 files changed, 1520 insertions(+), 1460 deletions(-) create mode 100644 openfe/protocols/openmm_rfe/hybridtop_protocols.py create mode 100644 openfe/protocols/openmm_rfe/hybridtop_unit_results.py create mode 100644 openfe/protocols/openmm_rfe/hybridtop_units.py diff --git a/openfe/protocols/openmm_rfe/__init__.py b/openfe/protocols/openmm_rfe/__init__.py index e400cc3d3..137b641c0 100644 --- a/openfe/protocols/openmm_rfe/__init__.py +++ b/openfe/protocols/openmm_rfe/__init__.py @@ -2,11 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/openfe from . import _rfe_utils -from .equil_rfe_methods import ( - RelativeHybridTopologyProtocol, - RelativeHybridTopologyProtocolResult, - RelativeHybridTopologyProtocolUnit, -) -from .equil_rfe_settings import ( - RelativeHybridTopologyProtocolSettings, -) +from .hybridtop_protocols import RelativeHybridTopologyProtocol +from .hybridtop_unit_results import RelativeHybridTopologyProtocolResult +from .hybridtop_units import RelativeHybridTopologyProtocolUnit +from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 208ec912c..22106b484 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -1,1466 +1,22 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -"""Equilibrium Relative Free Energy methods using OpenMM and OpenMMTools in a +"""Equilibrium Relative Free Energy Protocol using OpenMM and OpenMMTools in a Perses-like manner. -This module implements the necessary methodology toolking to run calculate a -ligand relative free energy transformation using OpenMM tools and one of the -following methods: +This module implements the necessary methodology toolking to run calculate the +relative free energy of a ligand transformation using OpenMM tools and one of +the following methods: - Hamiltonian Replica Exchange - Self-adjusted mixture sampling - Independent window sampling -TODO ----- -* Improve this docstring by adding an example use case. - Acknowledgements ---------------- This Protocol is based on, and leverages components originating from the Perses toolkit (https://github.com/choderalab/perses). """ -from __future__ import annotations - -import json -import logging -import os -import pathlib -import subprocess -import uuid -import warnings -from collections import defaultdict -from itertools import chain -from typing import Any, Iterable, Optional, Union - -import gufe -import matplotlib.pyplot as plt -import mdtraj -import numpy as np -import numpy.typing as npt -import openmmtools -from gufe import ( - ChemicalSystem, - Component, - ComponentMapping, - LigandAtomMapping, - ProteinComponent, - SmallMoleculeComponent, - SolventComponent, - settings, -) -from openff.toolkit.topology import Molecule as OFFMolecule -from openff.units import Quantity, unit -from openff.units.openmm import ensure_quantity, from_openmm, to_openmm -from openmmtools import multistate - -from openfe.due import Doi, due -from openfe.protocols.openmm_utils.omm_settings import ( - BasePartialChargeSettings, -) - -from ...analysis import plotting -from ...utils import log_system_probe, without_oechem_backend -from ..openmm_utils import ( - charge_generation, - multistate_analysis, - omm_compute, - settings_validation, - system_creation, - system_validation, -) -from . import _rfe_utils -from .equil_rfe_settings import ( - AlchemicalSettings, - IntegratorSettings, - LambdaSettings, - MultiStateOutputSettings, - MultiStateSimulationSettings, - OpenFFPartialChargeSettings, - OpenMMEngineSettings, - OpenMMSolvationSettings, - RelativeHybridTopologyProtocolSettings, -) - -logger = logging.getLogger(__name__) - - -due.cite( - Doi("10.5281/zenodo.1297683"), - description="Perses", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True, -) - -due.cite( - Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True, -) - -due.cite( - Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True, -) - - -def _get_resname(off_mol) -> str: - # behaviour changed between 0.10 and 0.11 - omm_top = off_mol.to_topology().to_openmm() - names = [r.name for r in omm_top.residues()] - if len(names) > 1: - raise ValueError("We assume single residue") - return names[0] - - -class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a RelativeHybridTopologyProtocol""" - - def __init__(self, **data): - super().__init__(**data) - # data is mapping of str(repeat_id): list[protocolunitresults] - # TODO: Detect when we have extensions and stitch these together? - if any(len(pur_list) > 2 for pur_list in self.data.values()): - raise NotImplementedError("Can't stitch together results yet") - - @staticmethod - def compute_mean_estimate(dGs: list[Quantity]) -> Quantity: - u = dGs[0].u - # convert all values to units of the first value, then take average of magnitude - # this would avoid a screwy case where each value was in different units - vals = np.asarray([dG.to(u).m for dG in dGs]) - - return np.average(vals) * u - - def get_estimate(self) -> Quantity: - """Average free energy difference of this transformation - - Returns - ------- - dG : openff.units.Quantity - The free energy difference between the first and last states. This is - a Quantity defined with units. - """ - # TODO: Check this holds up completely for SAMS. - dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] - return self.compute_mean_estimate(dGs) - - @staticmethod - def compute_uncertainty(dGs: list[Quantity]) -> Quantity: - u = dGs[0].u - # convert all values to units of the first value, then take average of magnitude - # this would avoid a screwy case where each value was in different units - vals = np.asarray([dG.to(u).m for dG in dGs]) - - return np.std(vals) * u - - def get_uncertainty(self) -> Quantity: - """The uncertainty/error in the dG value: The std of the estimates of - each independent repeat - """ - - dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] - return self.compute_uncertainty(dGs) - - def get_individual_estimates(self) -> list[tuple[Quantity, Quantity]]: - """Return a list of tuples containing the individual free energy - estimates and associated MBAR errors for each repeat. - - Returns - ------- - dGs : list[tuple[openff.units.Quantity]] - n_replicate simulation list of tuples containing the free energy - estimates (first entry) and associated MBAR estimate errors - (second entry). - """ - dGs = [ - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - for pus in self.data.values() - ] - return dGs - - def get_forward_and_reverse_energy_analysis( - self, - ) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]: - """ - Get a list of forward and reverse analysis of the free energies - for each repeat using uncorrelated production samples. - - The returned dicts have keys: - 'fractions' - the fraction of data used for this estimate - 'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate - 'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty - - The 'fractions' values are a numpy array, while the other arrays are - Quantity arrays, with units attached. - - If the list entry is ``None`` instead of a dictionary, this indicates - that the analysis could not be carried out for that repeat. This - is most likely caused by MBAR convergence issues when attempting to - calculate free energies from too few samples. - - - Returns - ------- - forward_reverse : list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]] - - - Raises - ------ - UserWarning - If any of the forward and reverse entries are ``None``. - """ - forward_reverse = [ - pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values() - ] - - if None in forward_reverse: - wmsg = ( - "One or more ``None`` entries were found in the list of " - "forward and reverse analyses. This is likely caused by " - "an MBAR convergence failure caused by too few independent " - "samples when calculating the free energies of the 10% " - "timeseries slice." - ) - warnings.warn(wmsg) - - return forward_reverse - - def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]: - """ - Return a list of dictionary containing the MBAR overlap estimates - calculated for each repeat. - - Returns - ------- - overlap_stats : list[dict[str, npt.NDArray]] - A list of dictionaries containing the following keys: - * ``scalar``: One minus the largest nontrivial eigenvalue - * ``eigenvalues``: The sorted (descending) eigenvalues of the - overlap matrix - * ``matrix``: Estimated overlap matrix of observing a sample from - state i in state j - """ - # Loop through and get the repeats and get the matrices - overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()] - - return overlap_stats - - def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]: - """The replica lambda state transition statistics for each repeat. - - Note - ---- - This is currently only available in cases where a replica exchange - simulation was run. - - Returns - ------- - repex_stats : list[dict[str, npt.NDArray]] - A list of dictionaries containing the following: - * ``eigenvalues``: The sorted (descending) eigenvalues of the - lambda state transition matrix - * ``matrix``: The transition matrix estimate of a replica switching - from state i to state j. - """ - try: - repex_stats = [ - pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values() - ] - except KeyError: - errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" - raise ValueError(errmsg) - - return repex_stats - - def get_replica_states(self) -> list[npt.NDArray]: - """ - Returns the timeseries of replica states for each repeat. - - Returns - ------- - replica_states : List[npt.NDArray] - List of replica states for each repeat - """ - - def is_file(filename: str): - p = pathlib.Path(filename) - if not p.exists(): - errmsg = f"File could not be found {p}" - raise ValueError(errmsg) - return p - - replica_states = [] - - for pus in self.data.values(): - nc = is_file(pus[0].outputs["nc"]) - dir_path = nc.parents[0] - chk = is_file(dir_path / pus[0].outputs["last_checkpoint"]).name - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode="r" - ) - replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states())) - reporter.close() - - return replica_states - - def equilibration_iterations(self) -> list[float]: - """ - Returns the number of equilibration iterations for each repeat - of the calculation. - - Returns - ------- - equilibration_lengths : list[float] - """ - equilibration_lengths = [ - pus[0].outputs["equilibration_iterations"] for pus in self.data.values() - ] - - return equilibration_lengths - - def production_iterations(self) -> list[float]: - """ - Returns the number of uncorrelated production samples for each - repeat of the calculation. - - Returns - ------- - production_lengths : list[float] - """ - production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()] - - return production_lengths - - -class RelativeHybridTopologyProtocol(gufe.Protocol): - """ - Relative Free Energy calculations using OpenMM and OpenMMTools. - - Based on `Perses `_ - - See Also - -------- - :mod:`openfe.protocols` - :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologySettings` - :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyResult` - :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocolUnit` - """ - - result_cls = RelativeHybridTopologyProtocolResult - _settings_cls = RelativeHybridTopologyProtocolSettings - _settings: RelativeHybridTopologyProtocolSettings - - @classmethod - def _default_settings(cls): - """A dictionary of initial settings for this creating this Protocol - - These settings are intended as a suitable starting point for creating - an instance of this protocol. It is recommended, however that care is - taken to inspect and customize these before performing a Protocol. - - Returns - ------- - Settings - a set of default settings - """ - return RelativeHybridTopologyProtocolSettings( - protocol_repeats=3, - forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), - thermo_settings=settings.ThermoSettings( - temperature=298.15 * unit.kelvin, - pressure=1 * unit.bar, - ), - partial_charge_settings=OpenFFPartialChargeSettings(), - solvation_settings=OpenMMSolvationSettings(), - alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"), - lambda_settings=LambdaSettings(), - simulation_settings=MultiStateSimulationSettings( - equilibration_length=1.0 * unit.nanosecond, - production_length=5.0 * unit.nanosecond, - ), - engine_settings=OpenMMEngineSettings(), - integrator_settings=IntegratorSettings(), - output_settings=MultiStateOutputSettings(), - ) - - @classmethod - def _adaptive_settings( - cls, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: gufe.LigandAtomMapping | list[gufe.LigandAtomMapping], - initial_settings: None | RelativeHybridTopologyProtocolSettings = None, - ) -> RelativeHybridTopologyProtocolSettings: - """ - Get the recommended OpenFE settings for this protocol based on the input states involved in the - transformation. - - These are intended as a suitable starting point for creating an instance of this protocol, which can be further - customized before performing a Protocol. - - Parameters - ---------- - stateA : ChemicalSystem - The initial state of the transformation. - stateB : ChemicalSystem - The final state of the transformation. - mapping : LigandAtomMapping | list[LigandAtomMapping] - The mapping(s) between transforming components in stateA and stateB. - initial_settings : None | RelativeHybridTopologyProtocolSettings, optional - Initial settings to base the adaptive settings on. If None, default settings are used. - - Returns - ------- - RelativeHybridTopologyProtocolSettings - The recommended settings for this protocol based on the input states. - - Notes - ----- - - If the transformation involves a change in net charge, the settings are adapted to use a more expensive - protocol with 22 lambda windows and 20 ns production length per window. - - If both states contain a ProteinComponent, the solvation padding is set to 1 nm. - - If initial_settings is provided, the adaptive settings are based on a copy of these settings. - """ - # use initial settings or default settings - # this is needed for the CLI so we don't override user settings - if initial_settings is not None: - protocol_settings = initial_settings.copy(deep=True) - else: - protocol_settings = cls.default_settings() - - if isinstance(mapping, list): - mapping = mapping[0] - - if mapping.get_alchemical_charge_difference() != 0: - # apply the recommended charge change settings taken from the industry benchmarking as fast settings not validated - # - info = ( - "Charge changing transformation between ligands " - f"{mapping.componentA.name} and {mapping.componentB.name}. " - "A more expensive protocol with 22 lambda windows, sampled " - "for 20 ns each, will be used here." - ) - logger.info(info) - protocol_settings.alchemical_settings.explicit_charge_correction = True - protocol_settings.simulation_settings.production_length = 20 * unit.nanosecond - protocol_settings.simulation_settings.n_replicas = 22 - protocol_settings.lambda_settings.lambda_windows = 22 - - # adapt the solvation padding based on the system components - if stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent): - protocol_settings.solvation_settings.solvent_padding = 1 * unit.nanometer - - return protocol_settings - - @staticmethod - def _validate_endstates( - stateA: ChemicalSystem, - stateB: ChemicalSystem, - ) -> None: - """ - Validates the end states for the RFE protocol. - - Parameters - ---------- - stateA : ChemicalSystem - The chemical system of end state A. - stateB : ChemicalSystem - The chemical system of end state B. - - Raises - ------ - ValueError - * If either state contains more than one unique Component. - * If unique components are not SmallMoleculeComponents. - """ - # Get the difference in Components between each state - diff = stateA.component_diff(stateB) - - for i, entry in enumerate(diff): - state_label = "A" if i == 0 else "B" - - # Check that there is only one unique Component in each state - if len(entry) != 1: - errmsg = ( - "Only one alchemical component is allowed per end state. " - f"Found {len(entry)} in state {state_label}." - ) - raise ValueError(errmsg) - - # Check that the unique Component is a SmallMoleculeComponent - if not isinstance(entry[0], SmallMoleculeComponent): - errmsg = ( - f"Alchemical component in state {state_label} is of type " - f"{type(entry[0])}, but only SmallMoleculeComponents " - "transformations are currently supported." - ) - raise ValueError(errmsg) - - @staticmethod - def _validate_mapping( - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], - alchemical_components: dict[str, list[Component]], - ) -> None: - """ - Validates that the provided mapping(s) are suitable for the RFE protocol. - - Parameters - ---------- - mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] - all mappings between transforming components. - alchemical_components : dict[str, list[Component]] - Dictionary contatining the alchemical components for - states A and B. - - Raises - ------ - ValueError - * If there are more than one mapping or mapping is None - * If the mapping components are not in the alchemical components. - UserWarning - * Mappings which involve element changes in core atoms - """ - # if a single mapping is provided, convert to list - if isinstance(mapping, ComponentMapping): - mapping = [mapping] - - # For now we only support a single mapping - if mapping is None or len(mapping) > 1: - errmsg = "A single LigandAtomMapping is expected for this Protocol" - raise ValueError(errmsg) - - # check that the mapping components are in the alchemical components - for m in mapping: - if m.componentA not in alchemical_components["stateA"]: - raise ValueError(f"Mapping componentA {m.componentA} not in alchemical components of stateA") - if m.componentB not in alchemical_components["stateB"]: - raise ValueError(f"Mapping componentB {m.componentB} not in alchemical components of stateB") - - # TODO: remove - this is now the default behaviour? - # Check for element changes in mappings - for m in mapping: - molA = m.componentA.to_rdkit() - molB = m.componentB.to_rdkit() - for i, j in m.componentA_to_componentB.items(): - atomA = molA.GetAtomWithIdx(i) - atomB = molB.GetAtomWithIdx(j) - if atomA.GetAtomicNum() != atomB.GetAtomicNum(): - wmsg = ( - f"Element change in mapping between atoms " - f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " - f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" - "No mass scaling is attempted in the hybrid topology, " - "the average mass of the two atoms will be used in the " - "simulation" - ) - logger.warning(wmsg) - warnings.warn(wmsg) - - @staticmethod - def _validate_charge_difference( - mapping: LigandAtomMapping, - nonbonded_method: str, - explicit_charge_correction: bool, - solvent_component: SolventComponent | None, - ): - """ - Validates the net charge difference between the two states. - - Parameters - ---------- - mapping : dict[str, ComponentMapping] - Dictionary of mappings between transforming components. - nonbonded_method : str - The OpenMM nonbonded method used for the simulation. - explicit_charge_correction : bool - Whether or not to use an explicit charge correction. - solvent_component : openfe.SolventComponent | None - The SolventComponent of the simulation. - - Raises - ------ - ValueError - * If an explicit charge correction is attempted and the - nonbonded method is not PME. - * If the absolute charge difference is greater than one - and an explicit charge correction is attempted. - UserWarning - * If there is any charge difference. - """ - difference = mapping.get_alchemical_charge_difference() - - if abs(difference) == 0: - return - - if not explicit_charge_correction: - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. No charge correction has " - "been requested, please account for this in your " - "final results." - ) - logger.warning(wmsg) - warnings.warn(wmsg) - return - - if solvent_component is None: - errmsg = "Cannot use eplicit charge correction without solvent" - raise ValueError(errmsg) - - # We implicitly check earlier that we have to have pme for a solvated - # system, so we only need to check the nonbonded method here - if nonbonded_method.lower() != "pme": - errmsg = "Explicit charge correction when not using PME is not currently supported." - raise ValueError(errmsg) - - if abs(difference) > 1: - errmsg = ( - f"A charge difference of {difference} is observed " - "between the end states and an explicit charge " - "correction has been requested. Unfortunately " - "only absolute differences of 1 are supported." - ) - raise ValueError(errmsg) - - ion = { - -1: solvent_component.positive_ion, - 1: solvent_component.negative_ion - }[difference] - - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - logger.info(wmsg) - - @staticmethod - def _validate_simulation_settings( - simulation_settings: MultiStateSimulationSettings, - integrator_settings: IntegratorSettings, - output_settings: MultiStateOutputSettings, - ): - """ - Validate various simulation settings, including but not limited to - timestep conversions, and output file write frequencies. - - Parameters - ---------- - simulation_settings : MultiStateSimulationSettings - The sampler simulation settings. - integrator_settings : IntegratorSettings - Settings defining the behaviour of the integrator. - output_settings : MultiStateOutputSettings - Settings defining the simulation file writing behaviour. - - Raises - ------ - ValueError - * If the - """ - - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings=simulation_settings, - integrator_settings=integrator_settings, - ) - - _ = settings_validation.get_simsteps( - sim_length=simulation_settings.equilibration_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - - _ = settings_validation.get_simsteps( - sim_length=simulation_settings.production_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - - _ = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=output_settings.checkpoint_interval, - time_per_iteration=simulation_settings.time_per_iteration, - ) - - if output_settings.positions_write_frequency is not None: - _ = settings_validation.divmod_time_and_check( - numerator=output_settings.positions_write_frequency, - denominator=simulation_settings.time_per_iteration, - numerator_name="output settings' positions_write_frequency", - denominator_name="sampler settings' time_per_iteration", - ) - - if output_settings.velocities_write_frequency is not None: - _ = settings_validation.divmod_time_and_check( - numerator=output_settings.velocities_write_frequency, - denominator=simulation_settings.time_per_iteration, - numerator_name="output settings' velocities_write_frequency", - denominator_name="sampler settings' time_per_iteration", - ) - - _, _ = settings_validation.convert_real_time_analysis_iterations( - simulation_settings=simulation_settings, - ) - - def _validate( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None, - extends: gufe.ProtocolDAGResult | None = None, - ) -> None: - # Check we're not trying to extend - if extends: - # This technically should be NotImplementedError - # but gufe.Protocol.validate calls `_validate` wrapped around an - # except for NotImplementedError, so we can't raise it here - raise ValueError("Can't extend simulations yet") - - # Validate the end states - self._validate_endstates(stateA, stateB) - - # Valildate the mapping - alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - self._validate_mapping(mapping, alchem_comps) - - # Validate solvent component - nonbond = self.settings.forcefield_settings.nonbonded_method - system_validation.validate_solvent(stateA, nonbond) - - # Validate solvation settings - settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) - - # Validate protein component - system_validation.validate_protein(stateA) - - # Validate charge difference - # Note: validation depends on the mapping & solvent component checks - if stateA.contains(SolventComponent): - solv_comp = stateA.get_components_of_type(SolventComponent)[0] - else: - solv_comp = None - - self._validate_charge_difference( - mapping=mapping[0] if isinstance(mapping, list) else mapping, - nonbonded_method=self.settings.forcefield_settings.nonbonded_method, - explicit_charge_correction=self.settings.alchemical_settings.explicit_charge_correction, - solvent_component=solv_comp, - ) - - # Validate integrator things - settings_validation.validate_timestep( - self.settings.forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep, - ) - - # Validate simulation & output settings - self._validate_simulation_settings( - self.settings.simulation_settings, - self.settings.integrator_settings, - self.settings.output_settings, - ) - - # Validate alchemical settings - # PR #125 temporarily pin lambda schedule spacing to n_replicas - if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.lambda_windows: - errmsg = ( - "Number of replicas in ``simulation_settings``: " - f"{self.settings.simulation_settings.n_replicas} must equal " - "the number of lambda windows in lambda_settings: " - f"{self.settings.lambda_settings.lambda_windows}." - ) - raise ValueError(errmsg) - - def _create( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], - extends: Optional[gufe.ProtocolDAGResult] = None, - ) -> list[gufe.ProtocolUnit]: - # validate inputs - self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) - - # get alchemical components and mapping - alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - ligandmapping = mapping[0] if isinstance(mapping, list) else mapping - - # actually create and return Units - Anames = ",".join(c.name for c in alchem_comps["stateA"]) - Bnames = ",".join(c.name for c in alchem_comps["stateB"]) - - # our DAG has no dependencies, so just list units - n_repeats = self.settings.protocol_repeats - - units = [ - RelativeHybridTopologyProtocolUnit( - protocol=self, - stateA=stateA, - stateB=stateB, - ligandmapping=ligandmapping, - generation=0, - repeat_id=int(uuid.uuid4()), - name=f"{Anames} to {Bnames} repeat {i} generation 0", - ) - for i in range(n_repeats) - ] - - return units - - def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: - # result units will have a repeat_id and generations within this repeat_id - # first group according to repeat_id - unsorted_repeats = defaultdict(list) - for d in protocol_dag_results: - pu: gufe.ProtocolUnitResult - for pu in d.protocol_unit_results: - if not pu.ok(): - continue - - unsorted_repeats[pu.outputs["repeat_id"]].append(pu) - - # then sort by generation within each repeat_id list - repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} - for k, v in unsorted_repeats.items(): - repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - - # returns a dict of repeat_id: sorted list of ProtocolUnitResult - return repeats - - -class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): - """ - Calculates the relative free energy of an alchemical ligand transformation. - """ - - def __init__( - self, - *, - protocol: RelativeHybridTopologyProtocol, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - ligandmapping: LigandAtomMapping, - generation: int, - repeat_id: int, - name: Optional[str] = None, - ): - """ - Parameters - ---------- - protocol : RelativeHybridTopologyProtocol - protocol used to create this Unit. Contains key information such - as the settings. - stateA, stateB : ChemicalSystem - the two ligand SmallMoleculeComponents to transform between. The - transformation will go from ligandA to ligandB. - ligandmapping : LigandAtomMapping - the mapping of atoms between the two ligand components - repeat_id : int - identifier for which repeat (aka replica/clone) this Unit is - generation : int - counter for how many times this repeat has been extended - name : str, optional - human-readable identifier for this Unit - - Notes - ----- - The mapping used must not involve any elemental changes. A check for - this is done on class creation. - """ - super().__init__( - name=name, - protocol=protocol, - stateA=stateA, - stateB=stateB, - ligandmapping=ligandmapping, - repeat_id=repeat_id, - generation=generation, - ) - - @staticmethod - def _assign_partial_charges( - charge_settings: OpenFFPartialChargeSettings, - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]], - ) -> None: - """ - Assign partial charges to SMCs. - - Parameters - ---------- - charge_settings : OpenFFPartialChargeSettings - Settings for controlling how the partial charges are assigned. - off_small_mols : dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - Dictionary of dictionary of OpenFF Molecules to add, keyed by - state and SmallMoleculeComponent. - """ - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): - charge_generation.assign_offmol_partial_charges( - offmol=mol, - overwrite=False, - method=charge_settings.partial_charge_method, - toolkit_backend=charge_settings.off_toolkit_backend, - generate_n_conformers=charge_settings.number_of_conformers, - nagl_model=charge_settings.nagl_model, - ) - - def run( - self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None - ) -> dict[str, Any]: - """Run the relative free energy calculation. - - Parameters - ---------- - dry : bool - Do a dry run of the calculation, creating all necessary hybrid - system components (topology, system, sampler, etc...) but without - running the simulation. - verbose : bool - Verbose output of the simulation progress. Output is provided via - INFO level logging. - scratch_basepath: Pathlike, optional - Where to store temporary files, defaults to current working directory - shared_basepath : Pathlike, optional - Where to run the calculation, defaults to current working directory - - Returns - ------- - dict - Outputs created in the basepath directory or the debug objects - (i.e. sampler) if ``dry==True``. - - Raises - ------ - error - Exception if anything failed - """ - if verbose: - self.logger.info("Preparing the hybrid topology simulation") - if scratch_basepath is None: - scratch_basepath = pathlib.Path(".") - if shared_basepath is None: - # use cwd - shared_basepath = pathlib.Path(".") - - # 0. General setup and settings dependency resolution step - - # Extract relevant settings - protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs[ - "protocol" - ].settings - stateA = self._inputs["stateA"] - stateB = self._inputs["stateB"] - mapping = self._inputs["ligandmapping"] - - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = ( - protocol_settings.forcefield_settings - ) - thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings - alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings - lambda_settings: LambdaSettings = protocol_settings.lambda_settings - charge_settings: BasePartialChargeSettings = protocol_settings.partial_charge_settings - solvation_settings: OpenMMSolvationSettings = protocol_settings.solvation_settings - sampler_settings: MultiStateSimulationSettings = protocol_settings.simulation_settings - output_settings: MultiStateOutputSettings = protocol_settings.output_settings - integrator_settings: IntegratorSettings = protocol_settings.integrator_settings - - # TODO: Also validate various conversions? - # Convert various time based inputs to steps/iterations - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings=sampler_settings, - integrator_settings=integrator_settings, - ) - - equil_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.equilibration_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - prod_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.production_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - - solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) - - # Get the change difference between the end states - # and check if the charge correction used is appropriate - charge_difference = mapping.get_alchemical_charge_difference() - - # 1. Create stateA system - self.logger.info("Parameterizing molecules") - - # a. create offmol dictionaries and assign partial charges - # workaround for conformer generation failures - # see openfe issue #576 - # calculate partial charges manually if not already given - # convert to OpenFF here, - # and keep the molecule around to maintain the partial charges - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - off_small_mols = { - "stateA": [(mapping.componentA, mapping.componentA.to_openff())], - "stateB": [(mapping.componentB, mapping.componentB.to_openff())], - "both": [ - (m, m.to_openff()) - for m in small_mols - if (m != mapping.componentA and m != mapping.componentB) - ], - } - - self._assign_partial_charges(charge_settings, off_small_mols) - - # b. get a system generator - if output_settings.forcefield_cache is not None: - ffcache = shared_basepath / output_settings.forcefield_cache - else: - ffcache = None - - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - system_generator = system_creation.get_system_generator( - forcefield_settings=forcefield_settings, - integrator_settings=integrator_settings, - thermo_settings=thermo_settings, - cache=ffcache, - has_solvent=solvent_comp is not None, - ) - - # c. force the creation of parameters - # This is necessary because we need to have the FF templates - # registered ahead of solvating the system. - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): - system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) - - # c. get OpenMM Modeller + a dictionary of resids for each component - stateA_modeller, comp_resids = system_creation.get_omm_modeller( - protein_comp=protein_comp, - solvent_comp=solvent_comp, - small_mols=dict(chain(off_small_mols["stateA"], off_small_mols["both"])), - omm_forcefield=system_generator.forcefield, - solvent_settings=solvation_settings, - ) - - # d. get topology & positions - # Note: roundtrip positions to remove vec3 issues - stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) - - # e. create the stateA System - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - stateA_system = system_generator.create_system( - stateA_modeller.topology, - molecules=[m for _, m in chain(off_small_mols["stateA"], off_small_mols["both"])], - ) - - # 2. Get stateB system - # a. get the topology - stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( - stateA_topology, - # zeroth item (there's only one) then get the OFF representation - off_small_mols["stateB"][0][1].to_topology().to_openmm(), - exclude_resids=comp_resids[mapping.componentA], - ) - - # b. get a list of small molecules for stateB - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - stateB_system = system_generator.create_system( - stateB_topology, - molecules=[m for _, m in chain(off_small_mols["stateB"], off_small_mols["both"])], - ) - - # c. Define correspondence mappings between the two systems - ligand_mappings = _rfe_utils.topologyhelpers.get_system_mappings( - mapping.componentA_to_componentB, - stateA_system, - stateA_topology, - comp_resids[mapping.componentA], - stateB_system, - stateB_topology, - stateB_alchem_resids, - # These are non-optional settings for this method - fix_constraints=True, - ) - - # d. if a charge correction is necessary, select alchemical waters - # and transform them - if alchem_settings.explicit_charge_correction: - alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( - stateA_topology, - stateA_positions, - charge_difference, - alchem_settings.explicit_charge_correction_cutoff, - ) - _rfe_utils.topologyhelpers.handle_alchemical_waters( - alchem_water_resids, - stateB_topology, - stateB_system, - ligand_mappings, - charge_difference, - solvent_comp, - ) - - # e. Finally get the positions - stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( - ligand_mappings, - stateA_topology, - stateB_topology, - old_positions=ensure_quantity(stateA_positions, "openmm"), - insert_positions=ensure_quantity( - off_small_mols["stateB"][0][1].conformers[0], "openmm" - ), - ) - - # 3. Create the hybrid topology - # a. Get softcore potential settings - if alchem_settings.softcore_LJ.lower() == "gapsys": - softcore_LJ_v2 = True - elif alchem_settings.softcore_LJ.lower() == "beutler": - softcore_LJ_v2 = False - # b. Get hybrid topology factory - hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( - stateA_system, - stateA_positions, - stateA_topology, - stateB_system, - stateB_positions, - stateB_topology, - old_to_new_atom_map=ligand_mappings["old_to_new_atom_map"], - old_to_new_core_atom_map=ligand_mappings["old_to_new_core_atom_map"], - use_dispersion_correction=alchem_settings.use_dispersion_correction, - softcore_alpha=alchem_settings.softcore_alpha, - softcore_LJ_v2=softcore_LJ_v2, - softcore_LJ_v2_alpha=alchem_settings.softcore_alpha, - interpolate_old_and_new_14s=alchem_settings.turn_off_core_unique_exceptions, - ) - - # 4. Create lambda schedule - # TODO - this should be exposed to users, maybe we should offer the - # ability to print the schedule directly in settings? - # fmt: off - lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( - functions=lambda_settings.lambda_functions, - windows=lambda_settings.lambda_windows - ) - # fmt: on - # PR #125 temporarily pin lambda schedule spacing to n_replicas - n_replicas = sampler_settings.n_replicas - if n_replicas != len(lambdas.lambda_schedule): - errmsg = ( - f"Number of replicas {n_replicas} " - f"does not equal the number of lambda windows " - f"{len(lambdas.lambda_schedule)}" - ) - raise ValueError(errmsg) - - # 9. Create the multistate reporter - # Get the sub selection of the system to print coords for - selection_indices = hybrid_factory.hybrid_topology.select(output_settings.output_indices) - - # a. Create the multistate reporter - # convert checkpoint_interval from time to iterations - chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=output_settings.checkpoint_interval, - time_per_iteration=sampler_settings.time_per_iteration, - ) - - nc = shared_basepath / output_settings.output_filename - chk = output_settings.checkpoint_storage_filename - - if output_settings.positions_write_frequency is not None: - pos_interval = settings_validation.divmod_time_and_check( - numerator=output_settings.positions_write_frequency, - denominator=sampler_settings.time_per_iteration, - numerator_name="output settings' position_write_frequency", - denominator_name="sampler settings' time_per_iteration", - ) - else: - pos_interval = 0 - - if output_settings.velocities_write_frequency is not None: - vel_interval = settings_validation.divmod_time_and_check( - numerator=output_settings.velocities_write_frequency, - denominator=sampler_settings.time_per_iteration, - numerator_name="output settings' velocity_write_frequency", - denominator_name="sampler settings' time_per_iteration", - ) - else: - vel_interval = 0 - - reporter = multistate.MultiStateReporter( - storage=nc, - analysis_particle_indices=selection_indices, - checkpoint_interval=chk_intervals, - checkpoint_storage=chk, - position_interval=pos_interval, - velocity_interval=vel_interval, - ) - - # b. Write out a PDB containing the subsampled hybrid state - # fmt: off - bfactors = np.zeros_like(selection_indices, dtype=float) # solvent - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B - # bfactors[np.in1d(selection_indices, protein)] = 1.0 # prot+cofactor - if len(selection_indices) > 0: - traj = mdtraj.Trajectory( - hybrid_factory.hybrid_positions[selection_indices, :], - hybrid_factory.hybrid_topology.subset(selection_indices), - ).save_pdb( - shared_basepath / output_settings.output_structure, - bfactors=bfactors, - ) - # fmt: on - - # 10. Get compute platform - # restrict to a single CPU if running vacuum - restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" - platform = omm_compute.get_openmm_platform( - platform_name=protocol_settings.engine_settings.compute_platform, - gpu_device_index=protocol_settings.engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu, - ) - - # 11. Set the integrator - # a. Validate integrator settings for current system - # Virtual sites sanity check - ensure we restart velocities when - # there are virtual sites in the system - if hybrid_factory.has_virtual_sites: - if not integrator_settings.reassign_velocities: - errmsg = ( - "Simulations with virtual sites without velocity " - "reassignments are unstable in openmmtools" - ) - raise ValueError(errmsg) - - # b. create langevin integrator - integrator = openmmtools.mcmc.LangevinDynamicsMove( - timestep=to_openmm(integrator_settings.timestep), - collision_rate=to_openmm(integrator_settings.langevin_collision_rate), - n_steps=steps_per_iteration, - reassign_velocities=integrator_settings.reassign_velocities, - n_restart_attempts=integrator_settings.n_restart_attempts, - constraint_tolerance=integrator_settings.constraint_tolerance, - ) - - # 12. Create sampler - self.logger.info("Creating and setting up the sampler") - rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( - simulation_settings=sampler_settings, - ) - # convert early_termination_target_error from kcal/mol to kT - early_termination_target_error = ( - settings_validation.convert_target_error_from_kcal_per_mole_to_kT( - thermo_settings.temperature, - sampler_settings.early_termination_target_error, - ) - ) - - if sampler_settings.sampler_method.lower() == "repex": - sampler = _rfe_utils.multistate.HybridRepexSampler( - mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, - online_analysis_interval=rta_its, - online_analysis_target_error=early_termination_target_error, - online_analysis_minimum_iterations=rta_min_its, - ) - elif sampler_settings.sampler_method.lower() == "sams": - sampler = _rfe_utils.multistate.HybridSAMSSampler( - mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, - online_analysis_interval=rta_its, - online_analysis_minimum_iterations=rta_min_its, - flatness_criteria=sampler_settings.sams_flatness_criteria, - gamma0=sampler_settings.sams_gamma0, - ) - elif sampler_settings.sampler_method.lower() == "independent": - sampler = _rfe_utils.multistate.HybridMultiStateSampler( - mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, - online_analysis_interval=rta_its, - online_analysis_target_error=early_termination_target_error, - online_analysis_minimum_iterations=rta_min_its, - ) - else: - raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") - - sampler.setup( - n_replicas=sampler_settings.n_replicas, - reporter=reporter, - lambda_protocol=lambdas, - temperature=to_openmm(thermo_settings.temperature), - endstates=alchem_settings.endstate_dispersion_correction, - minimization_platform=platform.getName(), - # Set minimization steps to None when running in dry mode - # otherwise do a very small one to avoid NaNs - minimization_steps=100 if not dry else None, - ) - - try: - # Create context caches (energy + sampler) - energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) - - sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) - - sampler.energy_context_cache = energy_context_cache - sampler.sampler_context_cache = sampler_context_cache - - if not dry: # pragma: no-cover - # minimize - if verbose: - self.logger.info("Running minimization") - - sampler.minimize(max_iterations=sampler_settings.minimization_steps) - - # equilibrate - if verbose: - self.logger.info("Running equilibration phase") - - sampler.equilibrate(int(equil_steps / steps_per_iteration)) - - # production - if verbose: - self.logger.info("Running production phase") - - sampler.extend(int(prod_steps / steps_per_iteration)) - - self.logger.info("Production phase complete") - - self.logger.info("Post-simulation analysis of results") - # calculate relevant analyses of the free energies & sampling - # First close & reload the reporter to avoid netcdf clashes - analyzer = multistate_analysis.MultistateEquilFEAnalysis( - reporter, - sampling_method=sampler_settings.sampler_method.lower(), - result_units=unit.kilocalorie_per_mole, - ) - analyzer.plot(filepath=shared_basepath, filename_prefix="") - analyzer.close() - - else: - # clean up the reporter file - fns = [ - shared_basepath / output_settings.output_filename, - shared_basepath / output_settings.checkpoint_storage_filename, - ] - for fn in fns: - os.remove(fn) - finally: - # close reporter when you're done, prevent - # file handle clashes - reporter.close() - - # clear GPU contexts - # TODO: use cache.empty() calls when openmmtools #690 is resolved - # replace with above - for context in list(energy_context_cache._lru._data.keys()): - del energy_context_cache._lru._data[context] - for context in list(sampler_context_cache._lru._data.keys()): - del sampler_context_cache._lru._data[context] - # cautiously clear out the global context cache too - for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): - del openmmtools.cache.global_context_cache._lru._data[context] - - del sampler_context_cache, energy_context_cache - - if not dry: - del integrator, sampler - - if not dry: # pragma: no-cover - return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} - else: - return {"debug": - { - "sampler": sampler, - "hybrid_factory": hybrid_factory - } - } - - @staticmethod - def structural_analysis(scratch, shared) -> dict: - # don't put energy analysis in here, it uses the open file reporter - # whereas structural stuff requires that the file handle is closed - # TODO: we should just make openfe_analysis write an npz instead! - analysis_out = scratch / "structural_analysis.json" - - ret = subprocess.run( - [ - "openfe_analysis", # CLI entry point - "RFE_analysis", # CLI option - str(shared), # Where the simulation.nc fille - str(analysis_out), # Where the analysis json file is written - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - if ret.returncode: - return {"structural_analysis_error": ret.stderr} - - with open(analysis_out, "rb") as f: - data = json.load(f) - - savedir = pathlib.Path(shared) - if d := data["protein_2D_RMSD"]: - fig = plotting.plot_2D_rmsd(d) - fig.savefig(savedir / "protein_2D_RMSD.png") - plt.close(fig) - f2 = plotting.plot_ligand_COM_drift(data["time(ps)"], data["ligand_wander"]) - f2.savefig(savedir / "ligand_COM_drift.png") - plt.close(f2) - - f3 = plotting.plot_ligand_RMSD(data["time(ps)"], data["ligand_RMSD"]) - f3.savefig(savedir / "ligand_RMSD.png") - plt.close(f3) - - # Save to numpy compressed format (~ 6x more space efficient than JSON) - np.savez_compressed( - shared / "structural_analysis.npz", - protein_RMSD=np.asarray(data["protein_RMSD"], dtype=np.float32), - ligand_RMSD=np.asarray(data["ligand_RMSD"], dtype=np.float32), - ligand_COM_drift=np.asarray(data["ligand_wander"], dtype=np.float32), - protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), - time_ps=np.asarray(data["time(ps)"], dtype=np.float32), - ) - - return {"structural_analysis": shared / "structural_analysis.npz"} - - def _execute( - self, - ctx: gufe.Context, - **kwargs, - ) -> dict[str, Any]: - log_system_probe(logging.INFO, paths=[ctx.scratch]) - - outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) - - structural_analysis_outputs = self.structural_analysis(ctx.scratch, ctx.shared) - - return { - "repeat_id": self._inputs["repeat_id"], - "generation": self._inputs["generation"], - **outputs, - **structural_analysis_outputs, - } +from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings +from .hybridtop_unit_results import RelativeHybridTopologyProtocolResult +from .hybridtop_units import RelativeHybridTopologyProtocolUnit +from .hybridtop_protocols import RelativeHybridTopologyProtocol diff --git a/openfe/protocols/openmm_rfe/hybridtop_protocols.py b/openfe/protocols/openmm_rfe/hybridtop_protocols.py new file mode 100644 index 000000000..42bf2ab9a --- /dev/null +++ b/openfe/protocols/openmm_rfe/hybridtop_protocols.py @@ -0,0 +1,571 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Hybrid Topology Protocols using OpenMM and OpenMMTools in a Perses-like manner. + +Acknowledgements +---------------- +These Protocols are based on, and leverages components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" + +from __future__ import annotations + +import logging +import uuid +import warnings +from collections import defaultdict +from typing import Any, Iterable, Optional, Union + +import gufe +from gufe import ( + ChemicalSystem, + Component, + ComponentMapping, + LigandAtomMapping, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, + settings, +) +from openff.units import unit as offunit + +from openfe.due import Doi, due + +from ..openmm_utils import ( + settings_validation, + system_validation, +) +from .equil_rfe_settings import ( + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + RelativeHybridTopologyProtocolSettings, +) +from .hybridtop_unit_results import RelativeHybridTopologyProtocolResult +from .hybridtop_units import RelativeHybridTopologyProtocolUnit + + +logger = logging.getLogger(__name__) + + +due.cite( + Doi("10.5281/zenodo.1297683"), + description="Perses", + path="openfe.protocols.openmm_rfe.hybridtop_protocols", + cite_module=True, +) + +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_rfe.hybridtop_protocols", + cite_module=True, +) + +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_rfe.hybridtop_protocols", + cite_module=True, +) + + +class RelativeHybridTopologyProtocol(gufe.Protocol): + """ + Relative Free Energy calculations using OpenMM and OpenMMTools. + + Based on `Perses `_ + + See Also + -------- + :mod:`openfe.protocols` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologySettings` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyResult` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocolUnit` + """ + + result_cls = RelativeHybridTopologyProtocolResult + _settings_cls = RelativeHybridTopologyProtocolSettings + _settings: RelativeHybridTopologyProtocolSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + return RelativeHybridTopologyProtocolSettings( + protocol_repeats=3, + forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * offunit.kelvin, + pressure=1 * offunit.bar, + ), + partial_charge_settings=OpenFFPartialChargeSettings(), + solvation_settings=OpenMMSolvationSettings(), + alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"), + lambda_settings=LambdaSettings(), + simulation_settings=MultiStateSimulationSettings( + equilibration_length=1.0 * offunit.nanosecond, + production_length=5.0 * offunit.nanosecond, + ), + engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + output_settings=MultiStateOutputSettings(), + ) + + @classmethod + def _adaptive_settings( + cls, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.LigandAtomMapping | list[gufe.LigandAtomMapping], + initial_settings: None | RelativeHybridTopologyProtocolSettings = None, + ) -> RelativeHybridTopologyProtocolSettings: + """ + Get the recommended OpenFE settings for this protocol based on the input states involved in the + transformation. + + These are intended as a suitable starting point for creating an instance of this protocol, which can be further + customized before performing a Protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The initial state of the transformation. + stateB : ChemicalSystem + The final state of the transformation. + mapping : LigandAtomMapping | list[LigandAtomMapping] + The mapping(s) between transforming components in stateA and stateB. + initial_settings : None | RelativeHybridTopologyProtocolSettings, optional + Initial settings to base the adaptive settings on. If None, default settings are used. + + Returns + ------- + RelativeHybridTopologyProtocolSettings + The recommended settings for this protocol based on the input states. + + Notes + ----- + - If the transformation involves a change in net charge, the settings are adapted to use a more expensive + protocol with 22 lambda windows and 20 ns production length per window. + - If both states contain a ProteinComponent, the solvation padding is set to 1 nm. + - If initial_settings is provided, the adaptive settings are based on a copy of these settings. + """ + # use initial settings or default settings + # this is needed for the CLI so we don't override user settings + if initial_settings is not None: + protocol_settings = initial_settings.copy(deep=True) + else: + protocol_settings = cls.default_settings() + + if isinstance(mapping, list): + mapping = mapping[0] + + if mapping.get_alchemical_charge_difference() != 0: + # apply the recommended charge change settings taken from the industry benchmarking as fast settings not validated + # + info = ( + "Charge changing transformation between ligands " + f"{mapping.componentA.name} and {mapping.componentB.name}. " + "A more expensive protocol with 22 lambda windows, sampled " + "for 20 ns each, will be used here." + ) + logger.info(info) + protocol_settings.alchemical_settings.explicit_charge_correction = True + protocol_settings.simulation_settings.production_length = 20 * offunit.nanosecond + protocol_settings.simulation_settings.n_replicas = 22 + protocol_settings.lambda_settings.lambda_windows = 22 + + # adapt the solvation padding based on the system components + if stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent): + protocol_settings.solvation_settings.solvent_padding = 1 * offunit.nanometer + + return protocol_settings + + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the end states for the RFE protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If either state contains more than one unique Component. + * If unique components are not SmallMoleculeComponents. + """ + # Get the difference in Components between each state + diff = stateA.component_diff(stateB) + + for i, entry in enumerate(diff): + state_label = "A" if i == 0 else "B" + + # Check that there is only one unique Component in each state + if len(entry) != 1: + errmsg = ( + "Only one alchemical component is allowed per end state. " + f"Found {len(entry)} in state {state_label}." + ) + raise ValueError(errmsg) + + # Check that the unique Component is a SmallMoleculeComponent + if not isinstance(entry[0], SmallMoleculeComponent): + errmsg = ( + f"Alchemical component in state {state_label} is of type " + f"{type(entry[0])}, but only SmallMoleculeComponents " + "transformations are currently supported." + ) + raise ValueError(errmsg) + + @staticmethod + def _validate_mapping( + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], + alchemical_components: dict[str, list[Component]], + ) -> None: + """ + Validates that the provided mapping(s) are suitable for the RFE protocol. + + Parameters + ---------- + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] + all mappings between transforming components. + alchemical_components : dict[str, list[Component]] + Dictionary contatining the alchemical components for + states A and B. + + Raises + ------ + ValueError + * If there are more than one mapping or mapping is None + * If the mapping components are not in the alchemical components. + UserWarning + * Mappings which involve element changes in core atoms + """ + # if a single mapping is provided, convert to list + if isinstance(mapping, ComponentMapping): + mapping = [mapping] + + # For now we only support a single mapping + if mapping is None or len(mapping) > 1: + errmsg = "A single LigandAtomMapping is expected for this Protocol" + raise ValueError(errmsg) + + # check that the mapping components are in the alchemical components + for m in mapping: + if m.componentA not in alchemical_components["stateA"]: + raise ValueError(f"Mapping componentA {m.componentA} not in alchemical components of stateA") + if m.componentB not in alchemical_components["stateB"]: + raise ValueError(f"Mapping componentB {m.componentB} not in alchemical components of stateB") + + # TODO: remove - this is now the default behaviour? + # Check for element changes in mappings + for m in mapping: + molA = m.componentA.to_rdkit() + molB = m.componentB.to_rdkit() + for i, j in m.componentA_to_componentB.items(): + atomA = molA.GetAtomWithIdx(i) + atomB = molB.GetAtomWithIdx(j) + if atomA.GetAtomicNum() != atomB.GetAtomicNum(): + wmsg = ( + f"Element change in mapping between atoms " + f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " + f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" + "No mass scaling is attempted in the hybrid topology, " + "the average mass of the two atoms will be used in the " + "simulation" + ) + logger.warning(wmsg) + warnings.warn(wmsg) + + @staticmethod + def _validate_charge_difference( + mapping: LigandAtomMapping, + nonbonded_method: str, + explicit_charge_correction: bool, + solvent_component: SolventComponent | None, + ): + """ + Validates the net charge difference between the two states. + + Parameters + ---------- + mapping : dict[str, ComponentMapping] + Dictionary of mappings between transforming components. + nonbonded_method : str + The OpenMM nonbonded method used for the simulation. + explicit_charge_correction : bool + Whether or not to use an explicit charge correction. + solvent_component : openfe.SolventComponent | None + The SolventComponent of the simulation. + + Raises + ------ + ValueError + * If an explicit charge correction is attempted and the + nonbonded method is not PME. + * If the absolute charge difference is greater than one + and an explicit charge correction is attempted. + UserWarning + * If there is any charge difference. + """ + difference = mapping.get_alchemical_charge_difference() + + if abs(difference) == 0: + return + + if not explicit_charge_correction: + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. No charge correction has " + "been requested, please account for this in your " + "final results." + ) + logger.warning(wmsg) + warnings.warn(wmsg) + return + + if solvent_component is None: + errmsg = "Cannot use eplicit charge correction without solvent" + raise ValueError(errmsg) + + # We implicitly check earlier that we have to have pme for a solvated + # system, so we only need to check the nonbonded method here + if nonbonded_method.lower() != "pme": + errmsg = "Explicit charge correction when not using PME is not currently supported." + raise ValueError(errmsg) + + if abs(difference) > 1: + errmsg = ( + f"A charge difference of {difference} is observed " + "between the end states and an explicit charge " + "correction has been requested. Unfortunately " + "only absolute differences of 1 are supported." + ) + raise ValueError(errmsg) + + ion = { + -1: solvent_component.positive_ion, + 1: solvent_component.negative_ion + }[difference] + + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + logger.info(wmsg) + + @staticmethod + def _validate_simulation_settings( + simulation_settings: MultiStateSimulationSettings, + integrator_settings: IntegratorSettings, + output_settings: MultiStateOutputSettings, + ): + """ + Validate various simulation settings, including but not limited to + timestep conversions, and output file write frequencies. + + Parameters + ---------- + simulation_settings : MultiStateSimulationSettings + The sampler simulation settings. + integrator_settings : IntegratorSettings + Settings defining the behaviour of the integrator. + output_settings : MultiStateOutputSettings + Settings defining the simulation file writing behaviour. + + Raises + ------ + ValueError + * If the + """ + + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings=simulation_settings, + integrator_settings=integrator_settings, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + if output_settings.positions_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' positions_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + if output_settings.velocities_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' velocities_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + _, _ = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=simulation_settings, + ) + + def _validate( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> None: + # Check we're not trying to extend + if extends: + # This technically should be NotImplementedError + # but gufe.Protocol.validate calls `_validate` wrapped around an + # except for NotImplementedError, so we can't raise it here + raise ValueError("Can't extend simulations yet") + + # Validate the end states + self._validate_endstates(stateA, stateB) + + # Valildate the mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + self._validate_mapping(mapping, alchem_comps) + + # Validate solvent component + nonbond = self.settings.forcefield_settings.nonbonded_method + system_validation.validate_solvent(stateA, nonbond) + + # Validate solvation settings + settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) + + # Validate protein component + system_validation.validate_protein(stateA) + + # Validate charge difference + # Note: validation depends on the mapping & solvent component checks + if stateA.contains(SolventComponent): + solv_comp = stateA.get_components_of_type(SolventComponent)[0] + else: + solv_comp = None + + self._validate_charge_difference( + mapping=mapping[0] if isinstance(mapping, list) else mapping, + nonbonded_method=self.settings.forcefield_settings.nonbonded_method, + explicit_charge_correction=self.settings.alchemical_settings.explicit_charge_correction, + solvent_component=solv_comp, + ) + + # Validate integrator things + settings_validation.validate_timestep( + self.settings.forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + # Validate simulation & output settings + self._validate_simulation_settings( + self.settings.simulation_settings, + self.settings.integrator_settings, + self.settings.output_settings, + ) + + # Validate alchemical settings + # PR #125 temporarily pin lambda schedule spacing to n_replicas + if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.lambda_windows: + errmsg = ( + "Number of replicas in ``simulation_settings``: " + f"{self.settings.simulation_settings.n_replicas} must equal " + "the number of lambda windows in lambda_settings: " + f"{self.settings.lambda_settings.lambda_windows}." + ) + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # validate inputs + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # get alchemical components and mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + ligandmapping = mapping[0] if isinstance(mapping, list) else mapping + + # actually create and return Units + Anames = ",".join(c.name for c in alchem_comps["stateA"]) + Bnames = ",".join(c.name for c in alchem_comps["stateB"]) + + # our DAG has no dependencies, so just list units + n_repeats = self.settings.protocol_repeats + + units = [ + RelativeHybridTopologyProtocolUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + ligandmapping=ligandmapping, + generation=0, + repeat_id=int(uuid.uuid4()), + name=f"{Anames} to {Bnames} repeat {i} generation 0", + ) + for i in range(n_repeats) + ] + + return units + + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: + # result units will have a repeat_id and generations within this repeat_id + # first group according to repeat_id + unsorted_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + if not pu.ok(): + continue + + unsorted_repeats[pu.outputs["repeat_id"]].append(pu) + + # then sort by generation within each repeat_id list + repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} + for k, v in unsorted_repeats.items(): + repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) + + # returns a dict of repeat_id: sorted list of ProtocolUnitResult + return repeats diff --git a/openfe/protocols/openmm_rfe/hybridtop_unit_results.py b/openfe/protocols/openmm_rfe/hybridtop_unit_results.py new file mode 100644 index 000000000..d3a6dc78d --- /dev/null +++ b/openfe/protocols/openmm_rfe/hybridtop_unit_results.py @@ -0,0 +1,240 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +ProtocolUnitResults for Hybrid Topology methods using +OpenMM and OpenMMTools in a Perses-like manner. +""" + +import logging +import pathlib +import warnings +from typing import Optional, Union + +import gufe +import numpy as np +import numpy.typing as npt +from openff.units import Quantity +from openmmtools import multistate + + +logger = logging.getLogger(__name__) + + +class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): + """Dict-like container for the output of a RelativeHybridTopologyProtocol""" + + def __init__(self, **data): + super().__init__(**data) + # data is mapping of str(repeat_id): list[protocolunitresults] + # TODO: Detect when we have extensions and stitch these together? + if any(len(pur_list) > 2 for pur_list in self.data.values()): + raise NotImplementedError("Can't stitch together results yet") + + @staticmethod + def compute_mean_estimate(dGs: list[Quantity]) -> Quantity: + u = dGs[0].u + # convert all values to units of the first value, then take average of magnitude + # this would avoid a screwy case where each value was in different units + vals = np.asarray([dG.to(u).m for dG in dGs]) + + return np.average(vals) * u + + def get_estimate(self) -> Quantity: + """Average free energy difference of this transformation + + Returns + ------- + dG : openff.units.Quantity + The free energy difference between the first and last states. This is + a Quantity defined with units. + """ + # TODO: Check this holds up completely for SAMS. + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] + return self.compute_mean_estimate(dGs) + + @staticmethod + def compute_uncertainty(dGs: list[Quantity]) -> Quantity: + u = dGs[0].u + # convert all values to units of the first value, then take average of magnitude + # this would avoid a screwy case where each value was in different units + vals = np.asarray([dG.to(u).m for dG in dGs]) + + return np.std(vals) * u + + def get_uncertainty(self) -> Quantity: + """The uncertainty/error in the dG value: The std of the estimates of + each independent repeat + """ + + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] + return self.compute_uncertainty(dGs) + + def get_individual_estimates(self) -> list[tuple[Quantity, Quantity]]: + """Return a list of tuples containing the individual free energy + estimates and associated MBAR errors for each repeat. + + Returns + ------- + dGs : list[tuple[openff.units.Quantity]] + n_replicate simulation list of tuples containing the free energy + estimates (first entry) and associated MBAR estimate errors + (second entry). + """ + dGs = [ + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + for pus in self.data.values() + ] + return dGs + + def get_forward_and_reverse_energy_analysis( + self, + ) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]: + """ + Get a list of forward and reverse analysis of the free energies + for each repeat using uncorrelated production samples. + + The returned dicts have keys: + 'fractions' - the fraction of data used for this estimate + 'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate + 'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty + + The 'fractions' values are a numpy array, while the other arrays are + Quantity arrays, with units attached. + + If the list entry is ``None`` instead of a dictionary, this indicates + that the analysis could not be carried out for that repeat. This + is most likely caused by MBAR convergence issues when attempting to + calculate free energies from too few samples. + + + Returns + ------- + forward_reverse : list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]] + + + Raises + ------ + UserWarning + If any of the forward and reverse entries are ``None``. + """ + forward_reverse = [ + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values() + ] + + if None in forward_reverse: + wmsg = ( + "One or more ``None`` entries were found in the list of " + "forward and reverse analyses. This is likely caused by " + "an MBAR convergence failure caused by too few independent " + "samples when calculating the free energies of the 10% " + "timeseries slice." + ) + warnings.warn(wmsg) + + return forward_reverse + + def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]: + """ + Return a list of dictionary containing the MBAR overlap estimates + calculated for each repeat. + + Returns + ------- + overlap_stats : list[dict[str, npt.NDArray]] + A list of dictionaries containing the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()] + + return overlap_stats + + def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]: + """The replica lambda state transition statistics for each repeat. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : list[dict[str, npt.NDArray]] + A list of dictionaries containing the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + try: + repex_stats = [ + pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values() + ] + except KeyError: + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> list[npt.NDArray]: + """ + Returns the timeseries of replica states for each repeat. + + Returns + ------- + replica_states : List[npt.NDArray] + List of replica states for each repeat + """ + + def is_file(filename: str): + p = pathlib.Path(filename) + if not p.exists(): + errmsg = f"File could not be found {p}" + raise ValueError(errmsg) + return p + + replica_states = [] + + for pus in self.data.values(): + nc = is_file(pus[0].outputs["nc"]) + dir_path = nc.parents[0] + chk = is_file(dir_path / pus[0].outputs["last_checkpoint"]).name + reporter = multistate.MultiStateReporter( + storage=nc, checkpoint_storage=chk, open_mode="r" + ) + replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states())) + reporter.close() + + return replica_states + + def equilibration_iterations(self) -> list[float]: + """ + Returns the number of equilibration iterations for each repeat + of the calculation. + + Returns + ------- + equilibration_lengths : list[float] + """ + equilibration_lengths = [ + pus[0].outputs["equilibration_iterations"] for pus in self.data.values() + ] + + return equilibration_lengths + + def production_iterations(self) -> list[float]: + """ + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : list[float] + """ + production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()] + + return production_lengths diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py new file mode 100644 index 000000000..5b47bb09c --- /dev/null +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -0,0 +1,697 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +ProtocolUnits for Hybrid Topology methods using OpenMM and OpenMMTools in a +Perses-like manner. + +Acknowledgements +---------------- +These ProtocolUnits are based on, and leverage components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" + +import json +import logging +import os +import pathlib +import subprocess +from itertools import chain +from typing import Any, Optional + +import gufe +import matplotlib.pyplot as plt +import mdtraj +import numpy as np +import openmmtools +from gufe import ( + ChemicalSystem, + LigandAtomMapping, + SmallMoleculeComponent, + settings, +) +from openff.toolkit.topology import Molecule as OFFMolecule +from openff.units import unit as offunit +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmmtools import multistate + +from openfe.protocols.openmm_utils.omm_settings import ( + BasePartialChargeSettings, +) + +from ...analysis import plotting +from ...utils import log_system_probe, without_oechem_backend +from ..openmm_utils import ( + charge_generation, + multistate_analysis, + omm_compute, + settings_validation, + system_creation, + system_validation, +) +from . import _rfe_utils +from .equil_rfe_settings import ( + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMSolvationSettings, + RelativeHybridTopologyProtocolSettings, +) + +logger = logging.getLogger(__name__) + + +class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): + """ + Calculates the relative free energy of an alchemical ligand transformation. + """ + def __init__( + self, + *, + protocol: gufe.Protocol, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ligandmapping: LigandAtomMapping, + generation: int, + repeat_id: int, + name: Optional[str] = None, + ): + """ + Parameters + ---------- + protocol : RelativeHybridTopologyProtocol + protocol used to create this Unit. Contains key information such + as the settings. + stateA, stateB : ChemicalSystem + the two ligand SmallMoleculeComponents to transform between. The + transformation will go from ligandA to ligandB. + ligandmapping : LigandAtomMapping + the mapping of atoms between the two ligand components + repeat_id : int + identifier for which repeat (aka replica/clone) this Unit is + generation : int + counter for how many times this repeat has been extended + name : str, optional + human-readable identifier for this Unit + + Notes + ----- + The mapping used must not involve any elemental changes. A check for + this is done on class creation. + """ + super().__init__( + name=name, + protocol=protocol, + stateA=stateA, + stateB=stateB, + ligandmapping=ligandmapping, + repeat_id=repeat_id, + generation=generation, + ) + + @staticmethod + def _assign_partial_charges( + charge_settings: OpenFFPartialChargeSettings, + off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]], + ) -> None: + """ + Assign partial charges to SMCs. + + Parameters + ---------- + charge_settings : OpenFFPartialChargeSettings + Settings for controlling how the partial charges are assigned. + off_small_mols : dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] + Dictionary of dictionary of OpenFF Molecules to add, keyed by + state and SmallMoleculeComponent. + """ + for smc, mol in chain( + off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] + ): + charge_generation.assign_offmol_partial_charges( + offmol=mol, + overwrite=False, + method=charge_settings.partial_charge_method, + toolkit_backend=charge_settings.off_toolkit_backend, + generate_n_conformers=charge_settings.number_of_conformers, + nagl_model=charge_settings.nagl_model, + ) + + def run( + self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + ) -> dict[str, Any]: + """Run the relative free energy calculation. + + Parameters + ---------- + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: Pathlike, optional + Where to store temporary files, defaults to current working directory + shared_basepath : Pathlike, optional + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + + Raises + ------ + error + Exception if anything failed + """ + if verbose: + self.logger.info("Preparing the hybrid topology simulation") + if scratch_basepath is None: + scratch_basepath = pathlib.Path(".") + if shared_basepath is None: + # use cwd + shared_basepath = pathlib.Path(".") + + # 0. General setup and settings dependency resolution step + + # Extract relevant settings + protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs[ + "protocol" + ].settings + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + mapping = self._inputs["ligandmapping"] + + forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = ( + protocol_settings.forcefield_settings + ) + thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings + alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings + lambda_settings: LambdaSettings = protocol_settings.lambda_settings + charge_settings: BasePartialChargeSettings = protocol_settings.partial_charge_settings + solvation_settings: OpenMMSolvationSettings = protocol_settings.solvation_settings + sampler_settings: MultiStateSimulationSettings = protocol_settings.simulation_settings + output_settings: MultiStateOutputSettings = protocol_settings.output_settings + integrator_settings: IntegratorSettings = protocol_settings.integrator_settings + + # TODO: Also validate various conversions? + # Convert various time based inputs to steps/iterations + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings=sampler_settings, + integrator_settings=integrator_settings, + ) + + equil_steps = settings_validation.get_simsteps( + sim_length=sampler_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + prod_steps = settings_validation.get_simsteps( + sim_length=sampler_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) + + # Get the change difference between the end states + # and check if the charge correction used is appropriate + charge_difference = mapping.get_alchemical_charge_difference() + + # 1. Create stateA system + self.logger.info("Parameterizing molecules") + + # a. create offmol dictionaries and assign partial charges + # workaround for conformer generation failures + # see openfe issue #576 + # calculate partial charges manually if not already given + # convert to OpenFF here, + # and keep the molecule around to maintain the partial charges + off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] + off_small_mols = { + "stateA": [(mapping.componentA, mapping.componentA.to_openff())], + "stateB": [(mapping.componentB, mapping.componentB.to_openff())], + "both": [ + (m, m.to_openff()) + for m in small_mols + if (m != mapping.componentA and m != mapping.componentB) + ], + } + + self._assign_partial_charges(charge_settings, off_small_mols) + + # b. get a system generator + if output_settings.forcefield_cache is not None: + ffcache = shared_basepath / output_settings.forcefield_cache + else: + ffcache = None + + # Block out oechem backend in system_generator calls to avoid + # any issues with smiles roundtripping between rdkit and oechem + with without_oechem_backend(): + system_generator = system_creation.get_system_generator( + forcefield_settings=forcefield_settings, + integrator_settings=integrator_settings, + thermo_settings=thermo_settings, + cache=ffcache, + has_solvent=solvent_comp is not None, + ) + + # c. force the creation of parameters + # This is necessary because we need to have the FF templates + # registered ahead of solvating the system. + for smc, mol in chain( + off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] + ): + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) + + # c. get OpenMM Modeller + a dictionary of resids for each component + stateA_modeller, comp_resids = system_creation.get_omm_modeller( + protein_comp=protein_comp, + solvent_comp=solvent_comp, + small_mols=dict(chain(off_small_mols["stateA"], off_small_mols["both"])), + omm_forcefield=system_generator.forcefield, + solvent_settings=solvation_settings, + ) + + # d. get topology & positions + # Note: roundtrip positions to remove vec3 issues + stateA_topology = stateA_modeller.getTopology() + stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) + + # e. create the stateA System + # Block out oechem backend in system_generator calls to avoid + # any issues with smiles roundtripping between rdkit and oechem + with without_oechem_backend(): + stateA_system = system_generator.create_system( + stateA_modeller.topology, + molecules=[m for _, m in chain(off_small_mols["stateA"], off_small_mols["both"])], + ) + + # 2. Get stateB system + # a. get the topology + stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( + stateA_topology, + # zeroth item (there's only one) then get the OFF representation + off_small_mols["stateB"][0][1].to_topology().to_openmm(), + exclude_resids=comp_resids[mapping.componentA], + ) + + # b. get a list of small molecules for stateB + # Block out oechem backend in system_generator calls to avoid + # any issues with smiles roundtripping between rdkit and oechem + with without_oechem_backend(): + stateB_system = system_generator.create_system( + stateB_topology, + molecules=[m for _, m in chain(off_small_mols["stateB"], off_small_mols["both"])], + ) + + # c. Define correspondence mappings between the two systems + ligand_mappings = _rfe_utils.topologyhelpers.get_system_mappings( + mapping.componentA_to_componentB, + stateA_system, + stateA_topology, + comp_resids[mapping.componentA], + stateB_system, + stateB_topology, + stateB_alchem_resids, + # These are non-optional settings for this method + fix_constraints=True, + ) + + # d. if a charge correction is necessary, select alchemical waters + # and transform them + if alchem_settings.explicit_charge_correction: + alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( + stateA_topology, + stateA_positions, + charge_difference, + alchem_settings.explicit_charge_correction_cutoff, + ) + _rfe_utils.topologyhelpers.handle_alchemical_waters( + alchem_water_resids, + stateB_topology, + stateB_system, + ligand_mappings, + charge_difference, + solvent_comp, + ) + + # e. Finally get the positions + stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( + ligand_mappings, + stateA_topology, + stateB_topology, + old_positions=ensure_quantity(stateA_positions, "openmm"), + insert_positions=ensure_quantity( + off_small_mols["stateB"][0][1].conformers[0], "openmm" + ), + ) + + # 3. Create the hybrid topology + # a. Get softcore potential settings + if alchem_settings.softcore_LJ.lower() == "gapsys": + softcore_LJ_v2 = True + elif alchem_settings.softcore_LJ.lower() == "beutler": + softcore_LJ_v2 = False + # b. Get hybrid topology factory + hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( + stateA_system, + stateA_positions, + stateA_topology, + stateB_system, + stateB_positions, + stateB_topology, + old_to_new_atom_map=ligand_mappings["old_to_new_atom_map"], + old_to_new_core_atom_map=ligand_mappings["old_to_new_core_atom_map"], + use_dispersion_correction=alchem_settings.use_dispersion_correction, + softcore_alpha=alchem_settings.softcore_alpha, + softcore_LJ_v2=softcore_LJ_v2, + softcore_LJ_v2_alpha=alchem_settings.softcore_alpha, + interpolate_old_and_new_14s=alchem_settings.turn_off_core_unique_exceptions, + ) + + # 4. Create lambda schedule + # TODO - this should be exposed to users, maybe we should offer the + # ability to print the schedule directly in settings? + # fmt: off + lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( + functions=lambda_settings.lambda_functions, + windows=lambda_settings.lambda_windows + ) + # fmt: on + # PR #125 temporarily pin lambda schedule spacing to n_replicas + n_replicas = sampler_settings.n_replicas + if n_replicas != len(lambdas.lambda_schedule): + errmsg = ( + f"Number of replicas {n_replicas} " + f"does not equal the number of lambda windows " + f"{len(lambdas.lambda_schedule)}" + ) + raise ValueError(errmsg) + + # 9. Create the multistate reporter + # Get the sub selection of the system to print coords for + selection_indices = hybrid_factory.hybrid_topology.select(output_settings.output_indices) + + # a. Create the multistate reporter + # convert checkpoint_interval from time to iterations + chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=sampler_settings.time_per_iteration, + ) + + nc = shared_basepath / output_settings.output_filename + chk = output_settings.checkpoint_storage_filename + + if output_settings.positions_write_frequency is not None: + pos_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=sampler_settings.time_per_iteration, + numerator_name="output settings' position_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + else: + pos_interval = 0 + + if output_settings.velocities_write_frequency is not None: + vel_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=sampler_settings.time_per_iteration, + numerator_name="output settings' velocity_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + else: + vel_interval = 0 + + reporter = multistate.MultiStateReporter( + storage=nc, + analysis_particle_indices=selection_indices, + checkpoint_interval=chk_intervals, + checkpoint_storage=chk, + position_interval=pos_interval, + velocity_interval=vel_interval, + ) + + # b. Write out a PDB containing the subsampled hybrid state + # fmt: off + bfactors = np.zeros_like(selection_indices, dtype=float) # solvent + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B + # bfactors[np.in1d(selection_indices, protein)] = 1.0 # prot+cofactor + if len(selection_indices) > 0: + traj = mdtraj.Trajectory( + hybrid_factory.hybrid_positions[selection_indices, :], + hybrid_factory.hybrid_topology.subset(selection_indices), + ).save_pdb( + shared_basepath / output_settings.output_structure, + bfactors=bfactors, + ) + # fmt: on + + # 10. Get compute platform + # restrict to a single CPU if running vacuum + restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" + platform = omm_compute.get_openmm_platform( + platform_name=protocol_settings.engine_settings.compute_platform, + gpu_device_index=protocol_settings.engine_settings.gpu_device_index, + restrict_cpu_count=restrict_cpu, + ) + + # 11. Set the integrator + # a. Validate integrator settings for current system + # Virtual sites sanity check - ensure we restart velocities when + # there are virtual sites in the system + if hybrid_factory.has_virtual_sites: + if not integrator_settings.reassign_velocities: + errmsg = ( + "Simulations with virtual sites without velocity " + "reassignments are unstable in openmmtools" + ) + raise ValueError(errmsg) + + # b. create langevin integrator + integrator = openmmtools.mcmc.LangevinDynamicsMove( + timestep=to_openmm(integrator_settings.timestep), + collision_rate=to_openmm(integrator_settings.langevin_collision_rate), + n_steps=steps_per_iteration, + reassign_velocities=integrator_settings.reassign_velocities, + n_restart_attempts=integrator_settings.n_restart_attempts, + constraint_tolerance=integrator_settings.constraint_tolerance, + ) + + # 12. Create sampler + self.logger.info("Creating and setting up the sampler") + rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=sampler_settings, + ) + # convert early_termination_target_error from kcal/mol to kT + early_termination_target_error = ( + settings_validation.convert_target_error_from_kcal_per_mole_to_kT( + thermo_settings.temperature, + sampler_settings.early_termination_target_error, + ) + ) + + if sampler_settings.sampler_method.lower() == "repex": + sampler = _rfe_utils.multistate.HybridRepexSampler( + mcmc_moves=integrator, + hybrid_system=hybrid_factory.hybrid_system, + hybrid_positions=hybrid_factory.hybrid_positions, + online_analysis_interval=rta_its, + online_analysis_target_error=early_termination_target_error, + online_analysis_minimum_iterations=rta_min_its, + ) + elif sampler_settings.sampler_method.lower() == "sams": + sampler = _rfe_utils.multistate.HybridSAMSSampler( + mcmc_moves=integrator, + hybrid_system=hybrid_factory.hybrid_system, + hybrid_positions=hybrid_factory.hybrid_positions, + online_analysis_interval=rta_its, + online_analysis_minimum_iterations=rta_min_its, + flatness_criteria=sampler_settings.sams_flatness_criteria, + gamma0=sampler_settings.sams_gamma0, + ) + elif sampler_settings.sampler_method.lower() == "independent": + sampler = _rfe_utils.multistate.HybridMultiStateSampler( + mcmc_moves=integrator, + hybrid_system=hybrid_factory.hybrid_system, + hybrid_positions=hybrid_factory.hybrid_positions, + online_analysis_interval=rta_its, + online_analysis_target_error=early_termination_target_error, + online_analysis_minimum_iterations=rta_min_its, + ) + else: + raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") + + sampler.setup( + n_replicas=sampler_settings.n_replicas, + reporter=reporter, + lambda_protocol=lambdas, + temperature=to_openmm(thermo_settings.temperature), + endstates=alchem_settings.endstate_dispersion_correction, + minimization_platform=platform.getName(), + # Set minimization steps to None when running in dry mode + # otherwise do a very small one to avoid NaNs + minimization_steps=100 if not dry else None, + ) + + try: + # Create context caches (energy + sampler) + energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + + sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + + sampler.energy_context_cache = energy_context_cache + sampler.sampler_context_cache = sampler_context_cache + + if not dry: # pragma: no-cover + # minimize + if verbose: + self.logger.info("Running minimization") + + sampler.minimize(max_iterations=sampler_settings.minimization_steps) + + # equilibrate + if verbose: + self.logger.info("Running equilibration phase") + + sampler.equilibrate(int(equil_steps / steps_per_iteration)) + + # production + if verbose: + self.logger.info("Running production phase") + + sampler.extend(int(prod_steps / steps_per_iteration)) + + self.logger.info("Production phase complete") + + self.logger.info("Post-simulation analysis of results") + # calculate relevant analyses of the free energies & sampling + # First close & reload the reporter to avoid netcdf clashes + analyzer = multistate_analysis.MultistateEquilFEAnalysis( + reporter, + sampling_method=sampler_settings.sampler_method.lower(), + result_units=offunit.kilocalorie_per_mole, + ) + analyzer.plot(filepath=shared_basepath, filename_prefix="") + analyzer.close() + + else: + # clean up the reporter file + fns = [ + shared_basepath / output_settings.output_filename, + shared_basepath / output_settings.checkpoint_storage_filename, + ] + for fn in fns: + os.remove(fn) + finally: + # close reporter when you're done, prevent + # file handle clashes + reporter.close() + + # clear GPU contexts + # TODO: use cache.empty() calls when openmmtools #690 is resolved + # replace with above + for context in list(energy_context_cache._lru._data.keys()): + del energy_context_cache._lru._data[context] + for context in list(sampler_context_cache._lru._data.keys()): + del sampler_context_cache._lru._data[context] + # cautiously clear out the global context cache too + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): + del openmmtools.cache.global_context_cache._lru._data[context] + + del sampler_context_cache, energy_context_cache + + if not dry: + del integrator, sampler + + if not dry: # pragma: no-cover + return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} + else: + return {"debug": + { + "sampler": sampler, + "hybrid_factory": hybrid_factory + } + } + + @staticmethod + def structural_analysis(scratch, shared) -> dict: + # don't put energy analysis in here, it uses the open file reporter + # whereas structural stuff requires that the file handle is closed + # TODO: we should just make openfe_analysis write an npz instead! + analysis_out = scratch / "structural_analysis.json" + + ret = subprocess.run( + [ + "openfe_analysis", # CLI entry point + "RFE_analysis", # CLI option + str(shared), # Where the simulation.nc fille + str(analysis_out), # Where the analysis json file is written + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if ret.returncode: + return {"structural_analysis_error": ret.stderr} + + with open(analysis_out, "rb") as f: + data = json.load(f) + + savedir = pathlib.Path(shared) + if d := data["protein_2D_RMSD"]: + fig = plotting.plot_2D_rmsd(d) + fig.savefig(savedir / "protein_2D_RMSD.png") + plt.close(fig) + f2 = plotting.plot_ligand_COM_drift(data["time(ps)"], data["ligand_wander"]) + f2.savefig(savedir / "ligand_COM_drift.png") + plt.close(f2) + + f3 = plotting.plot_ligand_RMSD(data["time(ps)"], data["ligand_RMSD"]) + f3.savefig(savedir / "ligand_RMSD.png") + plt.close(f3) + + # Save to numpy compressed format (~ 6x more space efficient than JSON) + np.savez_compressed( + shared / "structural_analysis.npz", + protein_RMSD=np.asarray(data["protein_RMSD"], dtype=np.float32), + ligand_RMSD=np.asarray(data["ligand_RMSD"], dtype=np.float32), + ligand_COM_drift=np.asarray(data["ligand_wander"], dtype=np.float32), + protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), + time_ps=np.asarray(data["time(ps)"], dtype=np.float32), + ) + + return {"structural_analysis": shared / "structural_analysis.npz"} + + def _execute( + self, + ctx: gufe.Context, + **kwargs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) + + structural_analysis_outputs = self.structural_analysis(ctx.scratch, ctx.shared) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + **outputs, + **structural_analysis_outputs, + } From 792996e9fd5492537f06427a910d818047c3ede4 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 26 Dec 2025 00:30:39 -0500 Subject: [PATCH 17/36] Add news item --- news/validate-rfe.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 news/validate-rfe.rst diff --git a/news/validate-rfe.rst b/news/validate-rfe.rst new file mode 100644 index 000000000..d7036e8d4 --- /dev/null +++ b/news/validate-rfe.rst @@ -0,0 +1,26 @@ +**Added:** + +* The `validate` method for the RelativeHybridTopologyProtocol has been + implemented. This means that settings and system validation can mostly + be done prior to Protocol execution by calling + `RelativeHybridTopologyProtocol.validate(stateA, stateB, mapping)`. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* From 7d179981e86b8d579951f10314ca14107df94f96 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 26 Dec 2025 20:25:30 -0500 Subject: [PATCH 18/36] fix redefine --- openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 0f7db50d6..349c81d06 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -521,7 +521,7 @@ def test_pos_vel_write_frequency_not_divisible( "attribute", ["real_time_analysis_interval", "real_time_analysis_interval"] ) -def test_pos_vel_write_frequency_not_divisible( +def test_real_time_analysis_not_divisible( benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, From 43eb947872896f350c694d714ced77789b495b0b Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 26 Dec 2025 23:27:58 -0500 Subject: [PATCH 19/36] start modularising everything --- .../protocols/openmm_rfe/hybridtop_units.py | 307 +++++++++++++++--- 1 file changed, 253 insertions(+), 54 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 5b47bb09c..96a8a2389 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -111,25 +111,132 @@ def __init__( generation=generation, ) + def _prepare( + self, + verbose: bool, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, + ): + """ + Set basepaths and do some initial logging. + + Parameters + ---------- + verbose : bool + Verbose output of the simulation progress. Output is provided at the + INFO level logging. + scratch_basepath : pathlib.Path | None + Optional scratch base path to write scratch files to. + shared_basepath : pathlib.Path | None + Optional shared base path to write shared files to. + """ + self.verbose = verbose + + if self.verbose: + self.logger.info("Setting up the hybrid topology simulation") + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path(".") + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + @staticmethod + def _get_settings( + settings: RelativeHybridTopologyProtocolSettings + ) -> dict[str, SettingsBaseModel]: + """ + Get a dictionary of Protocol settings. + + Returns + ------- + protocol_settings : dict[str, SettingsBaseModel] + + Notes + ----- + We return a dict so that we can duck type behaviour between phases. + For example subclasses may contain both `solvent` and `complex` + settings, using this approach we can extract the relevant entry + to the same key and pass it to other methods in a seamless manner. + """ + protocol_settings: dict[str, SettingsBaseModel] = {} + protocol_settings["forcefield_settings"] = settings.forcefield_settings + protocol_settings["thermo_settings"] = settings.thermo_settings + protocol_settings["alchemical_settings"] = settings.alchemical_settings + protocol_settings["lambda_settings"] = settings.lambda_settings + protocol_settings["charge_settings"] = settings.partial_charge_settings + protocol_settings["solvation_settings"] = settings.solvation_settings + protocol_settings["simulation_settings"] = settings.simulation_settings + protocol_settings["output_settings"] = settings.output_settings + protocol_settings["integrator_settings"] = settings.integrator_settings + protocol_settings["engine_settings"] = settings.engine_settings + return protocol_settings + + @staticmethod + def _get_components( + stateA: ChemicalSystem, + stateB: ChemicalSystem + ) -> tuple[ + dict[str, Component], + SolventComponent, + ProteinComponent, + dict[SmallMoleculeComponent, OFFMolecule] + ]: + """ + Get the components from the ChemicalSystem inputs. + + Parameters + ---------- + stateA : ChemicalSystem + ChemicalSystem defining the state A components. + stateB : CHemicalSystem + ChemicalSystem defining the state B components. + + Returns + ------- + alchem_comps : dict[str, Component] + Dictionary of alchemical components. + solv_comp : SolventComponent + The solvent component. + protein_comp : ProteinComponent + The protein component. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of small molecule components paired + with their OpenFF Molecule. + """ + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + + solvent_comp, protein_comp, smcs_A = system_validation.get_components(stateA) + _, _, smcs_B = system_validation.get_components(stateB) + + small_mols = { + m: m.to_openff() + for m in set(smcs_A).union(set(smcs_B)) + } + + return alchem_comps, solvent_comp, protein_comp, small_mols + @staticmethod def _assign_partial_charges( charge_settings: OpenFFPartialChargeSettings, - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]], + small_mols: dict[SmallMoleculeComponent, OFFMolecule], ) -> None: """ - Assign partial charges to SMCs. + Assign partial charges to the OpenFF Molecules associated with all + the SmallMoleculeComponents in the transformation. Parameters ---------- charge_settings : OpenFFPartialChargeSettings Settings for controlling how the partial charges are assigned. - off_small_mols : dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - Dictionary of dictionary of OpenFF Molecules to add, keyed by - state and SmallMoleculeComponent. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules to add, keyed by + their associated SmallMoleculeComponent. """ - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): + for smc, mol in small_mols.items(): charge_generation.assign_offmol_partial_charges( offmol=mol, overwrite=False, @@ -139,6 +246,129 @@ def _assign_partial_charges( nagl_model=charge_settings.nagl_model, ) + @staticmethod + def _get_system_generator( + shared_basepath: pathlib.Path, + settings: dict[str, SettingsBaseModel], + solvent_comp: SolventComponent | None, + openff_molecules: list[OFFMolecule] | None, + ) -> SystemGenerator: + """ + Get an OpenMM SystemGenerator. + + Parameters + ---------- + settings : dict[str, SettingsBaseModel] + A dictionary of protocol settings. + solvent_comp : SolventComponent | None + The solvent component of the system, if any. + openff_molecules : list[openff.Toolkit] | None + A list of openff molecules to generate templates for, if any. + + Returns + ------- + system_generator : openmmtools.SystemGenerator + The SystemGenerator for the protocol. + """ + ffcache = settings["output_settings"].forcefield_cachea + + if ffcache is not None: + ffcache = shared_basepath / ffcache + + # Block out oechem backend in system_generator calls to avoid + # any issues with smiles roundtripping between rdkit and oechem + with without_oechem_backend(): + system_generator = system_creation.get_system_generator( + forcefield_settings=settings["forcefield_settings"], + integrator_settings=settings["integrator_settings"], + thermo_settings=settings["thermo_settings"], + cache=ffcache, + has_solvent=solvent_comp is not None, + ) + + # Handle openff Molecule templates + # TODO: revisit this once the SystemGenerator update happens + # and we start loading the whole protein into OpenFF Topologies + + # First deduplicate isomoprhic molecules + unique_offmols = [] + for mol in openff_molecules: + unique = all( + [ + not mol.is_isomorphic_with(umol) + for umol in unique_offmols + ] + ) + if unique: + unique_offmols.append(mol) + + # register all the templates + system_generator.add_molecules(unique_offmols) + + return system_generator + + def _get_omm_objects( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: LigandAtomMapping, + settings: dict[str, SettingsBaseModel], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + small_mols: dict[SmallMoleculeComponent, OFFMolecule] + ): + """ + Get OpenMM objects for both end states A and B. + + Parameters + ---------- + stateA : ChemicalSystem + ChemicalSystem defining end state A. + stateB : ChmiecalSysstem + ChemicalSystem defining end state B. + mapping : LigandAtomMapping + The mapping for alchemical components between state A and B. + settings : dict[str, SettingsBaseModel] + Settings for the transformation. + protein_component : ProteinComponent | None + The common ProteinComponent between the end states, if there is is one. + solvent_component : SolventComponent | None + The common SolventComponent between the end states, if there is one. + small_mols : dict[SmallMoleculeCOmponent, openff.toolkit.Molecule] + The small molecules for both end states. + + Returns + ------- + stateA_system : openmm.System + OpenMM System for state A. + stateA_topology : openmm.app.Topology + OpenMM Topology for the state A System. + stateA_positions : openmm.unit.Quantity + Positions of partials for state A System. + stateB_system : openmm.System + OpenMM System for state B. + stateB_topology : openmm.app.Topology + OpenMM Topology for the state B System. + stateB_positions : openmm.unit.Quantity + Positions of partials for state B System. + system_mapping : dict[str, dict[int, int]] + Dictionary of mappings defining the correspondance between + the two state Systems. + """ + if self.verbose: + self.logger.info("Parameterizing systems") + + # Get the system generator with all the templates registered + system_generator = self._get_system_generator( + shared_basepath=self.shared_basepath, + settings=settings, + solv_comp=solvent_component, + openff_molecules=list(small_mols.values()) + ) + + .... + + def run( self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None ) -> dict[str, Any]: @@ -169,36 +399,26 @@ def run( error Exception if anything failed """ - if verbose: - self.logger.info("Preparing the hybrid topology simulation") - if scratch_basepath is None: - scratch_basepath = pathlib.Path(".") - if shared_basepath is None: - # use cwd - shared_basepath = pathlib.Path(".") - - # 0. General setup and settings dependency resolution step - - # Extract relevant settings - protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs[ - "protocol" - ].settings + # Prepare paths & verbosity + self._prepare(verbose, scratch_basepath, shared_basepath) + + # Get settings + settings = self._get_settings(self._inputs["protocol"].settings) + + # Get components stateA = self._inputs["stateA"] stateB = self._inputs["stateB"] mapping = self._inputs["ligandmapping"] - - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = ( - protocol_settings.forcefield_settings + alchem_comps, solvent_comp, protein_comp, small_mols = self._get_components( + stateA, stateB ) - thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings - alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings - lambda_settings: LambdaSettings = protocol_settings.lambda_settings - charge_settings: BasePartialChargeSettings = protocol_settings.partial_charge_settings - solvation_settings: OpenMMSolvationSettings = protocol_settings.solvation_settings - sampler_settings: MultiStateSimulationSettings = protocol_settings.simulation_settings - output_settings: MultiStateOutputSettings = protocol_settings.output_settings - integrator_settings: IntegratorSettings = protocol_settings.integrator_settings + # Assign partial charges now to avoid any discrepancies later + self._assign_partial_charges(charge_settings, small_mols) + + + + # TODO: move these down, not needed until we get to the sampler # TODO: Also validate various conversions? # Convert various time based inputs to steps/iterations steps_per_iteration = settings_validation.convert_steps_per_iteration( @@ -217,8 +437,6 @@ def run( mc_steps=steps_per_iteration, ) - solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) - # Get the change difference between the end states # and check if the charge correction used is appropriate charge_difference = mapping.get_alchemical_charge_difference() @@ -226,25 +444,6 @@ def run( # 1. Create stateA system self.logger.info("Parameterizing molecules") - # a. create offmol dictionaries and assign partial charges - # workaround for conformer generation failures - # see openfe issue #576 - # calculate partial charges manually if not already given - # convert to OpenFF here, - # and keep the molecule around to maintain the partial charges - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - off_small_mols = { - "stateA": [(mapping.componentA, mapping.componentA.to_openff())], - "stateB": [(mapping.componentB, mapping.componentB.to_openff())], - "both": [ - (m, m.to_openff()) - for m in small_mols - if (m != mapping.componentA and m != mapping.componentB) - ], - } - - self._assign_partial_charges(charge_settings, off_small_mols) - # b. get a system generator if output_settings.forcefield_cache is not None: ffcache = shared_basepath / output_settings.forcefield_cache From d1bd736414491f41f45ffb1341fca2d0dd36c86f Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 27 Dec 2025 17:23:43 -0500 Subject: [PATCH 20/36] Add charge validation for smcs when dealing with ismorphic molecules --- .../protocols/openmm_rfe/equil_rfe_methods.py | 55 ++++++++++++++++ .../openmm_utils/system_validation.py | 18 +++--- .../openmm_rfe/test_hybrid_top_validation.py | 63 +++++++++++++++++++ 3 files changed, 127 insertions(+), 9 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 189fffe79..6995f82dd 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -560,6 +560,58 @@ def _validate_mapping( logger.warning(wmsg) warnings.warn(wmsg) + @staticmethod + def _validate_smcs( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the SmallMoleculeComponents. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If there are isomorphic SmallMoleculeComponents with + different charges. + """ + smcs_A = stateA.get_components_of_type(SmallMoleculeComponent) + smcs_B = stateB.get_components_of_type(SmallMoleculeComponent) + smcs_all = list(set(smcs_A).union(set(smcs_B))) + offmols = [m.to_openff() for m in smcs_all] + + def _equal_charges(moli, molj): + # Base case, both molecules don't have charges + if (moli.partial_charges is None) & (molj.partial_charges is None): + return True + # If either is None but not the other + if (moli.partial_charges is None) ^ (molj.partial_charges is None): + return False + # Check if the charges are close to each other + return np.allclose(moli.partial_charges, molj.partial_charges) + + clashes = [] + + for i, moli in enumerate(offmols): + for molj in offmols: + if moli.is_isomorphic_with(molj): + if not _equal_charges(moli, molj): + clashes.append(smcs_all[i]) + + if len(clashes) > 0: + errmsg = ( + "Found SmallMoleculeComponents are are isomorphic " + "but with different charges, this is not currently allowed. " + f"Affected components: {clashes}" + ) + raise ValueError(errmsg) + @staticmethod def _validate_charge_difference( mapping: LigandAtomMapping, @@ -726,6 +778,9 @@ def _validate( alchem_comps = system_validation.get_alchemical_components(stateA, stateB) self._validate_mapping(mapping, alchem_comps) + # Validate the small molecule components + self._validate_smcs(stateA, stateB) + # Validate solvent component nonbond = self.settings.forcefield_settings.nonbonded_method system_validation.validate_solvent(stateA, nonbond) diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index 9d67e108f..750f5f565 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -162,24 +162,24 @@ def get_components(state: ChemicalSystem) -> ParseCompRet: small_mols : list[SmallMoleculeComponent] """ - def _get_single_comps(comp_list, comptype): - ret_comps = [comp for comp in comp_list if isinstance(comp, comptype)] - if ret_comps: + def _get_single_comps(state, comptype): + comps = state.get_components_of_type(comptype) + + if len(ret_comps) > 0: return ret_comps[0] else: return None solvent_comp: Optional[SolventComponent] = _get_single_comps( - list(state.values()), SolventComponent + state, + SolventComponent ) protein_comp: Optional[ProteinComponent] = _get_single_comps( - list(state.values()), ProteinComponent + state, + ProteinComponent ) - small_mols = [] - for comp in state.components.values(): - if isinstance(comp, SmallMoleculeComponent): - small_mols.append(comp) + small_mols = state.get_components_of_type(SmallMoleculeComponent) return solvent_comp, protein_comp, small_mols diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 349c81d06..57092ce3c 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -126,6 +126,69 @@ def test_vaccuum_PME_error( ) +@pytest.mark.parametrize('charge', [None, 'gasteiger']) +def test_smcs_same_charge_passes( + charge, + benzene_modifications +): + benzene = benzene_modifications['benzene'] + if charge is None: + smc = benzene + else: + offmol = benzene.to_openff() + offmol.assign_partial_charges(partial_charge_method='gasteiger') + smc = openfe.SmallMoleculeComponent.from_openff(offmol) + + # Just pass the same thing twice + state = openfe.ChemicalSystem({'l': smc}) + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs(state, state) + + +def test_smcs_different_charges_none_not_none( + benzene_modifications +): + # smcA has no charges + smcA = benzene_modifications['benzene'] + + # smcB has charges + offmol = smcA.to_openff() + offmol.assign_partial_charges(partial_charge_method='gasteiger') + smcB = openfe.SmallMoleculeComponent.from_openff(offmol) + + stateA = openfe.ChemicalSystem({'l': smcA}) + stateB = openfe.ChemicalSystem({'l': smcB}) + + errmsg = "isomorphic but with different charges" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs( + stateA, stateB + ) + + +def test_smcs_different_charges_all( + benzene_modifications +): + # For this test, we will assign both A and B to both states + # It wouldn't happen in real life, but it tests that within a state + # you can pick up isomorphic molecules with different charges + # create an offmol with gasteiger charges + offmol = benzene_modifications['benzene'].to_openff() + offmol.assign_partial_charges(partial_charge_method='gasteiger') + smcA = openfe.SmallMoleculeComponent.from_openff(offmol) + + # now alter the offmol charges, scaling by 0.1 + offmol.partial_charges *= 0.1 + smcB = openfe.SmallMoleculeComponent.from_openff(offmol) + + state = openfe.ChemicalSystem({'l1': smcA, 'l2': smcB}) + + errmsg = "isomorphic but with different charges" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs( + state, state + ) + + def test_solvent_nocutoff_error( benzene_system, toluene_system, From 51a6de1da7d9778c92b3aab4d57906c8e314fd3b Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sun, 28 Dec 2025 22:30:32 -0500 Subject: [PATCH 21/36] break down the rfe units into bits --- .../protocols/openmm_rfe/hybridtop_units.py | 1051 ++++++++++++----- 1 file changed, 733 insertions(+), 318 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 96a8a2389..da090e88e 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -18,19 +18,30 @@ from itertools import chain from typing import Any, Optional -import gufe import matplotlib.pyplot as plt import mdtraj import numpy as np +import numpy.typing as npt +import openmm import openmmtools +from openmmforcefields.generators import SystemGenerator + +import gufe +from gufe.settings import ( + SettingsBaseModel, + ThermoSettings, +) from gufe import ( ChemicalSystem, LigandAtomMapping, + Component, + SolventComponent, + ProteinComponent, SmallMoleculeComponent, - settings, ) from openff.toolkit.topology import Molecule as OFFMolecule from openff.units import unit as offunit +from openff.units import Quantity from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmmtools import multistate @@ -49,6 +60,7 @@ system_validation, ) from . import _rfe_utils +from ._rfe_utils.relative import HybridTopologyFactory from .equil_rfe_settings import ( AlchemicalSettings, IntegratorSettings, @@ -57,6 +69,7 @@ MultiStateSimulationSettings, OpenFFPartialChargeSettings, OpenMMSolvationSettings, + OpenMMEngineSettings, RelativeHybridTopologyProtocolSettings, ) @@ -270,7 +283,7 @@ def _get_system_generator( system_generator : openmmtools.SystemGenerator The SystemGenerator for the protocol. """ - ffcache = settings["output_settings"].forcefield_cachea + ffcache = settings["output_settings"].forcefield_cache if ffcache is not None: ffcache = shared_basepath / ffcache @@ -307,6 +320,163 @@ def _get_system_generator( return system_generator + @staticmethod + def _create_stateA_system( + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + system_generator: SystemGenerator, + solvation_settings: OpenMMSolvationSettings, + ) -> tuple[ + openmm.System, + openmm.app.Topology, + openmm.unit.Quantity, + dict[Component, npt.NDArray] + ]: + """ + Create an OpenMM System for state A. + + Parameters + ---------- + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + A list of small molecules to include in the System. + protein_component : ProteinComponent | None + Optionally, the protein component to include in the System. + solvent_component : SolventComponent | None + Optionally, the solvent component to include in the System. + system_generator : SystemGenerator + The SystemGenerator object ot use to construct the System. + solvation_settings : OpenMMSolvationSettings + Settings defining how to build the System. + + Returns + ------- + system : openmm.System + The System that defines state A. + topology : openmm.app.Topology + The Topology defining the returned System. + positions : openmm.unit.Quantity + The positions of the particles in the System. + comp_residues : dict[Component, npt.NDArray] + A dictionary defining which residues in the System + belong to which ChemicalSystem Component. + """ + modeller, comp_resids = system_creation.get_omm_modeller( + protein_comp=protein_component, + solvent_comp=solvent_component, + small_mols=small_mols, + omm_forcefield=system_generator.forcefield, + solvent_settings=solvation_settings, + ) + + topology = modeller.getTopology() + # Note: roundtrip positions to remove vec3 issues + positions = to_openmm(from_openmm(modeller.getPositions())) + + with without_oechem_backend(): + system = system_generator.create_system( + modeller.topology, + molecules=list(small_mols.values()), + ) + + return system, topology, positions, comp_resids + + @staticmethod + def _create_stateB_system( + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + mapping: LigandAtomMapping, + stateA_topology: openmm.app.Topology, + exclude_resids: npt.NDArray, + system_generator: SystemGenerator, + ) -> tuple[openmm.System, openmm.app.Topology, npt.NDArray]: + """ + Create the state B System from the state A Topology. + + Parameters + ---------- + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules keyed by SmallMoleculeComponent + to be present in system B. + mapping : LigandAtomMapping + LigandAtomMapping defining the correspondance betwee state A + and B's alchemical ligand. + stateA_topology : openmm.app.Topology + The OpenMM topology for state A. + exclude_resids : npt.NDArray + A list of residues to exclude from state A when building state B. + system_generator : SystemGenerator + The SystemGenerator to use to build System B. + + Returns + ------- + system : openmm.System + The state B System. + topology : openmm.app.Topology + The OpenMM Topology associated with the state B System. + alchem_resids : npt.NDArray + The residue indices of the state B alchemical species. + """ + topology, alchem_resids = _rfe_utils.topologyhelpers.combined_topology( + topology1=stateA_topology, + topology2=small_mols[mapping.componentB].to_topology().to_openmm(), + exclude_resids=exclude_resids, + ) + + with without_oechem_backend(): + system = system_generator.create_system( + topology, + molecules=list(small_mols.values()), + ) + + return system, topology, alchem_resids + + @staticmethod + def _handle_net_charge( + stateA_topology: openmm.app.Topology, + stateA_positions: openmm.unit.Quantity, + stateB_topology: openmm.app.Topology, + stateB_system: openmm.System, + charge_difference: int, + system_mappings: dict[str, dict[int, int]], + distance_cutoff: Quantity, + solvent_component: SolventComponent | None, + ) -> None: + """ + Handle system net charge by adding an alchemical water. + + Parameters + ---------- + stateA_topology : openmm.app.Topology + stateA_positions : openmm.unit.Quantity + stateB_topology : openmm.app.Topology + stateB_system : openmm.System + charge_difference : int + system_mappings : dict[str, dict[int, int]] + distance_cutoff : Quantity + solvent_component : SolventComponent | None + """ + # Base case, return if no net charge + if charge_difference == 0: + return + + # Get the residue ids for waters to turn alchemical + alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( + topology=stateA_topology, + positions=stateA_positions, + charge_difference=charge_difference, + distance_cutoff=distance_cutoff, + ) + + # In-place modify state B alchemical waters to ions + _rfe_utils.topologyhelpers.handle_alchemical_waters( + water_resids=alchem_water_resids, + topology=stateB_topology, + system=stateB_system, + system_mapping=system_mappings, + charge_difference=charge_difference, + solvent_component=solvent_component, + ) + def _get_omm_objects( self, stateA: ChemicalSystem, @@ -316,7 +486,15 @@ def _get_omm_objects( protein_component: ProteinComponent | None, solvent_component: SolventComponent | None, small_mols: dict[SmallMoleculeComponent, OFFMolecule] - ): + ) -> tuple[ + openmm.System, + openmm.app.Topology, + openmm.unit.Quantity, + openmm.System, + openmm.app.Topology, + openmm.unit.Quantity, + dict[str, dict[int, int]], + ]: """ Get OpenMM objects for both end states A and B. @@ -358,207 +536,140 @@ def _get_omm_objects( if self.verbose: self.logger.info("Parameterizing systems") + # TODO: get two generators, one for state A and one for stateB + # See issue #1120 # Get the system generator with all the templates registered system_generator = self._get_system_generator( shared_basepath=self.shared_basepath, settings=settings, - solv_comp=solvent_component, + solvent_comp=solvent_component, openff_molecules=list(small_mols.values()) ) - .... - - - def run( - self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None - ) -> dict[str, Any]: - """Run the relative free energy calculation. - - Parameters - ---------- - dry : bool - Do a dry run of the calculation, creating all necessary hybrid - system components (topology, system, sampler, etc...) but without - running the simulation. - verbose : bool - Verbose output of the simulation progress. Output is provided via - INFO level logging. - scratch_basepath: Pathlike, optional - Where to store temporary files, defaults to current working directory - shared_basepath : Pathlike, optional - Where to run the calculation, defaults to current working directory - - Returns - ------- - dict - Outputs created in the basepath directory or the debug objects - (i.e. sampler) if ``dry==True``. - - Raises - ------ - error - Exception if anything failed - """ - # Prepare paths & verbosity - self._prepare(verbose, scratch_basepath, shared_basepath) - - # Get settings - settings = self._get_settings(self._inputs["protocol"].settings) - - # Get components - stateA = self._inputs["stateA"] - stateB = self._inputs["stateB"] - mapping = self._inputs["ligandmapping"] - alchem_comps, solvent_comp, protein_comp, small_mols = self._get_components( - stateA, stateB - ) - - # Assign partial charges now to avoid any discrepancies later - self._assign_partial_charges(charge_settings, small_mols) - - - - # TODO: move these down, not needed until we get to the sampler - # TODO: Also validate various conversions? - # Convert various time based inputs to steps/iterations - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings=sampler_settings, - integrator_settings=integrator_settings, - ) + # Create the state A system + small_mols_stateA = { + smc: offmol + for smc, offmol in small_mols.items() + if stateA.contains(smc) + } - equil_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.equilibration_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - prod_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.production_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, + stateA_system, stateA_topology, stateA_positions, comp_resids = self._create_stateA_system( + small_mols=small_mols_stateA, + protein_component=protein_component, + solvent_component=solvent_component, + system_generator=system_generator, + solvation_settings=settings["solvation_settings"] ) - # Get the change difference between the end states - # and check if the charge correction used is appropriate - charge_difference = mapping.get_alchemical_charge_difference() - - # 1. Create stateA system - self.logger.info("Parameterizing molecules") - - # b. get a system generator - if output_settings.forcefield_cache is not None: - ffcache = shared_basepath / output_settings.forcefield_cache - else: - ffcache = None - - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - system_generator = system_creation.get_system_generator( - forcefield_settings=forcefield_settings, - integrator_settings=integrator_settings, - thermo_settings=thermo_settings, - cache=ffcache, - has_solvent=solvent_comp is not None, - ) - - # c. force the creation of parameters - # This is necessary because we need to have the FF templates - # registered ahead of solvating the system. - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): - system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) - - # c. get OpenMM Modeller + a dictionary of resids for each component - stateA_modeller, comp_resids = system_creation.get_omm_modeller( - protein_comp=protein_comp, - solvent_comp=solvent_comp, - small_mols=dict(chain(off_small_mols["stateA"], off_small_mols["both"])), - omm_forcefield=system_generator.forcefield, - solvent_settings=solvation_settings, - ) - - # d. get topology & positions - # Note: roundtrip positions to remove vec3 issues - stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) - - # e. create the stateA System - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - stateA_system = system_generator.create_system( - stateA_modeller.topology, - molecules=[m for _, m in chain(off_small_mols["stateA"], off_small_mols["both"])], - ) + # State B system creation + small_mols_stateB = { + smc: offmol + for smc, offmol in small_mols.items() + if stateB.contains(smc) + } - # 2. Get stateB system - # a. get the topology - stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( - stateA_topology, - # zeroth item (there's only one) then get the OFF representation - off_small_mols["stateB"][0][1].to_topology().to_openmm(), + ( + stateB_system, + stateB_topology, + stateB_alchem_resids + ) = self._create_stateB_system( + small_mols=small_mols_stateB, + mapping=mapping, + stateA_topology=stateA_topology, exclude_resids=comp_resids[mapping.componentA], + system_generator=system_generator, ) - # b. get a list of small molecules for stateB - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - stateB_system = system_generator.create_system( - stateB_topology, - molecules=[m for _, m in chain(off_small_mols["stateB"], off_small_mols["both"])], - ) - - # c. Define correspondence mappings between the two systems - ligand_mappings = _rfe_utils.topologyhelpers.get_system_mappings( - mapping.componentA_to_componentB, - stateA_system, - stateA_topology, - comp_resids[mapping.componentA], - stateB_system, - stateB_topology, - stateB_alchem_resids, + # Get the mapping between the two systems + system_mappings = _rfe_utils.topologyhelpers.get_system_mappings( + old_to_new_atom_map=mapping.componentA_to_componentB, + old_system=stateA_system, + old_topology=stateA_topology, + old_resids=comp_resids[mapping.componentA], + new_system=stateB_system, + new_topology=stateB_topology, + new_resids=stateB_alchem_resids, # These are non-optional settings for this method fix_constraints=True, ) - # d. if a charge correction is necessary, select alchemical waters - # and transform them - if alchem_settings.explicit_charge_correction: - alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( - stateA_topology, - stateA_positions, - charge_difference, - alchem_settings.explicit_charge_correction_cutoff, - ) - _rfe_utils.topologyhelpers.handle_alchemical_waters( - alchem_water_resids, - stateB_topology, - stateB_system, - ligand_mappings, - charge_difference, - solvent_comp, + # Net charge: add alchemical water if needed + # Must be done here as we in-place modify the particles of state B. + if settings["alchemical_settings"].explicit_charge_correction: + self._handle_net_charge( + stateA_topology=stateA_topology, + stateA_positions=stateA_positions, + stateB_topology=stateB_topology, + stateB_system=stateB_system, + charge_difference=mapping.get_alchemical_charge_difference(), + system_mappings=system_mappings, + distance_cutoff=settings["alchemical_settings"].explicit_charge_correction_cutoff, + solvent_component=solvent_component, ) - # e. Finally get the positions + # Finally get the state B positions stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( - ligand_mappings, + system_mappings, stateA_topology, stateB_topology, old_positions=ensure_quantity(stateA_positions, "openmm"), insert_positions=ensure_quantity( - off_small_mols["stateB"][0][1].conformers[0], "openmm" + small_mols[mapping.componentB].conformers[0], "openmm" ), ) - # 3. Create the hybrid topology - # a. Get softcore potential settings - if alchem_settings.softcore_LJ.lower() == "gapsys": + return ( + stateA_system, stateA_topology, stateA_positions, + stateB_system, stateB_topology, stateB_positions, + system_mappings + ) + + @staticmethod + def _get_alchemical_system( + stateA_system: openmm.System, + stateA_positions: openmm.unit.Quantity, + stateA_topology: openmm.app.Topology, + stateB_system: openmm.System, + stateB_positions: openmm.unit.Quantity, + stateB_topology: openmm.app.Topology, + system_mappings: dict[str, dict[int, int]], + alchemical_settings: AlchemicalSettings, + ): + """ + Get the hybrid topology alchemical system. + + Parameters + ---------- + stateA_system : openmm.System + State A OpenMM System + stateA_positions : openmm.unit.Quantity + Positions of state A System + stateA_topology : openmm.app.Topology + Topology of state A System + stateB_system : openmm.System + State B OpenMM System + stateB_positions : openmm.unit.Quantity + Positions of state B System + stateB_topology : openmm.app.Topology + Topology of state B System + system_mappings : dict[str, dict[int, int]] + Mapping of corresponding atoms between the two Systems. + alchemical_settings : AlchemicalSettings + The alchemical settings defining how the alchemical system + will be built. + + Returns + ------- + hybrid_factory : HybridTopologyFactory + The factory creating the hybrid system. + hybrid_system : openmm.System + The hybrid System. + """ + if alchemical_settings.softcore_LJ.lower() == "gapsys": softcore_LJ_v2 = True - elif alchem_settings.softcore_LJ.lower() == "beutler": + elif alchemical_settings.softcore_LJ.lower() == "beutler": softcore_LJ_v2 = False - # b. Get hybrid topology factory + hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( stateA_system, stateA_positions, @@ -566,54 +677,103 @@ def run( stateB_system, stateB_positions, stateB_topology, - old_to_new_atom_map=ligand_mappings["old_to_new_atom_map"], - old_to_new_core_atom_map=ligand_mappings["old_to_new_core_atom_map"], - use_dispersion_correction=alchem_settings.use_dispersion_correction, - softcore_alpha=alchem_settings.softcore_alpha, + old_to_new_atom_map=system_mappings["old_to_new_atom_map"], + old_to_new_core_atom_map=system_mappings["old_to_new_core_atom_map"], + use_dispersion_correction=alchemical_settings.use_dispersion_correction, + softcore_alpha=alchemical_settings.softcore_alpha, softcore_LJ_v2=softcore_LJ_v2, - softcore_LJ_v2_alpha=alchem_settings.softcore_alpha, - interpolate_old_and_new_14s=alchem_settings.turn_off_core_unique_exceptions, + softcore_LJ_v2_alpha=alchemical_settings.softcore_alpha, + interpolate_old_and_new_14s=alchemical_settings.turn_off_core_unique_exceptions, ) - # 4. Create lambda schedule - # TODO - this should be exposed to users, maybe we should offer the - # ability to print the schedule directly in settings? - # fmt: off - lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( - functions=lambda_settings.lambda_functions, - windows=lambda_settings.lambda_windows - ) - # fmt: on - # PR #125 temporarily pin lambda schedule spacing to n_replicas - n_replicas = sampler_settings.n_replicas - if n_replicas != len(lambdas.lambda_schedule): - errmsg = ( - f"Number of replicas {n_replicas} " - f"does not equal the number of lambda windows " - f"{len(lambdas.lambda_schedule)}" + return hybrid_factory, hybrid_factory.hybrid_system + + def _subsample_topology( + self, + hybrid_topology: openmm.app.Topology, + hybrid_positions: openmm.unit.Quantity, + output_selection: str, + output_filename: str, + atom_classes: dict[str, set[int]], + ) -> npt.NDArray: + """ + Subsample the hybrid topology based on user-selected output selection + and write the subsampled topology to a PDB file. + + Parameters + ---------- + hybrid_topology : openmm.app.Topology + The hybrid system topology to subsample. + hybrid_positions : openmm.unit.Quantity + The hybrid system positions. + output_selection : str + An MDTraj selection string to subsample the topology with. + output_filename : str + The name of the file to write the PDB to. + atom_classes : dict[str, set[int]] + A dictionary defining what atoms belong to the different + components of the hybrid system. + + Returns + ------- + selection_indices : npt.NDArray + The indices of the subselected system. + + TODO + ---- + Modify this to also store the full system. + """ + selection_indices = hybrid_topology.select(output_selection) + + # Write out a PDB containing the subsampled hybrid state + # We use bfactors as a hack to label different states + # bfactor of 0 is environment atoms + # bfactor of 0.25 is unique old atoms + # bfactor of 0.5 is core atoms + # bfactor of 0.75 is unique new atoms + bfactors = np.zeros_like(selection_indices, dtype=float) + bfactors[np.isin(selection_indices, list(atom_classes['unique_old_atoms']))] = 0.25 + bfactors[np.isin(selection_indices, list(atom_classes['core_atoms']))] = 0.50 + bfactors[np.isin(selection_indices, list(atom_classes['unique_new_atoms']))] = 0.75 + + if len(selection_indices) > 0: + traj = mdtraj.Trajectory( + hybrid_positions[selection_indices, :], + hybrid_topology.subset(selection_indices), + ).save_pdb( + self.shared_basepath / output_filename, + bfactors=bfactors, ) - raise ValueError(errmsg) - # 9. Create the multistate reporter - # Get the sub selection of the system to print coords for - selection_indices = hybrid_factory.hybrid_topology.select(output_settings.output_indices) + return selection_indices - # a. Create the multistate reporter - # convert checkpoint_interval from time to iterations - chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=output_settings.checkpoint_interval, - time_per_iteration=sampler_settings.time_per_iteration, - ) + def _get_reporter( + self, + selection_indices: npt.NDArray, + output_settings: MultiStateOutputSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> multistate.MultiStateReporter: + """ + Get the multistate reporter. - nc = shared_basepath / output_settings.output_filename + Parameters + ---------- + selection_indices : npt.NDArray + The set of system indices to report positions & velocities for. + output_settings : MultiStateOutputSettings + Settings defining how outputs should be written. + simulation_settings : MultiStateSimulationSettings + Settings defining out the simulation should be run. + """ + nc = self.shared_basepath / output_settings.output_filename chk = output_settings.checkpoint_storage_filename if output_settings.positions_write_frequency is not None: pos_interval = settings_validation.divmod_time_and_check( numerator=output_settings.positions_write_frequency, - denominator=sampler_settings.time_per_iteration, + denominator=simulation_settings.time_per_iteration, numerator_name="output settings' position_write_frequency", - denominator_name="sampler settings' time_per_iteration", + denominator_name="simulation settings' time_per_iteration", ) else: pos_interval = 0 @@ -621,14 +781,19 @@ def run( if output_settings.velocities_write_frequency is not None: vel_interval = settings_validation.divmod_time_and_check( numerator=output_settings.velocities_write_frequency, - denominator=sampler_settings.time_per_iteration, + denominator=simulation_settings.time_per_iteration, numerator_name="output settings' velocity_write_frequency", denominator_name="sampler settings' time_per_iteration", ) else: vel_interval = 0 - reporter = multistate.MultiStateReporter( + chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + return multistate.MultiStateReporter( storage=nc, analysis_particle_indices=selection_indices, checkpoint_interval=chk_intervals, @@ -637,45 +802,39 @@ def run( velocity_interval=vel_interval, ) - # b. Write out a PDB containing the subsampled hybrid state - # fmt: off - bfactors = np.zeros_like(selection_indices, dtype=float) # solvent - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B - # bfactors[np.in1d(selection_indices, protein)] = 1.0 # prot+cofactor - if len(selection_indices) > 0: - traj = mdtraj.Trajectory( - hybrid_factory.hybrid_positions[selection_indices, :], - hybrid_factory.hybrid_topology.subset(selection_indices), - ).save_pdb( - shared_basepath / output_settings.output_structure, - bfactors=bfactors, - ) - # fmt: on + @staticmethod + def _get_integrator( + integrator_settings: IntegratorSettings, + simulation_settings: MultiStateSimulationSettings, + system: openmm.System + ) -> openmmtools.mcmc.LangevinDynamicsMove: + """ + Get and validate the integrator - # 10. Get compute platform - # restrict to a single CPU if running vacuum - restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" - platform = omm_compute.get_openmm_platform( - platform_name=protocol_settings.engine_settings.compute_platform, - gpu_device_index=protocol_settings.engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu, - ) + Parameters + ---------- + integrator_settings : IntegratorSettings + Settings controlling the Langevin integrator. + simulation_settings : MultiStateSimulationSettings + Settings controlling the simulation. + system : openmm.System + The OpenMM System. - # 11. Set the integrator - # a. Validate integrator settings for current system - # Virtual sites sanity check - ensure we restart velocities when - # there are virtual sites in the system - if hybrid_factory.has_virtual_sites: - if not integrator_settings.reassign_velocities: - errmsg = ( - "Simulations with virtual sites without velocity " - "reassignments are unstable in openmmtools" - ) - raise ValueError(errmsg) + Returns + ------- + integrator : openmmtools.mcmc.LangevinDynamicsMove + The LangevinDynamicsMove integrator. + + Raises + ------ + ValueError + If there are virtual sites in the system, but velocities + are not being reassigned after every MCMC move. + """ + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings, integrator_settings + ) - # b. create langevin integrator integrator = openmmtools.mcmc.LangevinDynamicsMove( timestep=to_openmm(integrator_settings.timestep), collision_rate=to_openmm(integrator_settings.langevin_collision_rate), @@ -685,52 +844,112 @@ def run( constraint_tolerance=integrator_settings.constraint_tolerance, ) - # 12. Create sampler - self.logger.info("Creating and setting up the sampler") + # Validate for known issue when dealing with virtual sites + # and mutltistate simulations + if not integrator_settings.reassign_velocities: + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + errmsg = ( + "Simulations with virtual sites without velocity " + "reassignments are unstable with MCMC integrators." + ) + raise ValueError(errmsg) + + return integrator + + @staticmethod + def _get_sampler( + system: openmm.System, + positions: openmm.unit.Quantity, + lambdas: _rfe_utils.lambdaprotocol.LambdaProtocol, + integrator: openmmtools.mcmc.MCMCMove, + reporter: multistate.MultiStateReporter, + simulation_settings: MultiStateSimulationSettings, + thermo_settings: ThermoSettings, + alchem_settings: AlchemicalSettings, + platform: openmm.Platform, + dry: bool, + ) -> multistate.MultiStateSampler: + """ + Get the MultiStateSampler. + + Parameters + ---------- + system : openmm.System + The OpenMM System to simulate. + positions : openmm.unit.Quantity + The positions of the OpenMM System. + lambdas : LambdaProtocol + The lambda protocol to sample along. + integrator : openmmtools.mcmc.MCMCMove + The integrator to use. + reporter : multistate.MultiStateReporter + The reporter to attach to the sampler. + simulation_settings : MultiStateSimulationSettings + The simulation control settings. + thermo_settings : ThermoSettings + The thermodynamic control settings. + alchem_settings : AlchemicalSettings + The alchemical transformation settings. + platform : openmm.Platform + The compute platform to use. + dry : bool + Whether or not this is a dry run. + + Returns + ------- + sampler : multistate.MultiStateSampler + The requested sampler. + """ rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( - simulation_settings=sampler_settings, + simulation_settings=simulation_settings, ) + # convert early_termination_target_error from kcal/mol to kT early_termination_target_error = ( settings_validation.convert_target_error_from_kcal_per_mole_to_kT( thermo_settings.temperature, - sampler_settings.early_termination_target_error, + simulation_settings.early_termination_target_error, ) ) - if sampler_settings.sampler_method.lower() == "repex": + if simulation_settings.sampler_method.lower() == "repex": sampler = _rfe_utils.multistate.HybridRepexSampler( mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, + hybrid_system=system, + hybrid_positions=positions, online_analysis_interval=rta_its, online_analysis_target_error=early_termination_target_error, online_analysis_minimum_iterations=rta_min_its, ) - elif sampler_settings.sampler_method.lower() == "sams": + + elif simulation_settings.sampler_method.lower() == "sams": sampler = _rfe_utils.multistate.HybridSAMSSampler( mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, + hybrid_system=system, + hybrid_positions=positions, online_analysis_interval=rta_its, online_analysis_minimum_iterations=rta_min_its, - flatness_criteria=sampler_settings.sams_flatness_criteria, - gamma0=sampler_settings.sams_gamma0, + flatness_criteria=simulation_settings.sams_flatness_criteria, + gamma0=simulation_settings.sams_gamma0, ) - elif sampler_settings.sampler_method.lower() == "independent": + + elif simulation_settings.sampler_method.lower() == "independent": sampler = _rfe_utils.multistate.HybridMultiStateSampler( mcmc_moves=integrator, - hybrid_system=hybrid_factory.hybrid_system, - hybrid_positions=hybrid_factory.hybrid_positions, + hybrid_system=system, + hybrid_positions=positions, online_analysis_interval=rta_its, online_analysis_target_error=early_termination_target_error, online_analysis_minimum_iterations=rta_min_its, ) + + else: - raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") + raise AttributeError(f"Unknown sampler {simulation_settings.sampler_method}") sampler.setup( - n_replicas=sampler_settings.n_replicas, + n_replicas=simulation_settings.n_replicas, reporter=reporter, lambda_protocol=lambdas, temperature=to_openmm(thermo_settings.temperature), @@ -741,63 +960,256 @@ def run( minimization_steps=100 if not dry else None, ) - try: - # Create context caches (energy + sampler) - energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) + # Get and set the context caches + sampler.energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + sampler.sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) - sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, + return sampler + + def _run_simulation( + self, + sampler: multistate.MultiStateSampler, + reporter: multistate.MultiStateReporter, + simulation_settings : MultiStateSimulationSettings, + integrator_settings : IntegratorSettings, + output_settings : MultiStateOutputSettings, + dry: bool, + ): + """ + Run the simulation. + + Parameters + ---------- + sampler : multistate.MultiStateSampler. + The sampler associated with the simulation to run. + reporter : multistate.MultiStateReporter + The reporter associated with the sampler. + simulation_settings : MultiStateSimulationSettings + Simulation control settings. + integrator_settings : IntegratorSettings + Integrator control settings. + output_settings : MultiStateOutputSettings + Simulation output control settings. + dry : bool + Whether or not to dry run the simulation. + + Returns + ------- + unit_results_dict : dict | None + A dictionary containing the free energy results to report. + ``None`` if it is a dry run. + """ + # Get the relevant simulation steps + mc_steps = settings_validation.convert_steps_per_iteration( + simulation_settings=simulation_settings, + integrator_settings=integrator_settings, + ) + + equil_steps = settings_validation.get_simsteps( + sim_length=simulation_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=mc_steps, + ) + prod_steps = settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=mc_steps, + ) + + if not dry: # pragma: no-cover + # minimize + if self.verbose: + self.logger.info("minimizing systems") + + sampler.minimize(max_iterations=simulation_settings.minimization_steps) + + # equilibrate + if self.verbose: + self.logger.info("equilibrating systems") + + sampler.equilibrate(int(equil_steps / mc_steps)) + + # production + if self.verbose: + self.logger.info("running production phase") + + sampler.extend(int(prod_steps / mc_steps)) + + if self.verbose: + self.logger.info("production phase complete") + + if self.verbose: + self.logger.info("post-simulation result analysis") + + # calculate relevant analysis of the free energies & sampling + analyzer = multistate_analysis.MultistateEquilFEAnalysis( + reporter, + sampling_method=simulation_settings.sampler_method.lower(), + result_units=offunit.kilocalorie_per_mole, ) + analyzer.plot(filepath=self.shared_basepath, filename_prefix="") + analyzer.close() + + return analyzer.unit_results_dict + + else: + # We ran a dry simulation + # close reporter when you're done, prevent file handle clashes + reporter.close() + + # TODO: review this is likely no longer necessary + # clean up the reporter file + fns = [ + self.shared_basepath / output_settings.output_filename, + self.shared_basepath / output_settings.checkpoint_storage_filename, + ] + for fn in fns: + os.remove(fn) + + return None + + def run( + self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + ) -> dict[str, Any]: + """Run the relative free energy calculation. + + Parameters + ---------- + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: Pathlike, optional + Where to store temporary files, defaults to current working directory + shared_basepath : Pathlike, optional + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + + Raises + ------ + error + Exception if anything failed + """ + # Prepare paths & verbosity + self._prepare(verbose, scratch_basepath, shared_basepath) + + # Get settings + settings = self._get_settings(self._inputs["protocol"].settings) + + # Get components + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + mapping = self._inputs["ligandmapping"] + alchem_comps, solvent_comp, protein_comp, small_mols = self._get_components( + stateA, stateB + ) - sampler.energy_context_cache = energy_context_cache - sampler.sampler_context_cache = sampler_context_cache + # Assign partial charges now to avoid any discrepancies later + self._assign_partial_charges(settings["charge_settings"], small_mols) - if not dry: # pragma: no-cover - # minimize - if verbose: - self.logger.info("Running minimization") + ( + stateA_system, stateA_topology, stateA_positions, + stateB_system, stateB_topology, stateB_positions, + system_mappings + ) = self._get_omm_objects( + stateA=stateA, + stateB=stateB, + mapping=mapping, + settings=settings, + protein_component=protein_comp, + solvent_component=solvent_comp, + small_mols=small_mols + ) - sampler.minimize(max_iterations=sampler_settings.minimization_steps) + # Get the hybrid factory & system + hybrid_factory, hybrid_system = self._get_alchemical_system( + stateA_system=stateA_system, + stateA_positions=stateA_positions, + stateA_topology=stateA_topology, + stateB_system=stateB_system, + stateB_positions=stateB_positions, + stateB_topology=stateB_topology, + system_mappings=system_mappings, + alchemical_settings=settings["alchemical_settings"], + ) - # equilibrate - if verbose: - self.logger.info("Running equilibration phase") + # Subselect system based on user inputs & write initial PDB + selection_indices = self._subsample_topology( + hybrid_topology=hybrid_factory.hybrid_topology, + hybrid_positions=hybrid_factory.hybrid_positions, + output_selection=settings["output_settings"].output_indices, + output_filename=settings["output_settings"].output_structure, + atom_classes=hybrid_factory._atom_classes, + ) - sampler.equilibrate(int(equil_steps / steps_per_iteration)) + # Get the lambda schedule + # TODO - this should be better exposed to users + lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( + functions=settings["lambda_settings"].lambda_functions, + windows=settings["lambda_settings"].lambda_windows + ) + + # Get the reporter + reporter = self._get_reporter( + selection_indices=selection_indices, + output_settings=settings["output_settings"], + simulation_settings=settings["simulation_settings"], + ) - # production - if verbose: - self.logger.info("Running production phase") + # Get the compute platform + restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" + platform = omm_compute.get_openmm_platform( + platform_name=settings["engine_settings"].compute_platform, + gpu_device_index=settings["engine_settings"].gpu_device_index, + restrict_cpu_count=restrict_cpu, + ) - sampler.extend(int(prod_steps / steps_per_iteration)) + # Get the integrator + integrator = self._get_integrator( + integrator_settings=settings["integrator_settings"], + simulation_settings=settings["simulation_settings"], + system=hybrid_system + ) - self.logger.info("Production phase complete") + try: + # Get sampler + sampler = self._get_sampler( + system=hybrid_system, + positions=hybrid_factory.hybrid_positions, + lambdas=lambdas, + integrator=integrator, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + thermo_settings=settings["thermo_settings"], + alchem_settings=settings["alchemical_settings"], + platform=platform, + dry=dry + ) - self.logger.info("Post-simulation analysis of results") - # calculate relevant analyses of the free energies & sampling - # First close & reload the reporter to avoid netcdf clashes - analyzer = multistate_analysis.MultistateEquilFEAnalysis( - reporter, - sampling_method=sampler_settings.sampler_method.lower(), - result_units=offunit.kilocalorie_per_mole, - ) - analyzer.plot(filepath=shared_basepath, filename_prefix="") - analyzer.close() - - else: - # clean up the reporter file - fns = [ - shared_basepath / output_settings.output_filename, - shared_basepath / output_settings.checkpoint_storage_filename, - ] - for fn in fns: - os.remove(fn) + unit_results_dict = self._run_simulation( + sampler=sampler, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + integrator_settings=settings["integrator_settings"], + output_settings=settings["output_settings"], + dry=dry, + ) finally: # close reporter when you're done, prevent # file handle clashes @@ -806,26 +1218,29 @@ def run( # clear GPU contexts # TODO: use cache.empty() calls when openmmtools #690 is resolved # replace with above - for context in list(energy_context_cache._lru._data.keys()): - del energy_context_cache._lru._data[context] - for context in list(sampler_context_cache._lru._data.keys()): - del sampler_context_cache._lru._data[context] + for context in list(sampler.energy_context_cache._lru._data.keys()): + del sampler.energy_context_cache._lru._data[context] + for context in list(sampler.sampler_context_cache._lru._data.keys()): + del sampler.sampler_context_cache._lru._data[context] # cautiously clear out the global context cache too for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] - del sampler_context_cache, energy_context_cache + del sampler.sampler_context_cache, sampler.energy_context_cache if not dry: del integrator, sampler if not dry: # pragma: no-cover - return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} + unit_results_dict["nc"] = nc + unit_results_dict["last_checkpoint"] = chk + unit_results_dict["selection_indices"] = selection_indices + return unit_results_dict else: return {"debug": { "sampler": sampler, - "hybrid_factory": hybrid_factory + "hybrid_factory": hybrid_factory, } } From 6a5a76a95957dbbb0722657d86a2604ce62a7875 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 29 Dec 2025 09:59:29 -0500 Subject: [PATCH 22/36] more broadly disallow oechem as a backend when creating systems --- .../protocols/openmm_rfe/hybridtop_units.py | 96 +++++++++---------- 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index da090e88e..14f4485bb 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -373,11 +373,10 @@ def _create_stateA_system( # Note: roundtrip positions to remove vec3 issues positions = to_openmm(from_openmm(modeller.getPositions())) - with without_oechem_backend(): - system = system_generator.create_system( - modeller.topology, - molecules=list(small_mols.values()), - ) + system = system_generator.create_system( + modeller.topology, + molecules=list(small_mols.values()), + ) return system, topology, positions, comp_resids @@ -422,11 +421,10 @@ def _create_stateB_system( exclude_resids=exclude_resids, ) - with without_oechem_backend(): - system = system_generator.create_system( - topology, - molecules=list(small_mols.values()), - ) + system = system_generator.create_system( + topology, + molecules=list(small_mols.values()), + ) return system, topology, alchem_resids @@ -536,49 +534,49 @@ def _get_omm_objects( if self.verbose: self.logger.info("Parameterizing systems") - # TODO: get two generators, one for state A and one for stateB - # See issue #1120 - # Get the system generator with all the templates registered - system_generator = self._get_system_generator( - shared_basepath=self.shared_basepath, - settings=settings, - solvent_comp=solvent_component, - openff_molecules=list(small_mols.values()) - ) + def _filter_small_mols(smols, state): + return { + smc: offmol + for smc, offmol in smols.items() + if state.contains(smc) + } - # Create the state A system - small_mols_stateA = { - smc: offmol - for smc, offmol in small_mols.items() - if stateA.contains(smc) - } + small_mols_stateA = _filter_small_mols(small_mols, stateA) + small_mols_stateB = _filter_small_mols(small_mols, stateB) - stateA_system, stateA_topology, stateA_positions, comp_resids = self._create_stateA_system( - small_mols=small_mols_stateA, - protein_component=protein_component, - solvent_component=solvent_component, - system_generator=system_generator, - solvation_settings=settings["solvation_settings"] - ) + # Everything involving systemgenerator handling has a risk of + # oechem <-> rdkit smiles conversion clashes, cautiously ban it. + with without_oechem_backend(): + # TODO: get two generators, one for state A and one for stateB + # See issue #1120 + # Get the system generator with all the templates registered + system_generator = self._get_system_generator( + shared_basepath=self.shared_basepath, + settings=settings, + solvent_comp=solvent_component, + openff_molecules=list(small_mols.values()) + ) - # State B system creation - small_mols_stateB = { - smc: offmol - for smc, offmol in small_mols.items() - if stateB.contains(smc) - } + ( + stateA_system, stateA_topology, stateA_positions, + comp_resids + ) = self._create_stateA_system( + small_mols=small_mols_stateA, + protein_component=protein_component, + solvent_component=solvent_component, + system_generator=system_generator, + solvation_settings=settings["solvation_settings"] + ) - ( - stateB_system, - stateB_topology, - stateB_alchem_resids - ) = self._create_stateB_system( - small_mols=small_mols_stateB, - mapping=mapping, - stateA_topology=stateA_topology, - exclude_resids=comp_resids[mapping.componentA], - system_generator=system_generator, - ) + ( + stateB_system, stateB_topology, stateB_alchem_resids + ) = self._create_stateB_system( + small_mols=small_mols_stateB, + mapping=mapping, + stateA_topology=stateA_topology, + exclude_resids=comp_resids[mapping.componentA], + system_generator=system_generator, + ) # Get the mapping between the two systems system_mappings = _rfe_utils.topologyhelpers.get_system_mappings( From cdd3da04b3a0110d240e34a4ce1e9231ceb4e687 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 29 Dec 2025 11:33:03 -0500 Subject: [PATCH 23/36] fix issue with nc being undefined --- openfe/protocols/openmm_rfe/hybridtop_units.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 14f4485bb..06d3270e6 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -1230,6 +1230,8 @@ def run( del integrator, sampler if not dry: # pragma: no-cover + nc = self.shared_basepath / settings["output_settings"].output_filename + chk = settings["output_settings"].checkpoint_storage_filename unit_results_dict["nc"] = nc unit_results_dict["last_checkpoint"] = chk unit_results_dict["selection_indices"] = selection_indices From b8268036322106b6018c6b436fa5383c4ccc11e1 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 29 Dec 2025 17:21:18 -0500 Subject: [PATCH 24/36] Fix missing import --- openfe/protocols/openmm_rfe/hybridtop_protocols.py | 1 + 1 file changed, 1 insertion(+) diff --git a/openfe/protocols/openmm_rfe/hybridtop_protocols.py b/openfe/protocols/openmm_rfe/hybridtop_protocols.py index 40accc7e2..b2bcc1ab1 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_protocols.py +++ b/openfe/protocols/openmm_rfe/hybridtop_protocols.py @@ -16,6 +16,7 @@ import warnings from collections import defaultdict from typing import Any, Iterable, Optional, Union +import numpy as np import gufe from gufe import ( From 063e8ced9d18e7889065b0fa5fbabb59ab02ceb1 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 29 Dec 2025 17:41:59 -0500 Subject: [PATCH 25/36] Fix comp getter --- openfe/protocols/openmm_utils/system_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index 750f5f565..7cacaa1f1 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -165,8 +165,8 @@ def get_components(state: ChemicalSystem) -> ParseCompRet: def _get_single_comps(state, comptype): comps = state.get_components_of_type(comptype) - if len(ret_comps) > 0: - return ret_comps[0] + if len(comps) > 0: + return comps[0] else: return None From a98c799af2393f1e1f7b3985f1845579c38b7a38 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Tue, 30 Dec 2025 09:46:24 -0500 Subject: [PATCH 26/36] update module name --- openfe/protocols/openmm_rfe/__init__.py | 2 +- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 2 +- ...{hybridtop_unit_results.py => hybridtop_protocol_results.py} | 0 openfe/protocols/openmm_rfe/hybridtop_protocols.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename openfe/protocols/openmm_rfe/{hybridtop_unit_results.py => hybridtop_protocol_results.py} (100%) diff --git a/openfe/protocols/openmm_rfe/__init__.py b/openfe/protocols/openmm_rfe/__init__.py index 137b641c0..c5b59b543 100644 --- a/openfe/protocols/openmm_rfe/__init__.py +++ b/openfe/protocols/openmm_rfe/__init__.py @@ -3,6 +3,6 @@ from . import _rfe_utils from .hybridtop_protocols import RelativeHybridTopologyProtocol -from .hybridtop_unit_results import RelativeHybridTopologyProtocolResult +from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult from .hybridtop_units import RelativeHybridTopologyProtocolUnit from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 14d3c3eb5..0b3a72cc2 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -17,6 +17,6 @@ """ from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings -from .hybridtop_unit_results import RelativeHybridTopologyProtocolResult +from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult from .hybridtop_units import RelativeHybridTopologyProtocolUnit from .hybridtop_protocols import RelativeHybridTopologyProtocol diff --git a/openfe/protocols/openmm_rfe/hybridtop_unit_results.py b/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py similarity index 100% rename from openfe/protocols/openmm_rfe/hybridtop_unit_results.py rename to openfe/protocols/openmm_rfe/hybridtop_protocol_results.py diff --git a/openfe/protocols/openmm_rfe/hybridtop_protocols.py b/openfe/protocols/openmm_rfe/hybridtop_protocols.py index b2bcc1ab1..41c488d56 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_protocols.py +++ b/openfe/protocols/openmm_rfe/hybridtop_protocols.py @@ -48,7 +48,7 @@ OpenMMSolvationSettings, RelativeHybridTopologyProtocolSettings, ) -from .hybridtop_unit_results import RelativeHybridTopologyProtocolResult +from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult from .hybridtop_units import RelativeHybridTopologyProtocolUnit From 7c915ed69c84f09e6997bb820b8d73ca74e858b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 Jan 2026 01:08:51 +0000 Subject: [PATCH 27/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openfe/protocols/openmm_rfe/__init__.py | 4 +- .../protocols/openmm_rfe/equil_rfe_methods.py | 2 +- .../openmm_rfe/hybridtop_protocol_results.py | 1 - .../openmm_rfe/hybridtop_protocols.py | 23 +- .../protocols/openmm_rfe/hybridtop_units.py | 146 ++++++------ .../openmm_utils/system_validation.py | 12 +- .../openmm_rfe/test_hybrid_top_validation.py | 211 ++++++++---------- 7 files changed, 174 insertions(+), 225 deletions(-) diff --git a/openfe/protocols/openmm_rfe/__init__.py b/openfe/protocols/openmm_rfe/__init__.py index c5b59b543..ca1ae0fd2 100644 --- a/openfe/protocols/openmm_rfe/__init__.py +++ b/openfe/protocols/openmm_rfe/__init__.py @@ -2,7 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/openfe from . import _rfe_utils -from .hybridtop_protocols import RelativeHybridTopologyProtocol +from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult +from .hybridtop_protocols import RelativeHybridTopologyProtocol from .hybridtop_units import RelativeHybridTopologyProtocolUnit -from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 0b3a72cc2..91db1cdc3 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -18,5 +18,5 @@ from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult -from .hybridtop_units import RelativeHybridTopologyProtocolUnit from .hybridtop_protocols import RelativeHybridTopologyProtocol +from .hybridtop_units import RelativeHybridTopologyProtocolUnit diff --git a/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py b/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py index d3a6dc78d..08214a12a 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py +++ b/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py @@ -16,7 +16,6 @@ from openff.units import Quantity from openmmtools import multistate - logger = logging.getLogger(__name__) diff --git a/openfe/protocols/openmm_rfe/hybridtop_protocols.py b/openfe/protocols/openmm_rfe/hybridtop_protocols.py index 41c488d56..00db4a9de 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_protocols.py +++ b/openfe/protocols/openmm_rfe/hybridtop_protocols.py @@ -16,9 +16,9 @@ import warnings from collections import defaultdict from typing import Any, Iterable, Optional, Union -import numpy as np import gufe +import numpy as np from gufe import ( ChemicalSystem, Component, @@ -51,7 +51,6 @@ from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult from .hybridtop_units import RelativeHybridTopologyProtocolUnit - logger = logging.getLogger(__name__) @@ -277,9 +276,13 @@ def _validate_mapping( # check that the mapping components are in the alchemical components for m in mapping: if m.componentA not in alchemical_components["stateA"]: - raise ValueError(f"Mapping componentA {m.componentA} not in alchemical components of stateA") + raise ValueError( + f"Mapping componentA {m.componentA} not in alchemical components of stateA" + ) if m.componentB not in alchemical_components["stateB"]: - raise ValueError(f"Mapping componentB {m.componentB} not in alchemical components of stateB") + raise ValueError( + f"Mapping componentB {m.componentB} not in alchemical components of stateB" + ) # TODO: remove - this is now the default behaviour? # Check for element changes in mappings @@ -419,10 +422,7 @@ def _validate_charge_difference( ) raise ValueError(errmsg) - ion = { - -1: solvent_component.positive_ion, - 1: solvent_component.negative_ion - }[difference] + ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[difference] wmsg = ( f"A charge difference of {difference} is observed " @@ -453,7 +453,7 @@ def _validate_simulation_settings( Raises ------ ValueError - * If the + * If the """ steps_per_iteration = settings_validation.convert_steps_per_iteration( @@ -561,7 +561,10 @@ def _validate( # Validate alchemical settings # PR #125 temporarily pin lambda schedule spacing to n_replicas - if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.lambda_windows: + if ( + self.settings.simulation_settings.n_replicas + != self.settings.lambda_settings.lambda_windows + ): errmsg = ( "Number of replicas in ``simulation_settings``: " f"{self.settings.simulation_settings.n_replicas} must equal " diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 06d3270e6..41c6f1787 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -18,31 +18,30 @@ from itertools import chain from typing import Any, Optional +import gufe import matplotlib.pyplot as plt import mdtraj import numpy as np import numpy.typing as npt import openmm import openmmtools -from openmmforcefields.generators import SystemGenerator - -import gufe -from gufe.settings import ( - SettingsBaseModel, - ThermoSettings, -) from gufe import ( ChemicalSystem, - LigandAtomMapping, Component, - SolventComponent, + LigandAtomMapping, ProteinComponent, SmallMoleculeComponent, + SolventComponent, +) +from gufe.settings import ( + SettingsBaseModel, + ThermoSettings, ) from openff.toolkit.topology import Molecule as OFFMolecule -from openff.units import unit as offunit from openff.units import Quantity +from openff.units import unit as offunit from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmmforcefields.generators import SystemGenerator from openmmtools import multistate from openfe.protocols.openmm_utils.omm_settings import ( @@ -68,8 +67,8 @@ MultiStateOutputSettings, MultiStateSimulationSettings, OpenFFPartialChargeSettings, - OpenMMSolvationSettings, OpenMMEngineSettings, + OpenMMSolvationSettings, RelativeHybridTopologyProtocolSettings, ) @@ -80,6 +79,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): """ Calculates the relative free energy of an alchemical ligand transformation. """ + def __init__( self, *, @@ -159,7 +159,7 @@ def _set_optional_path(basepath): @staticmethod def _get_settings( - settings: RelativeHybridTopologyProtocolSettings + settings: RelativeHybridTopologyProtocolSettings, ) -> dict[str, SettingsBaseModel]: """ Get a dictionary of Protocol settings. @@ -190,14 +190,13 @@ def _get_settings( @staticmethod def _get_components( - stateA: ChemicalSystem, - stateB: ChemicalSystem - ) -> tuple[ - dict[str, Component], - SolventComponent, - ProteinComponent, - dict[SmallMoleculeComponent, OFFMolecule] - ]: + stateA: ChemicalSystem, stateB: ChemicalSystem + ) -> tuple[ + dict[str, Component], + SolventComponent, + ProteinComponent, + dict[SmallMoleculeComponent, OFFMolecule], + ]: """ Get the components from the ChemicalSystem inputs. @@ -225,10 +224,7 @@ def _get_components( solvent_comp, protein_comp, smcs_A = system_validation.get_components(stateA) _, _, smcs_B = system_validation.get_components(stateB) - small_mols = { - m: m.to_openff() - for m in set(smcs_A).union(set(smcs_B)) - } + small_mols = {m: m.to_openff() for m in set(smcs_A).union(set(smcs_B))} return alchem_comps, solvent_comp, protein_comp, small_mols @@ -277,7 +273,7 @@ def _get_system_generator( The solvent component of the system, if any. openff_molecules : list[openff.Toolkit] | None A list of openff molecules to generate templates for, if any. - + Returns ------- system_generator : openmmtools.SystemGenerator @@ -302,22 +298,17 @@ def _get_system_generator( # Handle openff Molecule templates # TODO: revisit this once the SystemGenerator update happens # and we start loading the whole protein into OpenFF Topologies - + # First deduplicate isomoprhic molecules unique_offmols = [] for mol in openff_molecules: - unique = all( - [ - not mol.is_isomorphic_with(umol) - for umol in unique_offmols - ] - ) + unique = all([not mol.is_isomorphic_with(umol) for umol in unique_offmols]) if unique: unique_offmols.append(mol) # register all the templates system_generator.add_molecules(unique_offmols) - + return system_generator @staticmethod @@ -328,10 +319,7 @@ def _create_stateA_system( system_generator: SystemGenerator, solvation_settings: OpenMMSolvationSettings, ) -> tuple[ - openmm.System, - openmm.app.Topology, - openmm.unit.Quantity, - dict[Component, npt.NDArray] + openmm.System, openmm.app.Topology, openmm.unit.Quantity, dict[Component, npt.NDArray] ]: """ Create an OpenMM System for state A. @@ -483,7 +471,7 @@ def _get_omm_objects( settings: dict[str, SettingsBaseModel], protein_component: ProteinComponent | None, solvent_component: SolventComponent | None, - small_mols: dict[SmallMoleculeComponent, OFFMolecule] + small_mols: dict[SmallMoleculeComponent, OFFMolecule], ) -> tuple[ openmm.System, openmm.app.Topology, @@ -535,11 +523,7 @@ def _get_omm_objects( self.logger.info("Parameterizing systems") def _filter_small_mols(smols, state): - return { - smc: offmol - for smc, offmol in smols.items() - if state.contains(smc) - } + return {smc: offmol for smc, offmol in smols.items() if state.contains(smc)} small_mols_stateA = _filter_small_mols(small_mols, stateA) small_mols_stateB = _filter_small_mols(small_mols, stateB) @@ -554,23 +538,20 @@ def _filter_small_mols(smols, state): shared_basepath=self.shared_basepath, settings=settings, solvent_comp=solvent_component, - openff_molecules=list(small_mols.values()) + openff_molecules=list(small_mols.values()), ) - ( - stateA_system, stateA_topology, stateA_positions, - comp_resids - ) = self._create_stateA_system( - small_mols=small_mols_stateA, - protein_component=protein_component, - solvent_component=solvent_component, - system_generator=system_generator, - solvation_settings=settings["solvation_settings"] + (stateA_system, stateA_topology, stateA_positions, comp_resids) = ( + self._create_stateA_system( + small_mols=small_mols_stateA, + protein_component=protein_component, + solvent_component=solvent_component, + system_generator=system_generator, + solvation_settings=settings["solvation_settings"], + ) ) - ( - stateB_system, stateB_topology, stateB_alchem_resids - ) = self._create_stateB_system( + (stateB_system, stateB_topology, stateB_alchem_resids) = self._create_stateB_system( small_mols=small_mols_stateB, mapping=mapping, stateA_topology=stateA_topology, @@ -617,9 +598,13 @@ def _filter_small_mols(smols, state): ) return ( - stateA_system, stateA_topology, stateA_positions, - stateB_system, stateB_topology, stateB_positions, - system_mappings + stateA_system, + stateA_topology, + stateA_positions, + stateB_system, + stateB_topology, + stateB_positions, + system_mappings, ) @staticmethod @@ -730,9 +715,9 @@ def _subsample_topology( # bfactor of 0.5 is core atoms # bfactor of 0.75 is unique new atoms bfactors = np.zeros_like(selection_indices, dtype=float) - bfactors[np.isin(selection_indices, list(atom_classes['unique_old_atoms']))] = 0.25 - bfactors[np.isin(selection_indices, list(atom_classes['core_atoms']))] = 0.50 - bfactors[np.isin(selection_indices, list(atom_classes['unique_new_atoms']))] = 0.75 + bfactors[np.isin(selection_indices, list(atom_classes["unique_old_atoms"]))] = 0.25 + bfactors[np.isin(selection_indices, list(atom_classes["core_atoms"]))] = 0.50 + bfactors[np.isin(selection_indices, list(atom_classes["unique_new_atoms"]))] = 0.75 if len(selection_indices) > 0: traj = mdtraj.Trajectory( @@ -804,7 +789,7 @@ def _get_reporter( def _get_integrator( integrator_settings: IntegratorSettings, simulation_settings: MultiStateSimulationSettings, - system: openmm.System + system: openmm.System, ) -> openmmtools.mcmc.LangevinDynamicsMove: """ Get and validate the integrator @@ -902,7 +887,7 @@ def _get_sampler( rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( simulation_settings=simulation_settings, ) - + # convert early_termination_target_error from kcal/mol to kT early_termination_target_error = ( settings_validation.convert_target_error_from_kcal_per_mole_to_kT( @@ -942,7 +927,6 @@ def _get_sampler( online_analysis_minimum_iterations=rta_min_its, ) - else: raise AttributeError(f"Unknown sampler {simulation_settings.sampler_method}") @@ -976,9 +960,9 @@ def _run_simulation( self, sampler: multistate.MultiStateSampler, reporter: multistate.MultiStateReporter, - simulation_settings : MultiStateSimulationSettings, - integrator_settings : IntegratorSettings, - output_settings : MultiStateOutputSettings, + simulation_settings: MultiStateSimulationSettings, + integrator_settings: IntegratorSettings, + output_settings: MultiStateOutputSettings, dry: bool, ): """ @@ -1114,17 +1098,19 @@ def run( stateA = self._inputs["stateA"] stateB = self._inputs["stateB"] mapping = self._inputs["ligandmapping"] - alchem_comps, solvent_comp, protein_comp, small_mols = self._get_components( - stateA, stateB - ) + alchem_comps, solvent_comp, protein_comp, small_mols = self._get_components(stateA, stateB) # Assign partial charges now to avoid any discrepancies later self._assign_partial_charges(settings["charge_settings"], small_mols) ( - stateA_system, stateA_topology, stateA_positions, - stateB_system, stateB_topology, stateB_positions, - system_mappings + stateA_system, + stateA_topology, + stateA_positions, + stateB_system, + stateB_topology, + stateB_positions, + system_mappings, ) = self._get_omm_objects( stateA=stateA, stateB=stateB, @@ -1132,7 +1118,7 @@ def run( settings=settings, protein_component=protein_comp, solvent_component=solvent_comp, - small_mols=small_mols + small_mols=small_mols, ) # Get the hybrid factory & system @@ -1160,7 +1146,7 @@ def run( # TODO - this should be better exposed to users lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( functions=settings["lambda_settings"].lambda_functions, - windows=settings["lambda_settings"].lambda_windows + windows=settings["lambda_settings"].lambda_windows, ) # Get the reporter @@ -1182,7 +1168,7 @@ def run( integrator = self._get_integrator( integrator_settings=settings["integrator_settings"], simulation_settings=settings["simulation_settings"], - system=hybrid_system + system=hybrid_system, ) try: @@ -1197,7 +1183,7 @@ def run( thermo_settings=settings["thermo_settings"], alchem_settings=settings["alchemical_settings"], platform=platform, - dry=dry + dry=dry, ) unit_results_dict = self._run_simulation( @@ -1237,8 +1223,8 @@ def run( unit_results_dict["selection_indices"] = selection_indices return unit_results_dict else: - return {"debug": - { + return { + "debug": { "sampler": sampler, "hybrid_factory": hybrid_factory, } diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index 7cacaa1f1..3e8ed5c50 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -170,15 +170,9 @@ def _get_single_comps(state, comptype): else: return None - solvent_comp: Optional[SolventComponent] = _get_single_comps( - state, - SolventComponent - ) - - protein_comp: Optional[ProteinComponent] = _get_single_comps( - state, - ProteinComponent - ) + solvent_comp: Optional[SolventComponent] = _get_single_comps(state, SolventComponent) + + protein_comp: Optional[ProteinComponent] = _get_single_comps(state, ProteinComponent) small_mols = state.get_components_of_type(SmallMoleculeComponent) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 57092ce3c..08597df10 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -33,49 +33,54 @@ def test_invalid_protocol_repeats(): settings.protocol_repeats = -1 -@pytest.mark.parametrize('state', ['A', 'B']) +@pytest.mark.parametrize("state", ["A", "B"]) def test_endstate_two_alchemcomp_stateA(state, benzene_modifications): - first_state = openfe.ChemicalSystem({ - 'ligandA': benzene_modifications['benzene'], - 'ligandB': benzene_modifications['toluene'], - 'solvent': openfe.SolventComponent(), - }) - other_state = openfe.ChemicalSystem({ - 'ligandC': benzene_modifications['phenol'], - 'solvent': openfe.SolventComponent(), - }) - - if state == 'A': + first_state = openfe.ChemicalSystem( + { + "ligandA": benzene_modifications["benzene"], + "ligandB": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent(), + } + ) + other_state = openfe.ChemicalSystem( + { + "ligandC": benzene_modifications["phenol"], + "solvent": openfe.SolventComponent(), + } + ) + + if state == "A": args = (first_state, other_state) else: args = (other_state, first_state) with pytest.raises(ValueError, match="Only one alchemical component"): - openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates( - *args - ) + openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates(*args) -@pytest.mark.parametrize('state', ['A', 'B']) + +@pytest.mark.parametrize("state", ["A", "B"]) def test_endstates_not_smc(state, benzene_modifications): - first_state = openfe.ChemicalSystem({ - 'ligand': benzene_modifications['benzene'], - 'foo': openfe.SolventComponent(), - }) - other_state = openfe.ChemicalSystem({ - 'ligand': benzene_modifications['benzene'], - 'foo': benzene_modifications['toluene'], - }) - - if state == 'A': + first_state = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["benzene"], + "foo": openfe.SolventComponent(), + } + ) + other_state = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["benzene"], + "foo": benzene_modifications["toluene"], + } + ) + + if state == "A": args = (first_state, other_state) else: args = (other_state, first_state) errmsg = "only SmallMoleculeComponents transformations" with pytest.raises(ValueError, match=errmsg): - openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates( - *args - ) + openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates(*args) def test_validate_mapping_none_mapping(): @@ -88,12 +93,11 @@ def test_validate_mapping_multi_mapping(benzene_to_toluene_mapping): errmsg = "A single LigandAtomMapping is expected" with pytest.raises(ValueError, match=errmsg): openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping( - [benzene_to_toluene_mapping] * 2, - None + [benzene_to_toluene_mapping] * 2, None ) -@pytest.mark.parametrize('state', ['A', 'B']) +@pytest.mark.parametrize("state", ["A", "B"]) def test_validate_mapping_alchem_not_in(state, benzene_to_toluene_mapping): errmsg = f"not in alchemical components of state{state}" @@ -110,10 +114,7 @@ def test_validate_mapping_alchem_not_in(state, benzene_to_toluene_mapping): def test_vaccuum_PME_error( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - solv_settings + benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, solv_settings ): p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) @@ -126,67 +127,56 @@ def test_vaccuum_PME_error( ) -@pytest.mark.parametrize('charge', [None, 'gasteiger']) -def test_smcs_same_charge_passes( - charge, - benzene_modifications -): - benzene = benzene_modifications['benzene'] +@pytest.mark.parametrize("charge", [None, "gasteiger"]) +def test_smcs_same_charge_passes(charge, benzene_modifications): + benzene = benzene_modifications["benzene"] if charge is None: smc = benzene else: offmol = benzene.to_openff() - offmol.assign_partial_charges(partial_charge_method='gasteiger') + offmol.assign_partial_charges(partial_charge_method="gasteiger") smc = openfe.SmallMoleculeComponent.from_openff(offmol) # Just pass the same thing twice - state = openfe.ChemicalSystem({'l': smc}) + state = openfe.ChemicalSystem({"l": smc}) openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs(state, state) -def test_smcs_different_charges_none_not_none( - benzene_modifications -): +def test_smcs_different_charges_none_not_none(benzene_modifications): # smcA has no charges - smcA = benzene_modifications['benzene'] + smcA = benzene_modifications["benzene"] # smcB has charges offmol = smcA.to_openff() - offmol.assign_partial_charges(partial_charge_method='gasteiger') + offmol.assign_partial_charges(partial_charge_method="gasteiger") smcB = openfe.SmallMoleculeComponent.from_openff(offmol) - stateA = openfe.ChemicalSystem({'l': smcA}) - stateB = openfe.ChemicalSystem({'l': smcB}) + stateA = openfe.ChemicalSystem({"l": smcA}) + stateB = openfe.ChemicalSystem({"l": smcB}) errmsg = "isomorphic but with different charges" with pytest.raises(ValueError, match=errmsg): - openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs( - stateA, stateB - ) + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs(stateA, stateB) -def test_smcs_different_charges_all( - benzene_modifications -): +def test_smcs_different_charges_all(benzene_modifications): # For this test, we will assign both A and B to both states # It wouldn't happen in real life, but it tests that within a state # you can pick up isomorphic molecules with different charges # create an offmol with gasteiger charges - offmol = benzene_modifications['benzene'].to_openff() - offmol.assign_partial_charges(partial_charge_method='gasteiger') + offmol = benzene_modifications["benzene"].to_openff() + offmol.assign_partial_charges(partial_charge_method="gasteiger") smcA = openfe.SmallMoleculeComponent.from_openff(offmol) # now alter the offmol charges, scaling by 0.1 offmol.partial_charges *= 0.1 smcB = openfe.SmallMoleculeComponent.from_openff(offmol) - state = openfe.ChemicalSystem({'l1': smcA, 'l2': smcB}) + state = openfe.ChemicalSystem({"l1": smcA, "l2": smcB}) errmsg = "isomorphic but with different charges" with pytest.raises(ValueError, match=errmsg): - openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs( - state, state - ) + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs(state, state) def test_solvent_nocutoff_error( @@ -212,20 +202,15 @@ def test_nonwater_solvent_error( benzene_to_toluene_mapping, solv_settings, ): - solvent = openfe.SolventComponent(smiles='C') + solvent = openfe.SolventComponent(smiles="C") stateA = openfe.ChemicalSystem( { - 'ligand': benzene_modifications['benzene'], - 'solvent': solvent, + "ligand": benzene_modifications["benzene"], + "solvent": solvent, } ) - stateB = openfe.ChemicalSystem( - { - 'ligand': benzene_modifications['toluene'], - 'solvent': solvent - } - ) + stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"], "solvent": solvent}) p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) @@ -246,17 +231,17 @@ def test_too_many_solv_comps_error( ): stateA = openfe.ChemicalSystem( { - 'ligand': benzene_modifications['benzene'], - 'solvent!': openfe.SolventComponent(neutralize=True), - 'solvent2': openfe.SolventComponent(neutralize=False), + "ligand": benzene_modifications["benzene"], + "solvent!": openfe.SolventComponent(neutralize=True), + "solvent2": openfe.SolventComponent(neutralize=False), } ) stateB = openfe.ChemicalSystem( { - 'ligand': benzene_modifications['toluene'], - 'solvent!': openfe.SolventComponent(neutralize=True), - 'solvent2': openfe.SolventComponent(neutralize=False), + "ligand": benzene_modifications["toluene"], + "solvent!": openfe.SolventComponent(neutralize=True), + "solvent2": openfe.SolventComponent(neutralize=False), } ) @@ -290,11 +275,7 @@ def test_bad_solv_settings( errmsg = "Only one of solvent_padding, number_of_solvent_molecules," with pytest.raises(ValueError, match=errmsg): - p.validate( - stateA=benzene_system, - stateB=toluene_system, - mapping=benzene_to_toluene_mapping - ) + p.validate(stateA=benzene_system, stateB=toluene_system, mapping=benzene_to_toluene_mapping) def test_too_many_prot_comps_error( @@ -304,22 +285,21 @@ def test_too_many_prot_comps_error( eg5_protein, solv_settings, ): - stateA = openfe.ChemicalSystem( { - 'ligand': benzene_modifications['benzene'], - 'solvent': openfe.SolventComponent(), - 'protein1': T4_protein_component, - 'protein2': eg5_protein, + "ligand": benzene_modifications["benzene"], + "solvent": openfe.SolventComponent(), + "protein1": T4_protein_component, + "protein2": eg5_protein, } ) stateB = openfe.ChemicalSystem( { - 'ligand': benzene_modifications['toluene'], - 'solvent': openfe.SolventComponent(), - 'protein1': T4_protein_component, - 'protein2': eg5_protein, + "ligand": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent(), + "protein1": T4_protein_component, + "protein2": eg5_protein, } ) @@ -419,19 +399,16 @@ def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): def test_get_charge_difference(mapping_name, result, request, caplog): mapping = request.getfixturevalue(mapping_name) caplog.set_level(logging.INFO) - + ion = r"Na+" if result == -1 else r"Cl-" msg = ( f"A charge difference of {result} is observed " "between the end states. This will be addressed by " f"transforming a water into a {ion} ion" ) - + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( - mapping, - "pme", - True, - openfe.SolventComponent() + mapping, "pme", True, openfe.SolventComponent() ) if result != 0: @@ -457,7 +434,7 @@ def test_hightimestep( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) @@ -479,13 +456,11 @@ def test_time_per_iteration_divmod( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) -@pytest.mark.parametrize( - "attribute", ["equilibration_length", "production_length"] -) +@pytest.mark.parametrize("attribute", ["equilibration_length", "production_length"]) def test_simsteps_not_timestep_divisible( attribute, benzene_vacuum_system, @@ -503,13 +478,11 @@ def test_simsteps_not_timestep_divisible( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) -@pytest.mark.parametrize( - "attribute", ["equilibration_length", "production_length"] -) +@pytest.mark.parametrize("attribute", ["equilibration_length", "production_length"]) def test_simsteps_not_mcstep_divisible( attribute, benzene_vacuum_system, @@ -520,17 +493,14 @@ def test_simsteps_not_mcstep_divisible( setattr(vac_settings.simulation_settings, attribute, 102 * offunit.ps) p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) - errmsg = ( - "should contain a number of steps divisible by the number of " - "integrator timesteps" - ) + errmsg = "should contain a number of steps divisible by the number of integrator timesteps" with pytest.raises(ValueError, match=errmsg): p.validate( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) @@ -551,14 +521,11 @@ def test_checkpoint_interval_not_divisible_time_per_iter( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) -@pytest.mark.parametrize( - "attribute", - ["positions_write_frequency", "velocities_write_frequency"] -) +@pytest.mark.parametrize("attribute", ["positions_write_frequency", "velocities_write_frequency"]) def test_pos_vel_write_frequency_not_divisible( benzene_vacuum_system, toluene_vacuum_system, @@ -576,13 +543,12 @@ def test_pos_vel_write_frequency_not_divisible( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) @pytest.mark.parametrize( - "attribute", - ["real_time_analysis_interval", "real_time_analysis_interval"] + "attribute", ["real_time_analysis_interval", "real_time_analysis_interval"] ) def test_real_time_analysis_not_divisible( benzene_vacuum_system, @@ -601,9 +567,10 @@ def test_real_time_analysis_not_divisible( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) + def test_n_replicas_not_n_windows( benzene_vacuum_system, toluene_vacuum_system, @@ -623,5 +590,5 @@ def test_n_replicas_not_n_windows( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, - extends=None + extends=None, ) From 951ac1588265ddfa060a6a821dfa7d6f81bca021 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 2 Jan 2026 20:51:45 -0500 Subject: [PATCH 28/36] move a few things around to make life easier --- .../protocols/openmm_rfe/hybridtop_units.py | 142 +++++++++--------- 1 file changed, 73 insertions(+), 69 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 41c6f1787..e20906173 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -730,61 +730,6 @@ def _subsample_topology( return selection_indices - def _get_reporter( - self, - selection_indices: npt.NDArray, - output_settings: MultiStateOutputSettings, - simulation_settings: MultiStateSimulationSettings, - ) -> multistate.MultiStateReporter: - """ - Get the multistate reporter. - - Parameters - ---------- - selection_indices : npt.NDArray - The set of system indices to report positions & velocities for. - output_settings : MultiStateOutputSettings - Settings defining how outputs should be written. - simulation_settings : MultiStateSimulationSettings - Settings defining out the simulation should be run. - """ - nc = self.shared_basepath / output_settings.output_filename - chk = output_settings.checkpoint_storage_filename - - if output_settings.positions_write_frequency is not None: - pos_interval = settings_validation.divmod_time_and_check( - numerator=output_settings.positions_write_frequency, - denominator=simulation_settings.time_per_iteration, - numerator_name="output settings' position_write_frequency", - denominator_name="simulation settings' time_per_iteration", - ) - else: - pos_interval = 0 - - if output_settings.velocities_write_frequency is not None: - vel_interval = settings_validation.divmod_time_and_check( - numerator=output_settings.velocities_write_frequency, - denominator=simulation_settings.time_per_iteration, - numerator_name="output settings' velocity_write_frequency", - denominator_name="sampler settings' time_per_iteration", - ) - else: - vel_interval = 0 - - chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=output_settings.checkpoint_interval, - time_per_iteration=simulation_settings.time_per_iteration, - ) - - return multistate.MultiStateReporter( - storage=nc, - analysis_particle_indices=selection_indices, - checkpoint_interval=chk_intervals, - checkpoint_storage=chk, - position_interval=pos_interval, - velocity_interval=vel_interval, - ) - @staticmethod def _get_integrator( integrator_settings: IntegratorSettings, @@ -840,6 +785,64 @@ def _get_integrator( return integrator + @staticmethod + def _get_reporter( + storage_path: pathlib.Path + selection_indices: npt.NDArray, + output_settings: MultiStateOutputSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> multistate.MultiStateReporter: + """ + Get the multistate reporter. + + Parameters + ---------- + storage_path : pathlib.Path + Path to the directory where files should be written. + selection_indices : npt.NDArray + The set of system indices to report positions & velocities for. + output_settings : MultiStateOutputSettings + Settings defining how outputs should be written. + simulation_settings : MultiStateSimulationSettings + Settings defining out the simulation should be run. + """ + nc = self.shared_basepath / output_settings.output_filename + chk = output_settings.checkpoint_storage_filename + + if output_settings.positions_write_frequency is not None: + pos_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' position_write_frequency", + denominator_name="simulation settings' time_per_iteration", + ) + else: + pos_interval = 0 + + if output_settings.velocities_write_frequency is not None: + vel_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' velocity_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + else: + vel_interval = 0 + + chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + return multistate.MultiStateReporter( + storage=nc, + analysis_particle_indices=selection_indices, + checkpoint_interval=chk_intervals, + checkpoint_storage=chk, + position_interval=pos_interval, + velocity_interval=vel_interval, + ) + @staticmethod def _get_sampler( system: openmm.System, @@ -1149,13 +1152,6 @@ def run( windows=settings["lambda_settings"].lambda_windows, ) - # Get the reporter - reporter = self._get_reporter( - selection_indices=selection_indices, - output_settings=settings["output_settings"], - simulation_settings=settings["simulation_settings"], - ) - # Get the compute platform restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" platform = omm_compute.get_openmm_platform( @@ -1164,14 +1160,22 @@ def run( restrict_cpu_count=restrict_cpu, ) - # Get the integrator - integrator = self._get_integrator( - integrator_settings=settings["integrator_settings"], - simulation_settings=settings["simulation_settings"], - system=hybrid_system, - ) - try: + # Get the integrator + integrator = self._get_integrator( + integrator_settings=settings["integrator_settings"], + simulation_settings=settings["simulation_settings"], + system=hybrid_system, + ) + + # get the reporter + reporter = self._get_reporter( + storage_path=self.shared_basepath, + selection_indices=selection_indices, + output_settings=settings["output_settings"], + simulation_settings=settings["simulation_settings"], + ) + # Get sampler sampler = self._get_sampler( system=hybrid_system, From 2e4b455a1a0a8185cf5e8f8d1d2fa4459445376c Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 7 Jan 2026 08:31:44 -0500 Subject: [PATCH 29/36] fix typo --- openfe/protocols/openmm_rfe/hybridtop_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 71c7a9ec5..dfcf21726 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -787,7 +787,7 @@ def _get_integrator( @staticmethod def _get_reporter( - storage_path: pathlib.Path + storage_path: pathlib.Path, selection_indices: npt.NDArray, output_settings: MultiStateOutputSettings, simulation_settings: MultiStateSimulationSettings, From 718280528a3f4807fe755febe2edb286cc659f79 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 7 Jan 2026 08:54:57 -0500 Subject: [PATCH 30/36] fix some merge issues --- openfe/protocols/openmm_rfe/hybridtop_units.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index dfcf21726..91336412c 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -22,7 +22,7 @@ import matplotlib.pyplot as plt import mdtraj import numpy as np -import numpy.typing as nptnpt +import numpy.typing as npt import openmm import openmmtools from gufe import ( @@ -806,7 +806,7 @@ def _get_reporter( simulation_settings : MultiStateSimulationSettings Settings defining out the simulation should be run. """ - nc = self.shared_basepath / output_settings.output_filename + nc = storage_path / output_settings.output_filename chk = output_settings.checkpoint_storage_filename if output_settings.positions_write_frequency is not None: @@ -1167,7 +1167,7 @@ def run( simulation_settings=settings["simulation_settings"], system=hybrid_system, ) - + # get the reporter reporter = self._get_reporter( storage_path=self.shared_basepath, @@ -1175,7 +1175,7 @@ def run( output_settings=settings["output_settings"], simulation_settings=settings["simulation_settings"], ) - + # Get sampler sampler = self._get_sampler( system=hybrid_system, @@ -1189,7 +1189,7 @@ def run( platform=platform, dry=dry, ) - + unit_results_dict = self._run_simulation( sampler=sampler, reporter=reporter, From 28b4381b1c2e178c1445e53487f871fda48ef221 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 7 Jan 2026 09:39:09 -0500 Subject: [PATCH 31/36] fix test failures due to integrator checks --- .../protocols/openmm_rfe/hybridtop_units.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 91336412c..e22944498 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -1159,15 +1159,15 @@ def run( gpu_device_index=settings["engine_settings"].gpu_device_index, restrict_cpu_count=restrict_cpu, ) - - try: - # Get the integrator - integrator = self._get_integrator( - integrator_settings=settings["integrator_settings"], - simulation_settings=settings["simulation_settings"], - system=hybrid_system, - ) + + # Get the integrator + integrator = self._get_integrator( + integrator_settings=settings["integrator_settings"], + simulation_settings=settings["simulation_settings"], + system=hybrid_system, + ) + try: # get the reporter reporter = self._get_reporter( storage_path=self.shared_basepath, @@ -1175,7 +1175,7 @@ def run( output_settings=settings["output_settings"], simulation_settings=settings["simulation_settings"], ) - + # Get sampler sampler = self._get_sampler( system=hybrid_system, @@ -1189,7 +1189,7 @@ def run( platform=platform, dry=dry, ) - + unit_results_dict = self._run_simulation( sampler=sampler, reporter=reporter, @@ -1202,7 +1202,7 @@ def run( # close reporter when you're done, prevent # file handle clashes reporter.close() - + # clear GPU contexts # TODO: use cache.empty() calls when openmmtools #690 is resolved # replace with above @@ -1213,9 +1213,9 @@ def run( # cautiously clear out the global context cache too for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] - + del sampler.sampler_context_cache, sampler.energy_context_cache - + if not dry: del integrator, sampler From 726f517c892cf4cef4a90bf5bbd85dbdcd0ab059 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jan 2026 14:40:02 +0000 Subject: [PATCH 32/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openfe/protocols/openmm_rfe/hybridtop_units.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index e22944498..484612b11 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -1159,14 +1159,14 @@ def run( gpu_device_index=settings["engine_settings"].gpu_device_index, restrict_cpu_count=restrict_cpu, ) - + # Get the integrator integrator = self._get_integrator( integrator_settings=settings["integrator_settings"], simulation_settings=settings["simulation_settings"], system=hybrid_system, ) - + try: # get the reporter reporter = self._get_reporter( @@ -1175,7 +1175,7 @@ def run( output_settings=settings["output_settings"], simulation_settings=settings["simulation_settings"], ) - + # Get sampler sampler = self._get_sampler( system=hybrid_system, @@ -1189,7 +1189,7 @@ def run( platform=platform, dry=dry, ) - + unit_results_dict = self._run_simulation( sampler=sampler, reporter=reporter, @@ -1202,7 +1202,7 @@ def run( # close reporter when you're done, prevent # file handle clashes reporter.close() - + # clear GPU contexts # TODO: use cache.empty() calls when openmmtools #690 is resolved # replace with above @@ -1213,9 +1213,9 @@ def run( # cautiously clear out the global context cache too for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] - + del sampler.sampler_context_cache, sampler.energy_context_cache - + if not dry: del integrator, sampler From 1587673a1c80151e8cd56d04f92c6ae131a4eafa Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 7 Jan 2026 09:42:45 -0500 Subject: [PATCH 33/36] try to make mypy happy --- openfe/protocols/openmm_rfe/hybridtop_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index e22944498..838b4c144 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -300,7 +300,7 @@ def _get_system_generator( # and we start loading the whole protein into OpenFF Topologies # First deduplicate isomoprhic molecules - unique_offmols = [] + unique_offmols: list[OFFMolecule] = [] for mol in openff_molecules: unique = all([not mol.is_isomorphic_with(umol) for umol in unique_offmols]) if unique: From 1fbec7dd9a198d6e5b43a08b67d7e7192dc37dda Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 7 Jan 2026 10:00:25 -0500 Subject: [PATCH 34/36] add early exist if there's no molecules --- openfe/protocols/openmm_rfe/hybridtop_units.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 5fc2bd577..3c9b75892 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -299,7 +299,10 @@ def _get_system_generator( # TODO: revisit this once the SystemGenerator update happens # and we start loading the whole protein into OpenFF Topologies - # First deduplicate isomoprhic molecules + # First deduplicate isomoprhic molecules, if there are any + if openff_molecules is None: + return system_generator + unique_offmols: list[OFFMolecule] = [] for mol in openff_molecules: unique = all([not mol.is_isomorphic_with(umol) for umol in unique_offmols]) From 3cd758ee1340066ebb40b79136f878d1fde662ba Mon Sep 17 00:00:00 2001 From: Irfan Alibay Date: Wed, 7 Jan 2026 16:41:42 +0000 Subject: [PATCH 35/36] Apply suggestions from code review Co-authored-by: Hannah Baumann <43765638+hannahbaumann@users.noreply.github.com> --- openfe/protocols/openmm_rfe/hybridtop_units.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index 3c9b75892..e2ff2b102 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -299,7 +299,7 @@ def _get_system_generator( # TODO: revisit this once the SystemGenerator update happens # and we start loading the whole protein into OpenFF Topologies - # First deduplicate isomoprhic molecules, if there are any + # First deduplicate isomorphic molecules, if there are any if openff_molecules is None: return system_generator @@ -501,7 +501,7 @@ def _get_omm_objects( The common ProteinComponent between the end states, if there is is one. solvent_component : SolventComponent | None The common SolventComponent between the end states, if there is one. - small_mols : dict[SmallMoleculeCOmponent, openff.toolkit.Molecule] + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] The small molecules for both end states. Returns @@ -776,7 +776,7 @@ def _get_integrator( ) # Validate for known issue when dealing with virtual sites - # and mutltistate simulations + # and multistate simulations if not integrator_settings.reassign_velocities: for particle_idx in range(system.getNumParticles()): if system.isVirtualSite(particle_idx): From 662242896e84b742d6b84b89bcff66180c029f3c Mon Sep 17 00:00:00 2001 From: Irfan Alibay Date: Wed, 7 Jan 2026 11:55:54 -0500 Subject: [PATCH 36/36] Update openfe/protocols/openmm_rfe/hybridtop_units.py Co-authored-by: Hannah Baumann <43765638+hannahbaumann@users.noreply.github.com> --- openfe/protocols/openmm_rfe/hybridtop_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/hybridtop_units.py b/openfe/protocols/openmm_rfe/hybridtop_units.py index e2ff2b102..03102435f 100644 --- a/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -491,7 +491,7 @@ def _get_omm_objects( ---------- stateA : ChemicalSystem ChemicalSystem defining end state A. - stateB : ChmiecalSysstem + stateB : ChemicalSystem ChemicalSystem defining end state B. mapping : LigandAtomMapping The mapping for alchemical components between state A and B.