diff --git a/environment.yml b/environment.yml index 5b23e9065..b083bb0f2 100644 --- a/environment.yml +++ b/environment.yml @@ -19,8 +19,8 @@ dependencies: - openff-toolkit-base >=0.16.2 - openff-units==0.3.1 # https://github.com/OpenFreeEnergy/openfe/pull/1374 - openmm ~=8.2.0 # omit 8.3.0 and 8.3.1 due to https://github.com/openmm/openmm/pull/5069, unpin once we've qualified 8.3.2 - - openmmforcefields >=0.15.0 # min needed for https://github.com/OpenFreeEnergy/openfe/pull/1695 - - openmmtools >=0.25.3 # fix to support numpy >=2.3: https://github.com/choderalab/openmmtools/pull/793 + - openmmforcefields >=0.15.1 # min needed for https://github.com/OpenFreeEnergy/openfe/pull/414 + - openmmtools >=0.26 # fix to support membrane barostat: https://github.com/choderalab/openmmtools/pull/798 - packaging - pandas - parmed >=4.3.1 # fix to support numpy >=2.3: https://github.com/ParmEd/ParmEd/pull/1387 diff --git a/news/membrane.rst b/news/membrane.rst new file mode 100644 index 000000000..fbb11842b --- /dev/null +++ b/news/membrane.rst @@ -0,0 +1,23 @@ +**Added:** + +* This PR adds membrane support to the protocols PlainMDProtocol, RelativeHybridTopologyProtocol, SepTopProtocol, AbsoluteBindingProtocol. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/openfe/__init__.py b/src/openfe/__init__.py index cb689bbed..a320e7e98 100644 --- a/src/openfe/__init__.py +++ b/src/openfe/__init__.py @@ -54,6 +54,7 @@ LigandAtomMapping, NonTransformation, ProteinComponent, + ProteinMembraneComponent, SmallMoleculeComponent, SolventComponent, Transformation, diff --git a/src/openfe/protocols/openmm_afe/abfe_units.py b/src/openfe/protocols/openmm_afe/abfe_units.py index 4032c43a8..343cbd2e8 100644 --- a/src/openfe/protocols/openmm_afe/abfe_units.py +++ b/src/openfe/protocols/openmm_afe/abfe_units.py @@ -16,7 +16,7 @@ from gufe import ( SolventComponent, ) -from gufe.components import Component +from gufe.components import Component, SolvatedPDBComponent from openff.units import Quantity from openff.units.openmm import to_openmm from openmm import System @@ -70,6 +70,13 @@ def _get_components(self): # an error will have been raised when calling `validate_solvent` # in the Protocol's `_create`. # Similarly we don't need to check prot_comp + + # If there is an SolvatedPDBComponent, we set the solv_comp + # in the complex to the SolvatedPDBComponent, as the SolventComponent + # is only used in the solvent leg + if isinstance(prot_comp, SolvatedPDBComponent): + solv_comp = prot_comp + return alchem_comps, solv_comp, prot_comp, off_comps @@ -106,7 +113,7 @@ def _get_settings(self) -> dict[str, SettingsBaseModel]: settings["alchemical_settings"] = prot_settings.alchemical_settings settings["lambda_settings"] = prot_settings.complex_lambda_settings settings["engine_settings"] = prot_settings.engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings + settings["integrator_settings"] = prot_settings.complex_integrator_settings settings["equil_simulation_settings"] = prot_settings.complex_equil_simulation_settings settings["equil_output_settings"] = prot_settings.complex_equil_output_settings settings["simulation_settings"] = prot_settings.complex_simulation_settings @@ -475,7 +482,7 @@ def _get_settings(self) -> dict[str, SettingsBaseModel]: settings["alchemical_settings"] = prot_settings.alchemical_settings settings["lambda_settings"] = prot_settings.solvent_lambda_settings settings["engine_settings"] = prot_settings.engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings + settings["integrator_settings"] = prot_settings.solvent_integrator_settings settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings settings["simulation_settings"] = prot_settings.solvent_simulation_settings diff --git a/src/openfe/protocols/openmm_afe/base_afe_units.py b/src/openfe/protocols/openmm_afe/base_afe_units.py index 994873b48..2f4302f44 100644 --- a/src/openfe/protocols/openmm_afe/base_afe_units.py +++ b/src/openfe/protocols/openmm_afe/base_afe_units.py @@ -29,6 +29,7 @@ import openmm import openmmtools from gufe import ( + BaseSolventComponent, ProteinComponent, SmallMoleculeComponent, SolventComponent, @@ -356,7 +357,7 @@ def _assign_partial_charges( @staticmethod def _get_system_generator( settings: dict[str, SettingsBaseModel], - solvent_component: SolventComponent | None, + solvent_component: BaseSolventComponent | None, openff_molecules: list[OFFMolecule], ffcache: pathlib.Path | None, ) -> SystemGenerator: @@ -368,8 +369,8 @@ def _get_system_generator( ---------- settings : dict[str, SettingsBaseModel] A dictionary of settings object for the unit. - solvent_comp : SolventComponent | None - The solvent component of this system, if there is one. + solvent_comp : BaseSolventComponent | None + The BaseSolventComponent of this system, if there is one. openff_molecules : list[openff.toolkit.Molecule] | None A list of OpenFF Molecules to generate templates for, if any. ffcache : pathlib.Path | None @@ -401,7 +402,7 @@ def _get_system_generator( @staticmethod def _get_modeller( protein_component: ProteinComponent | None, - solvent_component: SolventComponent | None, + solvent_component: BaseSolventComponent | None, small_mols: dict[SmallMoleculeComponent, OFFMolecule], system_generator: SystemGenerator, solvation_settings: BaseSolvationSettings, @@ -414,8 +415,8 @@ def _get_modeller( ---------- protein_component : ProteinComponent | None Protein Component, if it exists. - solvent_component : SolventComponent | None - Solvent Component, if it exists. + solvent_component : BaseSolventComponent | None + Base Solvent Component, if it exists. small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] Dictionary of OpenFF Molecules to add, keyed by SmallMoleculeComponent. @@ -447,7 +448,7 @@ def _get_omm_objects( self, settings: dict[str, SettingsBaseModel], protein_component: ProteinComponent | None, - solvent_component: SolventComponent | None, + solvent_component: BaseSolventComponent | None, small_mols: dict[SmallMoleculeComponent, OFFMolecule], ) -> tuple[ app.Topology, @@ -465,8 +466,8 @@ def _get_omm_objects( Protocol settings protein_component : ProteinComponent | None Protein component for the system. - solvent_component : SolventComponent | None - Solvent component for the system. + solvent_component : BaseSolventComponent | None + BaseSolventComponent for the system. small_mols : dict[str, openff.toolkit.Molecule] Dictionary of SmallMoleculeComponents and OpenFF Molecules defining the ligands to be added to the system @@ -847,7 +848,7 @@ def _get_states( box_vectors: openmm.unit.Quantity, thermodynamic_settings: ThermoSettings, lambdas: dict[str, list[float]], - solvent_component: SolventComponent | None, + solvent_component: BaseSolventComponent | None, alchemically_restrained: bool, ) -> tuple[list[SamplerState], list[ThermodynamicState]]: """ @@ -866,8 +867,8 @@ def _get_states( Settings controlling the thermodynamic parameters. lambdas : dict[str, list[float]] A dictionary of lambda scales. - solvent_component : SolventComponent | None - The solvent component of the system, if there is one. + solvent_component : BaseSolventComponent | None + The base solvent component of the system, if there is one. alchemically_restrained : bool Whether or not the system requires a control parameter for any alchemical restraints. diff --git a/src/openfe/protocols/openmm_afe/equil_afe_settings.py b/src/openfe/protocols/openmm_afe/equil_afe_settings.py index afb74528c..41cd5fa2e 100644 --- a/src/openfe/protocols/openmm_afe/equil_afe_settings.py +++ b/src/openfe/protocols/openmm_afe/equil_afe_settings.py @@ -387,10 +387,15 @@ def must_be_positive(cls, v): """ # Sampling State defining things - integrator_settings: IntegratorSettings + solvent_integrator_settings: IntegratorSettings """ Settings for controlling the integrator, such as the timestep and - barostat settings. + barostat settings in the solvent. + """ + complex_integrator_settings: IntegratorSettings + """ + Settings for controlling the integrator, such as the timestep and + barostat settings in the complex. """ # Simulation run settings diff --git a/src/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/src/openfe/protocols/openmm_afe/equil_binding_afe_method.py index 28054a14d..cf0a99eba 100644 --- a/src/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/src/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -32,8 +32,10 @@ import gufe from gufe import ( + BaseSolventComponent, ChemicalSystem, ProteinComponent, + ProteinMembraneComponent, SmallMoleculeComponent, SolventComponent, settings, @@ -171,7 +173,8 @@ def _default_settings(cls): ), solvent_solvation_settings=OpenMMSolvationSettings(), engine_settings=OpenMMEngineSettings(), - integrator_settings=IntegratorSettings(), + solvent_integrator_settings=IntegratorSettings(), + complex_integrator_settings=IntegratorSettings(), restraint_settings=BoreschRestraintSettings(), solvent_equil_simulation_settings=MDSimulationSettings( equilibration_length_nvt=0.1 * offunit.nanosecond, @@ -208,6 +211,53 @@ def _default_settings(cls): ) # fmt: on + @classmethod + def _adaptive_settings( + cls, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + initial_settings: None | AbsoluteBindingSettings = None, + ) -> AbsoluteBindingSettings: + """ + Get the recommended OpenFE settings for this protocol based on the input states involved in the + transformation. + + These are intended as a suitable starting point for creating an instance of this protocol, which can be further + customized before performing a Protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The initial state of the transformation. + stateB : ChemicalSystem + The final state of the transformation. + initial_settings : None | AbsoluteBindingSettings, optional + Initial settings to base the adaptive settings on. If None, default settings are used. + + Returns + ------- + AbsoluteBindingSettings + The recommended settings for this protocol based on the input states. + """ + # use initial settings or default settings + if initial_settings is not None: + protocol_settings = initial_settings.model_copy(deep=True) + else: + protocol_settings = cls.default_settings() + + # adapt the barostat and lipid forcefield based on the ProteinComponent + if stateA.contains(ProteinMembraneComponent): + protocol_settings.complex_integrator_settings.barostat = "MonteCarloMembraneBarostat" + protocol_settings.forcefield_settings.forcefields = [ + "amber/ff14SB.xml", + "amber/tip3p_standard.xml", + "amber/tip3p_HFE_multivalent.xml", + "amber/lipid17_merged.xml", + "amber/phosaa10.xml", + ] + + return protocol_settings + @staticmethod def _validate_endstates( stateA: ChemicalSystem, @@ -400,6 +450,11 @@ def _validate( # Use the more complete system validation solvent checks system_validation.validate_solvent(stateA, nonbonded_method) + # Validate the barostat used in combination with the protein component + system_validation.validate_barostat( + stateA, self.settings.complex_integrator_settings.barostat + ) + # Validate solvation settings settings_validation.validate_openmm_solvation_settings( self.settings.solvent_solvation_settings @@ -409,9 +464,15 @@ def _validate( ) # Validate integrator things + # We validate the timstep for both the complex & solvent settings + settings_validation.validate_timestep( + self.settings.forcefield_settings.hydrogen_mass, + self.settings.complex_integrator_settings.timestep, + ) + settings_validation.validate_timestep( self.settings.forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep, + self.settings.solvent_integrator_settings.timestep, ) def _create( diff --git a/src/openfe/protocols/openmm_md/plain_md_methods.py b/src/openfe/protocols/openmm_md/plain_md_methods.py index 5d626b4ac..23e51ac36 100644 --- a/src/openfe/protocols/openmm_md/plain_md_methods.py +++ b/src/openfe/protocols/openmm_md/plain_md_methods.py @@ -23,6 +23,7 @@ import openmm import openmm.unit as omm_unit from gufe import ( + BaseSolventComponent, ChemicalSystem, SmallMoleculeComponent, settings, @@ -32,6 +33,7 @@ from openff.toolkit.topology import Molecule as OFFMolecule from openff.units import Quantity, unit from openff.units.openmm import from_openmm, to_openmm +from openmm import MonteCarloBarostat, MonteCarloMembraneBarostat from openfe.protocols.openmm_md.plain_md_settings import ( IntegratorSettings, @@ -180,6 +182,9 @@ def _create( # Validate protein component system_validation.validate_protein(stateA) + # Validate the barostat used in combination with the protein component + system_validation.validate_barostat(stateA, self.settings.integrator_settings.barostat) + # Validate solvation settings settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) @@ -187,7 +192,7 @@ def _create( # TODO: Deal with multiple ProteinComponents solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) - system_name = "Solvent MD" if solvent_comp is not None else "Vacuum MD" + system_name = "Solvent MD" if stateA.contains(BaseSolventComponent) else "Vacuum MD" for comp in [protein_comp] + small_mols: if comp is not None: @@ -363,9 +368,9 @@ def _run_MD( logger.info("Running NVT equilibration") # Set barostat frequency to zero for NVT - for x in simulation.context.getSystem().getForces(): - if x.getName() == "MonteCarloBarostat": - x.setFrequency(0) + for force in simulation.context.getSystem().getForces(): + if isinstance(force, (MonteCarloBarostat, MonteCarloMembraneBarostat)): + force.setFrequency(0) simulation.context.setVelocitiesToTemperature(to_openmm(temperature)) @@ -397,9 +402,9 @@ def _run_MD( simulation.context.setVelocitiesToTemperature(to_openmm(temperature)) # Enable the barostat for NPT - for x in simulation.context.getSystem().getForces(): - if x.getName() == "MonteCarloBarostat": - x.setFrequency(barostat_frequency.m) + for force in simulation.context.getSystem().getForces(): + if isinstance(force, (MonteCarloBarostat, MonteCarloMembraneBarostat)): + force.setFrequency(barostat_frequency.m) t0 = time.time() simulation.step(equil_steps_npt) @@ -585,7 +590,7 @@ def run( solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) # 1. Create stateA system - # Create a dictionary of OFFMol for each SMC for bookeeping + # Create a dictionary of OFFMol for each SMC for bookkeeping smc_components: dict[SmallMoleculeComponent, OFFMolecule] smc_components = {i: i.to_openff() for i in small_mols} @@ -637,6 +642,9 @@ def run( # f. Save pdb of entire system if output_settings.preminimized_structure: + # roundtrip box vectors to remove vec3 issues + box = to_openmm(from_openmm(stateA_system.getDefaultPeriodicBoxVectors())) + stateA_topology.setPeriodicBoxVectors(box) with open(shared_basepath / output_settings.preminimized_structure, "w") as f: openmm.app.PDBFile.writeFile( stateA_topology, stateA_positions, file=f, keepIds=True diff --git a/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py index 096199a68..9f50f4d4f 100644 --- a/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py +++ b/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py @@ -470,9 +470,12 @@ def _check_and_store_system_forces(self): def _check_unknown_forces(forces, system_name): # TODO: double check that CMMotionRemover is ok being here - known_forces = {'HarmonicBondForce', 'HarmonicAngleForce', - 'PeriodicTorsionForce', 'NonbondedForce', - 'MonteCarloBarostat', 'CMMotionRemover', 'CMAPTorsionForce'} + known_forces = { + 'HarmonicBondForce', 'HarmonicAngleForce', + 'PeriodicTorsionForce', 'NonbondedForce', + 'MonteCarloBarostat', 'CMMotionRemover', + 'CMAPTorsionForce', 'MonteCarloMembraneBarostat', + } force_names = forces.keys() unknown_forces = set(force_names) - set(known_forces) @@ -548,10 +551,17 @@ def _handle_box(self): """ # Check that if there is a barostat in the old system, # it is added to the hybrid system - if "MonteCarloBarostat" in self._old_system_forces.keys(): + present_barostat = [ + i for i in self._old_system_forces.keys() + if i in ["MonteCarloBarostat", "MonteCarloMembraneBarostat"] + ] + if len(present_barostat) == 1: barostat = copy.deepcopy( - self._old_system_forces["MonteCarloBarostat"]) + self._old_system_forces[present_barostat[0]]) self._hybrid_system.addForce(barostat) + if len(present_barostat) > 1: + errmsg = "More than 1 barostat are present which is not supported" + raise ValueError(errmsg) # Copy over the box vectors from the old system box_vectors = self._old_system.getDefaultPeriodicBoxVectors() diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py b/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py index 47ac3779e..6e44002d8 100644 --- a/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py +++ b/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py @@ -25,6 +25,7 @@ ComponentMapping, LigandAtomMapping, ProteinComponent, + ProteinMembraneComponent, SmallMoleculeComponent, SolventComponent, settings, @@ -196,9 +197,20 @@ def _adaptive_settings( protocol_settings.lambda_settings.lambda_windows = 22 # adapt the solvation padding based on the system components - if stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent): + if stateA.contains(ProteinComponent): protocol_settings.solvation_settings.solvent_padding = 1 * offunit.nanometer + # adapt the barostat based on the system components + if stateA.contains(ProteinMembraneComponent): + protocol_settings.integrator_settings.barostat = "MonteCarloMembraneBarostat" + protocol_settings.forcefield_settings.forcefields = [ + "amber/ff14SB.xml", + "amber/tip3p_standard.xml", + "amber/tip3p_HFE_multivalent.xml", + "amber/lipid17_merged.xml", + "amber/phosaa10.xml", + ] + return protocol_settings @staticmethod @@ -539,6 +551,9 @@ def _validate( # Validate protein component system_validation.validate_protein(stateA) + # Validate the barostat used in combination with the protein component + system_validation.validate_barostat(stateA, self.settings.integrator_settings.barostat) + # Validate charge difference # Note: validation depends on the mapping & solvent component checks if stateA.contains(SolventComponent): diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_units.py b/src/openfe/protocols/openmm_rfe/hybridtop_units.py index b4cd6b744..00daf6f77 100644 --- a/src/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/src/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -30,6 +30,7 @@ LigandAtomMapping, ProteinComponent, SmallMoleculeComponent, + SolvatedPDBComponent, SolventComponent, ) from gufe.protocols.errors import ProtocolUnitExecutionError @@ -200,6 +201,10 @@ def _get_components( small_mols = {m: m.to_openff() for m in set(smcs_A).union(set(smcs_B))} + # If there is a SolvatedPDBComponent, we set the solvent_comp + if isinstance(protein_comp, SolvatedPDBComponent): + solvent_comp = protein_comp + return solvent_comp, protein_comp, small_mols @staticmethod diff --git a/src/openfe/protocols/openmm_septop/base.py b/src/openfe/protocols/openmm_septop/base.py index b3af8a4fd..8033d3fcb 100644 --- a/src/openfe/protocols/openmm_septop/base.py +++ b/src/openfe/protocols/openmm_septop/base.py @@ -1374,4 +1374,11 @@ def run( **unit_result_dict, } else: - return {"debug": {"sampler": sampler}} + return { + "debug": { + "sampler": sampler, + "alchem_system": system, + "selection_indices": self.selection_indices, + "positions": equil_positions, + } + } diff --git a/src/openfe/protocols/openmm_septop/equil_septop_method.py b/src/openfe/protocols/openmm_septop/equil_septop_method.py index 51e0fdedd..2f3f512e3 100644 --- a/src/openfe/protocols/openmm_septop/equil_septop_method.py +++ b/src/openfe/protocols/openmm_septop/equil_septop_method.py @@ -50,9 +50,12 @@ import openmm.unit import openmm.unit as omm_units from gufe import ( + BaseSolventComponent, ChemicalSystem, ProteinComponent, + ProteinMembraneComponent, SmallMoleculeComponent, + SolvatedPDBComponent, SolventComponent, settings, ) @@ -227,6 +230,11 @@ def _get_components(self): small_mols_B = {m: m.to_openff() for m in alchem_comps["stateB"]} small_mols = small_mols | small_mols_B + # If there is a SolvatedPDBComponent, we set the solv_comp in the + # complex to that, as the SolventComponent is only used in the solvent leg + if isinstance(prot_comp, SolvatedPDBComponent): + solv_comp = prot_comp + return alchem_comps, solv_comp, prot_comp, small_mols def _handle_settings(self) -> dict[str, SettingsBaseModel]: @@ -261,7 +269,7 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: "alchemical_settings": prot_settings.alchemical_settings, "lambda_settings": prot_settings.complex_lambda_settings, "engine_settings": prot_settings.engine_settings, - "integrator_settings": prot_settings.integrator_settings, + "integrator_settings": prot_settings.complex_integrator_settings, "equil_simulation_settings": prot_settings.complex_equil_simulation_settings, "equil_output_settings": prot_settings.complex_equil_output_settings, "simulation_settings": prot_settings.complex_simulation_settings, @@ -347,7 +355,7 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: "alchemical_settings": prot_settings.alchemical_settings, "lambda_settings": prot_settings.solvent_lambda_settings, "engine_settings": prot_settings.engine_settings, - "integrator_settings": prot_settings.integrator_settings, + "integrator_settings": prot_settings.solvent_integrator_settings, "equil_simulation_settings": prot_settings.solvent_equil_simulation_settings, "equil_output_settings": prot_settings.solvent_equil_output_settings, "simulation_settings": prot_settings.solvent_simulation_settings, @@ -1078,7 +1086,8 @@ def _default_settings(cls): solvent_padding=1.0 * unit.nanometer, ), engine_settings=OpenMMEngineSettings(), - integrator_settings=IntegratorSettings(), + solvent_integrator_settings=IntegratorSettings(), + complex_integrator_settings=IntegratorSettings(), solvent_equil_simulation_settings=MDSimulationSettings( equilibration_length_nvt=0.1 * unit.nanosecond, equilibration_length=0.1 * unit.nanosecond, @@ -1128,6 +1137,53 @@ def _default_settings(cls): complex_restraint_settings=BoreschRestraintSettings(), ) + @classmethod + def _adaptive_settings( + cls, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + initial_settings: None | SepTopSettings = None, + ) -> SepTopSettings: + """ + Get the recommended OpenFE settings for this protocol based on the input states involved in the + transformation. + + These are intended as a suitable starting point for creating an instance of this protocol, which can be further + customized before performing a Protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The initial state of the transformation. + stateB : ChemicalSystem + The final state of the transformation. + initial_settings : None | SepTopSettings, optional + Initial settings to base the adaptive settings on. If None, default settings are used. + + Returns + ------- + SepTopSettings + The recommended settings for this protocol based on the input states. + """ + # use initial settings or default settings + if initial_settings is not None: + protocol_settings = initial_settings.model_copy(deep=True) + else: + protocol_settings = cls.default_settings() + + # adapt the barostat and lipid forcefield based on the ProteinComponent + if stateA.contains(ProteinMembraneComponent): + protocol_settings.complex_integrator_settings.barostat = "MonteCarloMembraneBarostat" + protocol_settings.forcefield_settings.forcefields = [ + "amber/ff14SB.xml", + "amber/tip3p_standard.xml", + "amber/tip3p_HFE_multivalent.xml", + "amber/lipid17_merged.xml", + "amber/phosaa10.xml", + ] + + return protocol_settings + @staticmethod def _validate_complex_endstates( stateA: ChemicalSystem, @@ -1160,7 +1216,7 @@ def _validate_complex_endstates( errmsg = "No ProteinComponent found in stateB" raise ValueError(errmsg) - # check that there is a solvent component + # check that there is a BaseSolvent component if not any(isinstance(comp, SolventComponent) for comp in stateA.values()): errmsg = "No SolventComponent found in stateA" raise ValueError(errmsg) @@ -1327,7 +1383,7 @@ def _create( # Check nonbonded and solvent compatibility nonbonded_method = self.settings.forcefield_settings.nonbonded_method - # Use the more complete system validation solvent checks + # Validate solvent component system_validation.validate_solvent(stateA, nonbonded_method) # Validate solvation settings @@ -1338,6 +1394,11 @@ def _create( # Validate protein component system_validation.validate_protein(stateA) + # Validate the barostat used in combination with the protein component + system_validation.validate_barostat( + stateA, self.settings.complex_integrator_settings.barostat + ) + # Create list units for complex and solvent transforms def create_setup_units(unit_cls, leg): return [ @@ -2001,6 +2062,15 @@ def run( self.verbose, self.logger, ) + # roundtrip box vectors to remove vec3 issues + box_AB = to_openmm(from_openmm(box_AB)) + omm_topology_AB.setPeriodicBoxVectors(box_AB) + + # ToDo: also apply REST + system_outfile = self.shared_basepath / "system.xml.bz2" + + # Serialize system, state and integrator + serialize(system, system_outfile) topology_file = self.shared_basepath / "topology.pdb" openmm.app.pdbfile.PDBFile.writeFile( @@ -2009,21 +2079,31 @@ def run( open(topology_file, "w"), ) - # ToDo: also apply REST - - system_outfile = self.shared_basepath / "system.xml.bz2" + if not dry: + return { + "system": system_outfile, + "topology": topology_file, + "standard_state_correction_A": corr_A.to("kilocalorie_per_mole"), + "standard_state_correction_B": corr_B.to("kilocalorie_per_mole"), + "restraint_geometry_A": restraint_geom_A.model_dump(), + "restraint_geometry_B": restraint_geom_B.model_dump(), + } - # Serialize system, state and integrator - serialize(system, system_outfile) - - return { - "system": system_outfile, - "topology": topology_file, - "standard_state_correction_A": corr_A.to("kilocalorie_per_mole"), - "standard_state_correction_B": corr_B.to("kilocalorie_per_mole"), - "restraint_geometry_A": restraint_geom_A.model_dump(), - "restraint_geometry_B": restraint_geom_B.model_dump(), - } + else: + return { + # Add in various objects we can used to test the system + "debug": { + "system": system_outfile, + "topology": topology_file, + "system_A": omm_system_A, + "system_B": omm_system_B, + "system_AB": omm_system_AB, + "restrained_system": system, + "alchem_system": alchemical_system, + "alchem_factory": alchemical_factory, + "positions": equil_positions_AB, + } + } def _execute( self, @@ -2286,11 +2366,25 @@ def run( # Serialize system, state and integrator serialize(system, system_outfile) - return { - "system": system_outfile, - "topology": topology_file, - "standard_state_correction": corr.to("kilocalorie_per_mole"), - } + if not dry: + return { + "system": system_outfile, + "topology": topology_file, + "standard_state_correction": corr.to("kilocalorie_per_mole"), + } + else: + return { + # Add in various objects we can used to test the system + "debug": { + "system": system_outfile, + "topology": topology_file, + "system_AB": omm_system_AB, + "restrained_system": system, + "alchem_system": alchemical_system, + "alchem_factory": alchemical_factory, + "positions": positions_AB, + } + } def _execute( self, @@ -2369,7 +2463,7 @@ def _execute( class SepTopComplexRunUnit(SepTopComplexMixin, BaseSepTopRunUnit): """ - Protocol Unit for the complex phase of an relative SepTop free energy + Protocol Unit for the complex phase of a relative SepTop free energy """ def _get_lambda_schedule( diff --git a/src/openfe/protocols/openmm_septop/equil_septop_settings.py b/src/openfe/protocols/openmm_septop/equil_septop_settings.py index e1898c5fa..768c77cc1 100644 --- a/src/openfe/protocols/openmm_septop/equil_septop_settings.py +++ b/src/openfe/protocols/openmm_septop/equil_septop_settings.py @@ -385,10 +385,15 @@ def must_be_positive(cls, v): """ # Sampling State defining things - integrator_settings: IntegratorSettings + solvent_integrator_settings: IntegratorSettings """ Settings for controlling the integrator, such as the timestep and - barostat settings. + barostat settings in the solvent. + """ + complex_integrator_settings: IntegratorSettings + """ + Settings for controlling the integrator, such as the timestep and + barostat settings in the complex. """ # Simulation run settings diff --git a/src/openfe/protocols/openmm_utils/omm_settings.py b/src/openfe/protocols/openmm_utils/omm_settings.py index 4e11685c6..3f4c0ca99 100644 --- a/src/openfe/protocols/openmm_utils/omm_settings.py +++ b/src/openfe/protocols/openmm_utils/omm_settings.py @@ -34,13 +34,14 @@ ) from openff.interchange.components._packmol import _box_vectors_are_in_reduced_form from openff.units import unit -from pydantic import ConfigDict, field_validator +from pydantic import ConfigDict, field_validator, model_validator FemtosecondQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("femtosecond")] InversePicosecondQuantity: TypeAlias = Annotated[ GufeQuantity, specify_quantity_units("1/picosecond") ] TimestepQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("timestep")] +SurfaceTensionQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("bar*nanometer")] class BaseSolvationSettings(SettingsBaseModel): @@ -382,12 +383,28 @@ class IntegratorSettings(SettingsBaseModel): """ constraint_tolerance: float = 1e-06 """Tolerance for the constraint solver. Default 1e-6.""" + barostat: Literal["MonteCarloBarostat", "MonteCarloMembraneBarostat"] = "MonteCarloBarostat" + """ + The barostat to be used in the simulations. Default MonteCarloBarostat. + Notes + ----- + If the system contains a membrane, use the `MonteCarloMembraneBarostat`. + """ barostat_frequency: TimestepQuantity = 25.0 * unit.timestep """ Frequency at which volume scaling changes should be attempted. Note: The barostat frequency is ignored for gas-phase simulations. Default 25 * unit.timestep. """ + surface_tension: Optional[SurfaceTensionQuantity] = 0 * unit.bar * unit.nanometer + """ + The surface tension in bar*nm to define the `MonteCarloMembraneBarostat`. + Default 0 * unit.bar * unit.nanometer. + Notes + ----- + The `surface_tension` is ignored when the `MonteCarloMembraneBarostat` is + not used. + """ remove_com: bool = False """ Whether or not to remove the center of mass motion. Default ``False``. @@ -423,6 +440,19 @@ def must_be_inverse_time(cls, v): raise ValueError("langevin collision_rate must be in inverse time (i.e. 1/picoseconds)") return v + @model_validator(mode="after") + def validate_surface_tension(self): + if self.barostat == "MonteCarloMembraneBarostat" and self.surface_tension is None: + raise ValueError( + "surface_tension must be set (may be zero) when using MonteCarloMembraneBarostat" + ) + if self.barostat == "MonteCarloBarostat" and self.surface_tension: + raise ValueError( + "Got a nonzero surface_tension which is not allowed when using " + f"the MonteCarloBarostat. Got surface_tension {self.surface_tension}" + ) + return self + class OutputSettings(SettingsBaseModel): """ diff --git a/src/openfe/protocols/openmm_utils/system_creation.py b/src/openfe/protocols/openmm_utils/system_creation.py index ebe79842c..85fd8bd85 100644 --- a/src/openfe/protocols/openmm_utils/system_creation.py +++ b/src/openfe/protocols/openmm_utils/system_creation.py @@ -10,11 +10,19 @@ import numpy as np import numpy.typing as npt -from gufe import Component, ProteinComponent, SmallMoleculeComponent, SolventComponent +from gufe import ( + BaseSolventComponent, + Component, + ProteinComponent, + ProteinMembraneComponent, + SmallMoleculeComponent, + SolvatedPDBComponent, + SolventComponent, +) from gufe.settings import OpenMMSystemGeneratorFFSettings, ThermoSettings from openff.toolkit import Molecule as OFFMol from openff.units.openmm import ensure_quantity, to_openmm -from openmm import MonteCarloBarostat, app +from openmm import MonteCarloBarostat, MonteCarloMembraneBarostat, app from openmm import unit as omm_unit from openmmforcefields.generators import SystemGenerator @@ -40,13 +48,11 @@ def get_system_generator( Force field settings, including necessary information for constraints, hydrogen mass, rigid waters, non-ligand FF xmls, and the ligand FF name. - integrator_settings: IntegratorSettings - Integrator settings, including COM removal. thermo_settings : ThermoSettings Thermodynamic settings, including necessary settings for defining the ensemble conditions. integrator_settings : IntegratorSettings - Integrator settings, including barostat control variables. + Integrator settings, including barostat control variables and COM removal. cache : Optional[pathlib.Path] Path to openff force field cache. has_solvent : bool @@ -108,13 +114,26 @@ def get_system_generator( nonperiodic_kwargs = periodic_kwargs # Add barostat if necessary - # TODO: move this to its own place where we can handle membranes + # For membrane systems, add a MonteCarloMembraneBarostat. + # ToDo: We could also only check for the barostat setting here. But for + # that we first need adaptive settings for the rfe protocol if has_solvent: - barostat = MonteCarloBarostat( - ensure_quantity(thermo_settings.pressure, "openmm"), - ensure_quantity(thermo_settings.temperature, "openmm"), - integrator_settings.barostat_frequency.m, - ) + if integrator_settings.barostat == "MonteCarloMembraneBarostat": + barostat = MonteCarloMembraneBarostat( + ensure_quantity(thermo_settings.pressure, "openmm"), + to_openmm(integrator_settings.surface_tension), + ensure_quantity(thermo_settings.temperature, "openmm"), + MonteCarloMembraneBarostat.XYIsotropic, + MonteCarloMembraneBarostat.ZFree, + integrator_settings.barostat_frequency.m, + ) + + else: + barostat = MonteCarloBarostat( + ensure_quantity(thermo_settings.pressure, "openmm"), + ensure_quantity(thermo_settings.temperature, "openmm"), + integrator_settings.barostat_frequency.m, + ) else: barostat = None @@ -136,7 +155,7 @@ def get_system_generator( def get_omm_modeller( protein_comp: Optional[ProteinComponent], - solvent_comp: Optional[SolventComponent], + solvent_comp: Optional[BaseSolventComponent], small_mols: dict[SmallMoleculeComponent, OFFMol], omm_forcefield: app.ForceField, solvent_settings: OpenMMSolvationSettings, @@ -149,8 +168,8 @@ def get_omm_modeller( ---------- protein_comp : Optional[ProteinComponent] Protein Component, if it exists. - solvent_comp : Optional[ProteinCompoinent] - Solvent Component, if it exists. + solvent_comp : Optional[BaseSolventComponent] + Base Solvent Component, if it exists. small_mols : dict Small molecules to add. omm_forcefield : app.ForceField @@ -195,7 +214,7 @@ def _add_small_mol( ) # if we solvate temporarily rename water molecules to 'WAT' # see openmm issue #4103 - if solvent_comp is not None: + if isinstance(solvent_comp, SolventComponent): for r in system_modeller.topology.residues(): if r.name == "HOH": r.name = "WAT" @@ -204,8 +223,8 @@ def _add_small_mol( for comp, mol in small_mols.items(): _add_small_mol(comp, mol, system_modeller, component_resids) - # Add solvent if neeeded - if solvent_comp is not None: + # Add solvent if needed + if isinstance(solvent_comp, SolventComponent): # Do unit conversions if necessary solvent_padding = None box_size = None @@ -243,5 +262,11 @@ def _add_small_mol( for r in system_modeller.topology.residues(): if r.name == "WAT": r.name = "HOH" + # If we are working with a presolvated system (with solvent + # already added) and we have predefined box vectors, then skip solvation + # with Modeller and set box vectors. + elif isinstance(solvent_comp, SolvatedPDBComponent): + # Set the periodic box vectors + system_modeller.topology.setPeriodicBoxVectors(to_openmm(solvent_comp.box_vectors)) return system_modeller, component_resids diff --git a/src/openfe/protocols/openmm_utils/system_validation.py b/src/openfe/protocols/openmm_utils/system_validation.py index 7b45a077a..d5244cbf1 100644 --- a/src/openfe/protocols/openmm_utils/system_validation.py +++ b/src/openfe/protocols/openmm_utils/system_validation.py @@ -5,19 +5,26 @@ Protocols. """ +import logging +import warnings from typing import Optional, Tuple import numpy as np import openmm from gufe import ( + BaseSolventComponent, ChemicalSystem, Component, ProteinComponent, + ProteinMembraneComponent, SmallMoleculeComponent, + SolvatedPDBComponent, SolventComponent, ) from openff.toolkit import Molecule as OFFMol +logger = logging.getLogger(__name__) + def get_alchemical_components( stateA: ChemicalSystem, @@ -82,6 +89,11 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): Checks that the ChemicalSystem component has the right solvent composition for an input nonbonded_methtod. + Supported configurations are: + * Vacuum (no BaseSolventComponent) + * One BaseSolventComponent + * One SolventComponent paired with one SolvatedPDBComponent + Parameters ---------- state : ChemicalSystem @@ -92,29 +104,39 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): Raises ------ ValueError - * If there are multiple SolventComponents in the ChemicalSystem. - * If there is a SolventComponent and the `nonbonded_method` is - `nocutoff`. + * If there are more than two BaseSolventComponents in the ChemicalSystem. + * If there are multiple SolventComponents or SolvatedPDBComponents in the ChemicalSystem. + * If `nocutoff` is requested with any BaseSolventComponent present. + * If there is no BaseSolventComponent and the `nonbonded_method` is `pme`. * If the SolventComponent solvent is not water. """ - solv_comps = state.get_components_of_type(SolventComponent) + nb_method = nonbonded_method.lower() + base_solv_comps = state.get_components_of_type(BaseSolventComponent) + solvation_comps = state.get_components_of_type(SolventComponent) + solvated_comps = state.get_components_of_type(SolvatedPDBComponent) + + if len(solvated_comps) > 1: + raise ValueError("Multiple SolvatedPDBComponent found, only one is supported") + + if len(solvation_comps) > 1: + raise ValueError("Multiple SolventComponent found, only one is supported") + + # Any BaseSolventComponent present → nocutoff is invalid + if base_solv_comps and nb_method == "nocutoff": + raise ValueError("nocutoff cannot be used for solvent transformations") - if len(solv_comps) > 0: - if nonbonded_method.lower() == "nocutoff": - errmsg = "nocutoff cannot be used for solvent transformations" - raise ValueError(errmsg) + # Vacuum transform + if not base_solv_comps: + if nb_method == "pme": + raise ValueError("PME cannot be used for vacuum transform") + return - if len(solv_comps) > 1: - errmsg = "Multiple SolventComponent found, only one is supported" - raise ValueError(errmsg) + # Solvent-specific checks + if solvation_comps: + solvent = solvation_comps[0] - if solv_comps[0].smiles != "O": - errmsg = "Non water solvent is not currently supported" - raise ValueError(errmsg) - else: - if nonbonded_method.lower() == "pme": - errmsg = "PME cannot be used for vacuum transform" - raise ValueError(errmsg) + if solvent.smiles != "O": + raise ValueError("Non water solvent is not currently supported") def validate_protein(state: ChemicalSystem): @@ -139,6 +161,49 @@ def validate_protein(state: ChemicalSystem): raise ValueError(errmsg) +def validate_barostat(state: ChemicalSystem, barostat: str): + """ + Warn if there is a mismatch between the protein component type and barostat. + + A ProteinMembraneComponent should generally be simulated with a + MonteCarloMembraneBarostat, while non-membrane protein systems should + use a MonteCarloBarostat. + + Parameters + ---------- + state : ChemicalSystem + The chemical system to inspect. + barostat: str + The barostat to be applied to the simulation + """ + prot_comps = state.get_components_of_type(ProteinComponent) + protein_membrane = state.get_components_of_type(ProteinMembraneComponent) + + if not prot_comps: + return + + if protein_membrane: + if barostat != "MonteCarloMembraneBarostat": + wmsg = ( + "A ProteinMembraneComponent is present, but a membrane-specific " + "barostat (MonteCarloMembraneBarostat) is not specified. If you " + "are simulating a system with a membrane, consider using " + "integrator_settings.barostat='MonteCarloMembraneBarostat'." + ) + warnings.warn(wmsg) + logger.warning(wmsg) + + elif barostat == "MonteCarloMembraneBarostat": + wmsg = ( + "A MonteCarloMembraneBarostat is specified, but no " + "ProteinMembraneComponent is present. If you are not simulating a " + "membrane system, consider using " + "integrator_settings.barostat='MonteCarloBarostat'." + ) + warnings.warn(wmsg) + logger.warning(wmsg) + + ParseCompRet = Tuple[ Optional[SolventComponent], Optional[ProteinComponent], diff --git a/src/openfe/tests/conftest.py b/src/openfe/tests/conftest.py index e3c5c84d4..5b1cec402 100644 --- a/src/openfe/tests/conftest.py +++ b/src/openfe/tests/conftest.py @@ -286,6 +286,13 @@ def T4_protein_component(): return gufe.ProteinComponent.from_pdb_file(fn, name="T4_protein") +@pytest.fixture(scope="session") +def a2a_protein_membrane_component(): + with resources.as_file(resources.files("openfe.tests.data")) as d: + with gzip.open(d / "a2a/protein.pdb.gz", "rb") as f: + yield openfe.ProteinMembraneComponent.from_pdb_file(f, name="a2a") + + @pytest.fixture(scope="session") def eg5_protein_pdb(): with resources.as_file(resources.files("openfe.tests.data.eg5")) as d: @@ -319,6 +326,19 @@ def eg5_cofactor(eg5_cofactor_sdf) -> SmallMoleculeComponent: return SmallMoleculeComponent.from_sdf_file(eg5_cofactor_sdf) +@pytest.fixture(scope="session") +def a2a_ligands_sdf(): + with resources.as_file(resources.files("openfe.tests.data.a2a")) as d: + yield str(d / "ligands.sdf.gz") + + +@pytest.fixture(scope="session") +def a2a_ligands(a2a_ligands_sdf): + with gzip.open(a2a_ligands_sdf, "rb") as gzf: + suppl = Chem.ForwardSDMolSupplier(gzf, removeHs=False) + yield [SmallMoleculeComponent(m) for m in suppl] + + @pytest.fixture() def orion_network(): with resources.as_file(resources.files("openfe.tests.data.external_formats")) as d: diff --git a/src/openfe/tests/data/a2a/__init__.py b/src/openfe/tests/data/a2a/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/openfe/tests/data/a2a/ligands.sdf.gz b/src/openfe/tests/data/a2a/ligands.sdf.gz new file mode 100644 index 000000000..ce17f4996 Binary files /dev/null and b/src/openfe/tests/data/a2a/ligands.sdf.gz differ diff --git a/src/openfe/tests/data/a2a/protein.pdb.gz b/src/openfe/tests/data/a2a/protein.pdb.gz new file mode 100644 index 000000000..e14d0f06a Binary files /dev/null and b/src/openfe/tests/data/a2a/protein.pdb.gz differ diff --git a/src/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz b/src/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz index 5cbc5d9fb..e776454f3 100644 Binary files a/src/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz and b/src/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz differ diff --git a/src/openfe/tests/data/openmm_md/MDProtocol_json_results.gz b/src/openfe/tests/data/openmm_md/MDProtocol_json_results.gz index c8f1dfd99..cafb2659e 100644 Binary files a/src/openfe/tests/data/openmm_md/MDProtocol_json_results.gz and b/src/openfe/tests/data/openmm_md/MDProtocol_json_results.gz differ diff --git a/src/openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz b/src/openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz index 6e2035889..5e7ed722a 100644 Binary files a/src/openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz and b/src/openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz differ diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py index 21d1009a8..3fd8041e1 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py @@ -6,6 +6,7 @@ import gufe import mdtraj as mdt import numpy as np +import openmm import pytest from numpy.testing import assert_allclose from openff.units import unit as offunit @@ -17,9 +18,11 @@ HarmonicAngleForce, HarmonicBondForce, MonteCarloBarostat, + MonteCarloMembraneBarostat, NonbondedForce, PeriodicTorsionForce, ) +from openmm import unit as omm_unit from openmm import unit as ommunit from openmmtools.alchemy import ( AlchemicalRegion, @@ -179,7 +182,13 @@ class TestT4LysozymeDryRun: solvent = SolventComponent(ion_concentration=0 * offunit.molar) num_all_not_water = 2634 num_complex_atoms = 2613 - num_solvent_atoms = 12 + # No ions + num_ligand_atoms = 12 + + barostat_by_phase = { + "complex": MonteCarloBarostat, + "solvent": MonteCarloBarostat, + } @pytest.fixture(scope="class") def protocol(self, settings): @@ -253,19 +262,20 @@ def _assert_force_num(self, system, forcetype, number): forces = [f for f in system.getForces() if isinstance(f, forcetype)] assert len(forces) == number - def _assert_expected_alchemical_forces(self, system, complexed: bool, settings): + def _assert_expected_alchemical_forces(self, system, phase: str, settings): """ Assert the forces expected in the alchemical system. """ + barostat_type = self.barostat_by_phase[phase] self._assert_force_num(system, NonbondedForce, 1) self._assert_force_num(system, CustomNonbondedForce, 2) self._assert_force_num(system, CustomBondForce, 2) self._assert_force_num(system, HarmonicBondForce, 1) self._assert_force_num(system, HarmonicAngleForce, 1) self._assert_force_num(system, PeriodicTorsionForce, 1) - self._assert_force_num(system, MonteCarloBarostat, 1) + self._assert_force_num(system, barostat_type, 1) - if complexed: + if phase == "complex": self._assert_force_num(system, CustomCompoundBondForce, 1) assert len(system.getForces()) == 10 else: @@ -281,23 +291,31 @@ def _assert_expected_alchemical_forces(self, system, complexed: bool, settings): ) # Check the barostat made it all the way through - barostat = [f for f in system.getForces() if isinstance(f, MonteCarloBarostat)] + barostat = [f for f in system.getForces() if isinstance(f, barostat_type)] assert len(barostat) == 1 - assert barostat[0].getFrequency() == int(settings.integrator_settings.barostat_frequency.m) + expected_frequency = int( + ( + settings.complex_integrator_settings + if phase == "complex" + else settings.solvent_integrator_settings + ).barostat_frequency.m + ) + assert barostat[0].getFrequency() == expected_frequency assert barostat[0].getDefaultPressure() == to_openmm(settings.thermo_settings.pressure) assert barostat[0].getDefaultTemperature() == to_openmm( settings.thermo_settings.temperature ) - def _assert_expected_nonalchemical_forces(self, system, settings): + def _assert_expected_nonalchemical_forces(self, system, phase: str, settings): """ Assert the forces expected in the non-alchemical system. """ + barostat_type = self.barostat_by_phase[phase] self._assert_force_num(system, NonbondedForce, 1) self._assert_force_num(system, HarmonicBondForce, 1) self._assert_force_num(system, HarmonicAngleForce, 1) self._assert_force_num(system, PeriodicTorsionForce, 1) - self._assert_force_num(system, MonteCarloBarostat, 1) + self._assert_force_num(system, barostat_type, 1) assert len(system.getForces()) == 5 @@ -311,27 +329,39 @@ def _assert_expected_nonalchemical_forces(self, system, settings): ) # Check the barostat made it all the way through - barostat = [f for f in system.getForces() if isinstance(f, MonteCarloBarostat)] + barostat = [f for f in system.getForces() if isinstance(f, barostat_type)] assert len(barostat) == 1 - assert barostat[0].getFrequency() == int(settings.integrator_settings.barostat_frequency.m) + expected_frequency = int( + ( + settings.complex_integrator_settings + if phase == "complex" + else settings.solvent_integrator_settings + ).barostat_frequency.m + ) + assert barostat[0].getFrequency() == expected_frequency + assert barostat[0].getDefaultPressure() == to_openmm(settings.thermo_settings.pressure) assert barostat[0].getDefaultTemperature() == to_openmm( settings.thermo_settings.temperature ) - def _verify_sampler(self, sampler, complexed: bool, settings): + def _verify_sampler(self, sampler, phase: str, settings): """ Utility to verify the contents of the sampler. """ assert sampler.is_periodic assert isinstance(sampler, MultiStateSampler) - assert isinstance(sampler._thermodynamic_states[0].barostat, MonteCarloBarostat) + barostat_type = self.barostat_by_phase[phase] + assert isinstance(sampler._thermodynamic_states[0].barostat, barostat_type) assert sampler._thermodynamic_states[1].pressure == to_openmm( settings.thermo_settings.pressure ) for state in sampler._thermodynamic_states: system = state.get_system(remove_thermostat=True) - self._assert_expected_alchemical_forces(system, complexed, settings) + self._assert_expected_alchemical_forces(system, phase, settings) + + def _check_box_vectors(self, system): + self._test_dodecahedron_vectors(system) @staticmethod def _test_dodecahedron_vectors(system): @@ -405,21 +435,24 @@ def test_complex_dry_run(self, complex_setup_units, complex_sim_units, settings, ) # Check the sampler - self._verify_sampler(sim_results["sampler"], complexed=True, settings=settings) + self._verify_sampler(sim_results["sampler"], phase="complex", settings=settings) # Check the alchemical system self._assert_expected_alchemical_forces( - setup_results["alchem_system"], complexed=True, settings=settings + setup_results["alchem_system"], phase="complex", settings=settings ) - self._test_dodecahedron_vectors(setup_results["alchem_system"]) + self._check_box_vectors(setup_results["alchem_system"]) # Check the alchemical indices - expected_indices = [i + self.num_complex_atoms for i in range(self.num_solvent_atoms)] + expected_indices = [i + self.num_complex_atoms for i in range(self.num_ligand_atoms)] assert expected_indices == setup_results["alchem_indices"] # Check the non-alchemical system - self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings) - self._test_dodecahedron_vectors(setup_results["standard_system"]) + self._assert_expected_nonalchemical_forces( + setup_results["standard_system"], "complex", settings=settings + ) + self._check_box_vectors(setup_results["standard_system"]) + # Check the box vectors haven't changed (they shouldn't have because we didn't do MD) assert_allclose( from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()), @@ -452,21 +485,24 @@ def test_solvent_dry_run(self, solvent_setup_units, solvent_sim_units, settings, ) # Check the sampler - self._verify_sampler(sim_results["sampler"], complexed=False, settings=settings) + self._verify_sampler(sim_results["sampler"], phase="solvent", settings=settings) # Check the alchemical system self._assert_expected_alchemical_forces( - setup_results["alchem_system"], complexed=False, settings=settings + setup_results["alchem_system"], phase="solvent", settings=settings ) self._test_cubic_vectors(setup_results["alchem_system"]) # Check the alchemical indices - expected_indices = [i for i in range(self.num_solvent_atoms)] + expected_indices = [i for i in range(self.num_ligand_atoms)] assert expected_indices == setup_results["alchem_indices"] # Check the non-alchemical system - self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings) + self._assert_expected_nonalchemical_forces( + setup_results["standard_system"], "solvent", settings=settings + ) self._test_cubic_vectors(setup_results["standard_system"]) + # Check the box vectors haven't changed (they shouldn't have because we didn't do MD) assert_allclose( from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()), @@ -475,7 +511,7 @@ def test_solvent_dry_run(self, solvent_setup_units, solvent_sim_units, settings, # Check the PDB pdb = mdt.load_pdb(setup_results["pdb_structure"]) - assert pdb.n_atoms == self.num_solvent_atoms + assert pdb.n_atoms == self.num_ligand_atoms # Check energies alchem_region = AlchemicalRegion(alchemical_atoms=setup_results["alchem_indices"]) @@ -511,8 +547,10 @@ def settings(self): "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS "amber/phosaa10.xml", # Handles THE TPO ] - s.integrator_settings.reassign_velocities = True - s.integrator_settings.barostat_frequency = 100.0 * offunit.timestep + s.complex_integrator_settings.reassign_velocities = True + s.solvent_integrator_settings.reassign_velocities = True + s.complex_integrator_settings.barostat_frequency = 100.0 * offunit.timestep + s.solvent_integrator_settings.barostat_frequency = 100.0 * offunit.timestep s.thermo_settings.pressure = 1.1 * offunit.bar return s @@ -585,3 +623,85 @@ def assign_fictitious_charges(offmol): offsets = alchem_system_nbf.getParticleParameterOffset(i) assert pytest.approx(prop_chgs[i]) == offsets[2] + + +@pytest.mark.slow +class TestA2AMembraneDryRun(TestT4LysozymeDryRun): + solvent = SolventComponent(ion_concentration=0 * offunit.molar) + num_all_not_water = 16080 + num_complex_atoms = 39390 + # No ions + num_ligand_atoms = 36 + + barostat_by_phase = { + "complex": MonteCarloMembraneBarostat, + "solvent": MonteCarloBarostat, + } + + @pytest.fixture(scope="class") + def settings(self): + s = openmm_afe.AbsoluteBindingProtocol.default_settings() + s.protocol_repeats = 1 + s.engine_settings.compute_platform = "cpu" + s.complex_output_settings.output_indices = "not water" + s.solvent_solvation_settings.box_shape = "cube" + return s + + @pytest.fixture(scope="class") + def dag(self, settings, a2a_ligands, a2a_protein_membrane_component): + stateA = ChemicalSystem( + { + "ligand": a2a_ligands[0], + "protein": a2a_protein_membrane_component, + "solvent": self.solvent, + } + ) + + stateB = ChemicalSystem( + { + "protein": a2a_protein_membrane_component, + "solvent": self.solvent, + } + ) + + adaptive_settings = AbsoluteBindingProtocol._adaptive_settings( + stateA=stateA, + stateB=stateB, + initial_settings=settings, + ) + + protocol = AbsoluteBindingProtocol(settings=adaptive_settings) + + return protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + + def _check_box_vectors(self, system): + self._test_orthogonal_vectors(system) + + @staticmethod + def _test_orthogonal_vectors(system): + """Test that the system has an orthorhombic (rectangular) periodic box.""" + vectors = system.getDefaultPeriodicBoxVectors() + vectors = from_openmm(vectors) # convert to a Quantity array + + # Extract box lengths in nanometers + width_x, width_y, width_z = [v[i].to("nanometer").m for i, v in enumerate(vectors)] + + # Expected orthogonal box (axis-aligned) + expected_vectors = ( + np.array( + [ + [width_x, 0, 0], + [0, width_y, 0], + [0, 0, width_z], + ] + ) + * offunit.nanometer + ) + + assert_allclose( + vectors, expected_vectors, atol=1e-5, err_msg=f"Box is not orthogonal:\n{vectors}" + ) diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py index 5d815c713..78577ac34 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py @@ -157,7 +157,7 @@ def test_get_estimate(self, protocolresult): est = protocolresult.get_estimate() assert est - assert est.m == pytest.approx(-21.35, abs=0.01) + assert est.m == pytest.approx(-19.74, abs=0.01) assert isinstance(est, offunit.Quantity) assert est.is_compatible_with(offunit.kilojoule_per_mole) @@ -165,7 +165,7 @@ def test_get_uncertainty(self, protocolresult): est = protocolresult.get_uncertainty() assert est - assert est.m == pytest.approx(1.04, abs=0.01) + assert est.m == pytest.approx(0.85, abs=0.01) assert isinstance(est, offunit.Quantity) assert est.is_compatible_with(offunit.kilojoule_per_mole) @@ -270,13 +270,13 @@ def test_restraint_geometry(self, protocolresult): assert len(geom) == 3 assert isinstance(geom[0], BoreschRestraintGeometry) assert geom[0].guest_atoms == [1779, 1778, 1777] - assert geom[0].host_atoms == [880, 865, 864] - assert pytest.approx(geom[0].r_aA0, rel=1e-2) == 1.083558 * offunit.nanometer - assert pytest.approx(geom[0].theta_A0, rel=1e-2) == 0.711876 * offunit.radian - assert pytest.approx(geom[0].theta_B0, rel=1e-2) == 1.687366 * offunit.radian - assert pytest.approx(geom[0].phi_A0, rel=1e-2) == -0.2164231 * offunit.radian - assert pytest.approx(geom[0].phi_B0, rel=1e-2) == 1.892376 * offunit.radian - assert pytest.approx(geom[0].phi_C0, rel=1e-2) == -0.522031870 * offunit.radian + assert geom[0].host_atoms == [852, 853, 854] + assert pytest.approx(geom[0].r_aA0, rel=1e-2) == 1.041035 * offunit.nanometer + assert pytest.approx(geom[0].theta_A0, rel=1e-2) == 1.063788 * offunit.radian + assert pytest.approx(geom[0].theta_B0, rel=1e-2) == 1.230858 * offunit.radian + assert pytest.approx(geom[0].phi_A0, rel=1e-2) == 1.155133 * offunit.radian + assert pytest.approx(geom[0].phi_B0, rel=1e-2) == 1.141134 * offunit.radian + assert pytest.approx(geom[0].phi_C0, rel=1e-2) == -0.621615 * offunit.radian @pytest.mark.parametrize( "key, expected_size", diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_settings.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_settings.py index 273a3030f..08d10809a 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_settings.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_settings.py @@ -2,8 +2,10 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import pytest +from openfe import ChemicalSystem, SolventComponent from openfe.protocols.openmm_afe import ( AbsoluteBindingProtocol, + AbsoluteBindingSettings, ) @@ -76,3 +78,34 @@ def test_equil_not_all_complex(default_settings): def test_equil_not_all_solvent(default_settings): with pytest.raises(ValueError, match="output_indices must be all"): default_settings.solvent_equil_output_settings.output_indices = "not water" + + +def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings): + settings = AbsoluteBindingProtocol._adaptive_settings( + toluene_complex_system, + toluene_complex_system, + default_settings, + ) + + assert isinstance(settings, AbsoluteBindingSettings) + # Should use default barostat since no ProteinMembraneComponent + assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat" + + +def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands): + stateA = ChemicalSystem( + { + "ligandA": a2a_ligands[0], + "protein": a2a_protein_membrane_component, + "solvent": SolventComponent(), + } + ) + + settings = AbsoluteBindingProtocol._adaptive_settings(stateA, stateA) + assert isinstance(settings, AbsoluteBindingSettings) + # Barostat should have been updated + assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat" + + # Forcefields should include the lipid forcefields + ff = settings.forcefield_settings.forcefields + assert "amber/lipid17_merged.xml" in ff diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_validation.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_validation.py index 38cfcc558..77b5286dc 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_validation.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_validation.py @@ -167,7 +167,7 @@ def test_validate_endstates_nosolvcomp_stateB(benzene_modifications, T4_protein_ } ) - with pytest.raises(ValueError, match="No SolventComponent"): + with pytest.raises(ValueError, match="No SolventComponent found"): AbsoluteBindingProtocol._validate_endstates(stateA, stateB) diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index 7e257e865..4cce21795 100644 --- a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -20,7 +20,14 @@ from openff.toolkit import Molecule from openff.units import unit from openff.units.openmm import ensure_quantity, from_openmm, to_openmm -from openmm import CustomNonbondedForce, MonteCarloBarostat, NonbondedForce, XmlSerializer, app +from openmm import ( + CustomNonbondedForce, + MonteCarloBarostat, + MonteCarloMembraneBarostat, + NonbondedForce, + XmlSerializer, + app, +) from openmm import unit as omm_unit from openmmforcefields.generators import SMIRNOFFTemplateGenerator from openmmtools.multistate.multistatesampler import MultiStateSampler @@ -1111,6 +1118,93 @@ def test_dry_run_complex( assert pdb.n_atoms == 2629 +@pytest.mark.slow +def test_dry_run_membrane_complex( + a2a_protein_membrane_component, + a2a_ligands, + tmpdir, +): + ligA = next(c for c in a2a_ligands if c.name == "4g") + ligB = next(c for c in a2a_ligands if c.name == "4h") + + mapping = openfe.LigandAtomMapping( + componentA=ligA, + componentB=ligB, + componentA_to_componentB={i: i for i in range(36)}, + ) + + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + settings.protocol_repeats = 1 + settings.engine_settings.compute_platform = "cpu" + settings.output_settings.output_indices = "protein or resname UNK" + + systemA = openfe.ChemicalSystem( + {"ligand": mapping.componentA, "protein": a2a_protein_membrane_component}, + name=f"{mapping.componentA.name}_{a2a_protein_membrane_component.name}", + ) + systemB = openfe.ChemicalSystem( + {"ligand": mapping.componentB, "protein": a2a_protein_membrane_component}, + name=f"{mapping.componentB.name}_{a2a_protein_membrane_component.name}", + ) + + adaptive_settings = openmm_rfe.RelativeHybridTopologyProtocol._adaptive_settings( + stateA=systemA, stateB=systemB, mapping=mapping, initial_settings=settings + ) + protocol = openmm_rfe.RelativeHybridTopologyProtocol( + settings=adaptive_settings, + ) + dag = protocol.create( + stateA=systemA, + stateB=systemB, + mapping=mapping, + ) + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] + dag_sim_unit = _get_units(dag.protocol_units, HybridTopologyMultiStateSimulationUnit)[0] + + with tmpdir.as_cwd(): + setup_results = dag_setup_unit.run(dry=True) + input_box = a2a_protein_membrane_component.box_vectors + system_box = from_openmm(setup_results["hybrid_system"].getDefaultPeriodicBoxVectors()) + assert_allclose(system_box, input_box, atol=1e-5) + sim_results = dag_sim_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + dry=True, + ) + sampler = sim_results["sampler"] + + assert isinstance(sampler, MultiStateSampler) + assert sampler.is_periodic + assert isinstance(sampler._thermodynamic_states[0].barostat, MonteCarloMembraneBarostat) + assert sampler._thermodynamic_states[1].pressure == 1 * omm_unit.bar + + # Check we have the right number of atoms in the PDB + pdb = mdt.load_pdb("hybrid_system.pdb") + assert pdb.n_atoms == 4690 + box = sampler._thermodynamic_states[0].system.getDefaultPeriodicBoxVectors() + vectors = from_openmm(box) # convert to a Quantity array + + # Extract box lengths in nanometers + width_x, width_y, width_z = [v[i].to("nanometer").m for i, v in enumerate(vectors)] + + # Expected orthogonal box (axis-aligned) + expected_vectors = ( + np.array( + [ + [width_x, 0, 0], + [0, width_y, 0], + [0, 0, width_z], + ] + ) + * unit.nanometer + ) + + assert_allclose( + vectors, expected_vectors, atol=1e-5, err_msg=f"Box is not orthogonal:\n{vectors}" + ) + + def test_lambda_schedule_default(): lambdas = openmm_rfe._rfe_utils.lambdaprotocol.LambdaProtocol(functions="default") assert len(lambdas.lambda_schedule) == 10 diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py index e5e4a9f91..3ddfb7391 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py @@ -9,14 +9,13 @@ import gufe import mdtraj as md import numpy as np -import numpy.typing as npt import openmm import openmm.app import openmm.unit import pytest from numpy.testing import assert_allclose from openff.units import unit as offunit -from openff.units.openmm import ensure_quantity, from_openmm +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmm import ( CustomBondForce, CustomCompoundBondForce, @@ -24,6 +23,7 @@ HarmonicAngleForce, HarmonicBondForce, MonteCarloBarostat, + MonteCarloMembraneBarostat, NonbondedForce, PeriodicTorsionForce, ) @@ -43,6 +43,7 @@ from openfe.protocols.openmm_septop.equil_septop_method import ( _check_alchemical_charge_difference, ) +from openfe.protocols.openmm_septop.equil_septop_settings import SepTopSettings from openfe.protocols.openmm_utils import system_validation from openfe.protocols.openmm_utils.serialization import deserialize from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry @@ -359,7 +360,7 @@ def test_charge_error_create(charged_benzene_modifications, T4_protein_component ], ) def test_validate_complex_endstates_protcomp(request, system_A, system_B, fail_endstate): - with pytest.raises(ValueError, match=f"No ProteinComponent found in {fail_endstate}"): + with pytest.raises(ValueError, match="No ProteinComponent found"): SepTopProtocol._validate_complex_endstates( request.getfixturevalue(system_A), request.getfixturevalue(system_B), @@ -389,7 +390,7 @@ def test_validate_complex_endstates_nosolvcomp( system_B, fail_endstate, ): - with pytest.raises(ValueError, match=f"No SolventComponent found in {fail_endstate}"): + with pytest.raises(ValueError, match="No SolventComponent found"): SepTopProtocol._validate_complex_endstates( request.getfixturevalue(system_A), request.getfixturevalue(system_B), @@ -710,7 +711,7 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmpdir): assert len(complex_run_unit) == 1 with tmpdir.as_cwd(): - solv_setup_output = solv_setup_unit[0].run(dry=True) + solv_setup_output = solv_setup_unit[0].run(dry=True)["debug"] pdb = md.load_pdb("topology.pdb") assert pdb.n_atoms == 1762 central_atoms = np.array([[2, 19]], dtype=np.int32) @@ -746,7 +747,7 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmpdir): if isinstance(f, CustomNonbondedForce) and "U_sterics" in f.getEnergyFunction(): _verify_alchemical_sterics_force_parameters(f) - complex_setup_output = complex_setup_unit[0].run(dry=True) + complex_setup_output = complex_setup_unit[0].run(dry=True)["debug"] serialized_topology = complex_setup_output["topology"] serialized_system = complex_setup_output["system"] complex_sampler = complex_run_unit[0].run( @@ -805,7 +806,7 @@ def test_dry_run_methods( solv_setup_unit = [u for u in dag_units if isinstance(u, SepTopSolventSetupUnit)] sol_run_unit = [u for u in dag_units if isinstance(u, SepTopSolventRunUnit)] with tmpdir.as_cwd(): - solv_setup_output = solv_setup_unit[0].run(dry=True) + solv_setup_output = solv_setup_unit[0].run(dry=True)["debug"] serialized_topology = solv_setup_output["topology"] serialized_system = solv_setup_output["system"] solv_sampler = sol_run_unit[0].run( @@ -856,7 +857,7 @@ def test_dry_run_ligand_system_pressure( solv_setup_unit = [u for u in dag_units if isinstance(u, SepTopSolventSetupUnit)] sol_run_unit = [u for u in dag_units if isinstance(u, SepTopSolventRunUnit)] with tmpdir.as_cwd(): - solv_setup_output = solv_setup_unit[0].run(dry=True) + solv_setup_output = solv_setup_unit[0].run(dry=True)["debug"] serialized_topology = solv_setup_output["topology"] serialized_system = solv_setup_output["system"] solv_sampler = sol_run_unit[0].run( @@ -882,7 +883,7 @@ def test_virtual_sites_no_reassign( "amber/tip4pew_standard.xml", # FF with VS ] protocol_dry_settings.solvent_solvation_settings.solvent_model = "tip4pew" - protocol_dry_settings.integrator_settings.reassign_velocities = False + protocol_dry_settings.solvent_integrator_settings.reassign_velocities = False protocol = SepTopProtocol( settings=protocol_dry_settings, @@ -931,7 +932,7 @@ def test_dry_run_ligand_system_cutoff( solv_setup_unit = [u for u in dag_units if isinstance(u, SepTopSolventSetupUnit)] with tmpdir.as_cwd(): - serialized_system = solv_setup_unit[0].run(dry=True)["system"] + serialized_system = solv_setup_unit[0].run(dry=True)["debug"]["system"] system = deserialize(serialized_system) nbfs = [ f @@ -956,7 +957,7 @@ def test_dry_run_benzene_toluene_tip4p( "amber/phosaa10.xml", # Handles THE TPO ] protocol_dry_settings.solvent_solvation_settings.solvent_model = "tip4pew" - protocol_dry_settings.integrator_settings.reassign_velocities = True + protocol_dry_settings.solvent_integrator_settings.reassign_velocities = True protocol = SepTopProtocol(settings=protocol_dry_settings) @@ -979,7 +980,7 @@ def test_dry_run_benzene_toluene_tip4p( assert len(sol_run_unit) == 1 with tmpdir.as_cwd(): - solv_setup_output = solv_setup_unit[0].run(dry=True) + solv_setup_output = solv_setup_unit[0].run(dry=True)["debug"] serialized_topology = solv_setup_output["topology"] serialized_system = solv_setup_output["system"] solv_run = sol_run_unit[0].run( @@ -1017,7 +1018,7 @@ def test_dry_run_benzene_toluene_noncubic( assert len(solv_setup_unit) == 1 with tmpdir.as_cwd(): - solv_setup_output = solv_setup_unit[0].run(dry=True) + solv_setup_output = solv_setup_unit[0].run(dry=True)["debug"] serialized_system = solv_setup_output["system"] system = deserialize(serialized_system) vectors = system.getDefaultPeriodicBoxVectors() @@ -1106,7 +1107,7 @@ def check_partial_charges(offmol): # check sol_unit charges with tmpdir.as_cwd(): - serialized_system = solv_setup_unit[0].run(dry=True)["system"] + serialized_system = solv_setup_unit[0].run(dry=True)["debug"]["system"] system = deserialize(serialized_system) nonbond = [f for f in system.getForces() if isinstance(f, openmm.NonbondedForce)] assert len(nonbond) == 1 @@ -1125,7 +1126,7 @@ def check_partial_charges(offmol): # check complex_unit charges with tmpdir.as_cwd(): - serialized_system = complex_setup_unit[0].run(dry=True)["system"] + serialized_system = complex_setup_unit[0].run(dry=True)["debug"]["system"] system = deserialize(serialized_system) nonbond = [f for f in system.getForces() if isinstance(f, openmm.NonbondedForce)] assert len(nonbond) == 1 @@ -1192,7 +1193,7 @@ def T4L_xml( tmp = tmp_path_factory.mktemp("xml_reg") - dryrun = solv_setup_unit[0].run(dry=True, shared_basepath=tmp) + dryrun = solv_setup_unit[0].run(dry=True, shared_basepath=tmp)["debug"] system = dryrun["system"] return deserialize(system) @@ -1373,7 +1374,7 @@ def test_get_estimate(self, protocolresult): est = protocolresult.get_estimate() assert est - assert est.m == pytest.approx(5.18, abs=0.1) + assert est.m == pytest.approx(3.82, abs=0.1) assert isinstance(est, offunit.Quantity) assert est.is_compatible_with(offunit.kilojoule_per_mole) @@ -1491,9 +1492,332 @@ def test_restraint_geometry(self, protocolresult): assert isinstance(geom[0][0], BoreschRestraintGeometry) assert geom[0][0].guest_atoms == [1779, 1778, 1777] assert geom[0][0].host_atoms == [802, 801, 800] - assert pytest.approx(geom[0][0].r_aA0) == 0.774170 * offunit.nanometer - assert pytest.approx(geom[0][0].theta_A0) == 1.793181 * offunit.radian - assert pytest.approx(geom[0][0].theta_B0) == 1.501008 * offunit.radian - assert pytest.approx(geom[0][0].phi_A0) == 0.939174 * offunit.radian - assert pytest.approx(geom[0][0].phi_B0) == -1.504071 * offunit.radian - assert pytest.approx(geom[0][0].phi_C0) == -0.745093 * offunit.radian + assert pytest.approx(geom[0][0].r_aA0) == 0.798936 * offunit.nanometer + assert pytest.approx(geom[0][0].theta_A0) == 2.049091 * offunit.radian + assert pytest.approx(geom[0][0].theta_B0) == 1.221973 * offunit.radian + assert pytest.approx(geom[0][0].phi_A0) == 0.956774 * offunit.radian + assert pytest.approx(geom[0][0].phi_B0) == -1.217188 * offunit.radian + assert pytest.approx(geom[0][0].phi_C0) == -1.068226 * offunit.radian + + +@pytest.mark.slow +class TestA2AMembraneDryRun: + solvent = SolventComponent(ion_concentration=0 * offunit.molar) + num_all_not_water = 16116 + num_complex_atoms = 39462 + num_ligand_atoms_A = 36 + num_ligand_atoms_B = 36 + + @pytest.fixture(scope="class") + def settings(self): + s = SepTopProtocol.default_settings() + s.protocol_repeats = 1 + s.engine_settings.compute_platform = "cpu" + s.complex_output_settings.output_indices = "not water" + s.complex_solvation_settings.box_shape = "dodecahedron" + s.complex_solvation_settings.solvent_padding = 0.9 * offunit.nanometer + s.solvent_solvation_settings.box_shape = "cube" + return s + + @pytest.fixture(scope="function") + def dag(self, settings, a2a_ligands, a2a_protein_membrane_component): + stateA = ChemicalSystem( + { + "ligandA": a2a_ligands[0], + "protein": a2a_protein_membrane_component, + "solvent": self.solvent, + } + ) + + stateB = ChemicalSystem( + { + "ligandB": a2a_ligands[1], + "protein": a2a_protein_membrane_component, + "solvent": self.solvent, + } + ) + + # adaptive settings + protocol_settings = SepTopProtocol._adaptive_settings( + stateA=stateA, + stateB=stateB, + initial_settings=settings, + ) + protocol = SepTopProtocol(settings=protocol_settings) + + return protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + + @pytest.fixture(scope="function") + def complex_setup_units(self, dag): + return [u for u in dag.protocol_units if isinstance(u, SepTopComplexSetupUnit)] + + @pytest.fixture(scope="function") + def complex_run_units(self, dag): + return [u for u in dag.protocol_units if isinstance(u, SepTopComplexRunUnit)] + + @pytest.fixture(scope="function") + def solvent_setup_units(self, dag): + return [u for u in dag.protocol_units if isinstance(u, SepTopSolventSetupUnit)] + + @pytest.fixture(scope="function") + def solvent_run_units(self, dag): + return [u for u in dag.protocol_units if isinstance(u, SepTopSolventRunUnit)] + + def test_number_of_units( + self, dag, complex_setup_units, complex_run_units, solvent_setup_units, solvent_run_units + ): + assert len(list(dag.protocol_units)) == 4 + assert len(complex_setup_units) == 1 + assert len(complex_run_units) == 1 + assert len(solvent_setup_units) == 1 + assert len(solvent_run_units) == 1 + + def _assert_force_num(self, system, forcetype, number): + forces = [f for f in system.getForces() if isinstance(f, forcetype)] + assert len(forces) == number + + def _assert_expected_alchemical_forces(self, system, complexed: bool, settings): + """ + Assert the forces expected in the alchemical system. + """ + if complexed: + barostat_type = MonteCarloMembraneBarostat + self._assert_force_num(system, HarmonicBondForce, 1) + # Two custom bonds for the two Boresch restraints + self._assert_force_num(system, CustomCompoundBondForce, 2) + assert len(system.getForces()) == 15 + else: + # Extra bond in the solvent + self._assert_force_num(system, HarmonicBondForce, 2) + assert len(system.getForces()) == 14 + barostat_type = MonteCarloBarostat + + self._assert_force_num(system, NonbondedForce, 1) + self._assert_force_num(system, CustomNonbondedForce, 4) + self._assert_force_num(system, CustomBondForce, 4) + self._assert_force_num(system, HarmonicAngleForce, 1) + self._assert_force_num(system, PeriodicTorsionForce, 1) + self._assert_force_num(system, barostat_type, 1) + + # Check the nonbonded force has the right contents + nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)] + assert len(nonbond) == 1 + assert nonbond[0].getNonbondedMethod() == NonbondedForce.PME + assert ( + from_openmm(nonbond[0].getCutoffDistance()) + == settings.forcefield_settings.nonbonded_cutoff + ) + + # Check the barostat made it all the way through + barostat = [f for f in system.getForces() if isinstance(f, barostat_type)] + assert len(barostat) == 1 + assert barostat[0].getFrequency() == int( + settings.complex_integrator_settings.barostat_frequency.m + ) + assert barostat[0].getDefaultPressure() == to_openmm(settings.thermo_settings.pressure) + assert barostat[0].getDefaultTemperature() == to_openmm( + settings.thermo_settings.temperature + ) + + def _assert_expected_nonalchemical_forces(self, system, complexed: bool, settings): + """ + Assert the forces expected in the non-alchemical system. + """ + if complexed: + barostat_type = MonteCarloMembraneBarostat + else: + barostat_type = MonteCarloBarostat + self._assert_force_num(system, NonbondedForce, 1) + self._assert_force_num(system, HarmonicBondForce, 1) + self._assert_force_num(system, HarmonicAngleForce, 1) + self._assert_force_num(system, PeriodicTorsionForce, 1) + self._assert_force_num(system, barostat_type, 1) + + assert len(system.getForces()) == 5 + + # Check that the nonbonded force has the right contents + nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)] + assert len(nonbond) == 1 + assert nonbond[0].getNonbondedMethod() == NonbondedForce.PME + assert ( + from_openmm(nonbond[0].getCutoffDistance()) + == settings.forcefield_settings.nonbonded_cutoff + ) + + # Check the barostat made it all the way through + barostat = [f for f in system.getForces() if isinstance(f, barostat_type)] + assert len(barostat) == 1 + assert barostat[0].getFrequency() == int( + settings.complex_integrator_settings.barostat_frequency.m + ) + assert barostat[0].getDefaultPressure() == to_openmm(settings.thermo_settings.pressure) + assert barostat[0].getDefaultTemperature() == to_openmm( + settings.thermo_settings.temperature + ) + + def _verify_sampler(self, sampler, complexed: bool, settings): + """ + Utility to verify the contents of the sampler. + """ + assert sampler.is_periodic + assert isinstance(sampler, MultiStateSampler) + if complexed: + barostat_type = MonteCarloMembraneBarostat + else: + barostat_type = MonteCarloBarostat + assert isinstance(sampler._thermodynamic_states[0].barostat, barostat_type) + assert sampler._thermodynamic_states[1].pressure == to_openmm( + settings.thermo_settings.pressure + ) + for state in sampler._thermodynamic_states: + system = state.get_system(remove_thermostat=True) + self._assert_expected_alchemical_forces(system, complexed, settings) + + @staticmethod + def _test_orthogonal_vectors(system): + """Test that the system has an orthorhombic (rectangular) periodic box.""" + vectors = system.getDefaultPeriodicBoxVectors() + vectors = from_openmm(vectors) # convert to a Quantity array + + # Extract box lengths in nanometers + width_x, width_y, width_z = [v[i].to("nanometer").m for i, v in enumerate(vectors)] + + # Expected orthogonal box (axis-aligned) + expected_vectors = ( + np.array( + [ + [width_x, 0, 0], + [0, width_y, 0], + [0, 0, width_z], + ] + ) + * offunit.nanometer + ) + + assert_allclose( + vectors, expected_vectors, atol=1e-5, err_msg=f"Box is not orthogonal:\n{vectors}" + ) + + @staticmethod + def _test_cubic_vectors(system): + # cube is an identity matrix + vectors = system.getDefaultPeriodicBoxVectors() + width = float(from_openmm(vectors)[0][0].to("nanometer").m) + + expected_vectors = [ + [width, 0, 0], + [0, width, 0], + [0, 0, width], + ] * offunit.nanometer + + assert_allclose( + expected_vectors, + from_openmm(vectors), + ) + + def test_complex_dry_run(self, complex_setup_units, complex_run_units, tmpdir): + with tmpdir.as_cwd(): + # Get adaptive settings + adaptive_settings = complex_setup_units[0]._inputs["protocol"].settings + # Check that adaptive settings changed the barostat to membrane barostat + assert ( + adaptive_settings.complex_integrator_settings.barostat + == "MonteCarloMembraneBarostat" + ) + complex_setup_output = complex_setup_units[0].run(dry=True)["debug"] + serialized_topology = complex_setup_output["topology"] + serialized_system = complex_setup_output["system"] + data = complex_run_units[0].run( + serialized_system, serialized_topology, dry=True + )["debug"] # fmt: skip + # Check the sampler + self._verify_sampler(data["sampler"], complexed=True, settings=adaptive_settings) + + # Check the alchemical system + self._assert_expected_alchemical_forces( + data["alchem_system"], complexed=True, settings=adaptive_settings + ) + self._test_orthogonal_vectors(data["alchem_system"]) + + # Check the non-alchemical system + self._assert_expected_nonalchemical_forces( + complex_setup_output["system_AB"], complexed=True, settings=adaptive_settings + ) + self._test_orthogonal_vectors(complex_setup_output["system_AB"]) + # Check the box vectors haven't changed (they shouldn't have because we didn't do MD) + assert_allclose( + from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()), + from_openmm(complex_setup_output["system_AB"].getDefaultPeriodicBoxVectors()), + ) + + # Check the PDB + pdb = md.load_pdb("alchemical_system.pdb") + assert pdb.n_atoms == self.num_all_not_water + + full_pdb = md.load_pdb("topology.pdb") + assert full_pdb.n_atoms == self.num_complex_atoms + + def test_solvent_dry_run(self, solvent_setup_units, solvent_run_units, settings, tmpdir): + with tmpdir.as_cwd(): + solv_setup_output = solvent_setup_units[0].run(dry=True)["debug"] + serialized_topology = solv_setup_output["topology"] + serialized_system = solv_setup_output["system"] + data = solvent_run_units[0].run( + serialized_system, serialized_topology, dry=True + )["debug"] # fmt: skip + + # Check the sampler + self._verify_sampler(data["sampler"], complexed=False, settings=settings) + + # Check the alchemical system + self._assert_expected_alchemical_forces( + data["alchem_system"], complexed=False, settings=settings + ) + self._test_cubic_vectors(data["alchem_system"]) + + # Check the alchemical indices + expected_indices = [i for i in range(self.num_ligand_atoms_A + self.num_ligand_atoms_B)] + assert expected_indices == data["selection_indices"].tolist() + + # Check the non-alchemical system + self._assert_expected_nonalchemical_forces( + solv_setup_output["system_AB"], complexed=False, settings=settings + ) + self._test_cubic_vectors(solv_setup_output["system_AB"]) + + # Check the PDB + pdb = md.load_pdb("alchemical_system.pdb") + assert pdb.n_atoms == (self.num_ligand_atoms_A + self.num_ligand_atoms_B) + + +def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings): + settings = SepTopProtocol._adaptive_settings( + toluene_complex_system, toluene_complex_system, default_settings + ) + + assert isinstance(settings, SepTopSettings) + # Should use default barostat since no ProteinMembraneComponent + assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat" + + +def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands): + stateA = ChemicalSystem( + { + "ligandA": a2a_ligands[0], + "protein": a2a_protein_membrane_component, + "solvent": SolventComponent(), + } + ) + + settings = SepTopProtocol._adaptive_settings(stateA, stateA) + assert isinstance(settings, SepTopSettings) + # Barostat should have been updated + assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat" + + # Forcefields should include the lipid forcefields + ff = settings.forcefield_settings.forcefields + assert "amber/lipid17_merged.xml" in ff diff --git a/src/openfe/tests/protocols/test_openmmutils.py b/src/openfe/tests/protocols/test_openmmutils.py index d3838bfdd..ac3d78420 100644 --- a/src/openfe/tests/protocols/test_openmmutils.py +++ b/src/openfe/tests/protocols/test_openmmutils.py @@ -10,14 +10,15 @@ import numpy as np import pooch import pytest +from gufe import BaseSolventComponent from gufe.settings import OpenMMSystemGeneratorFFSettings, ThermoSettings from numpy.testing import assert_allclose, assert_equal from openff.toolkit import Molecule as OFFMol from openff.toolkit.utils.toolkit_registry import ToolkitRegistry from openff.toolkit.utils.toolkits import RDKitToolkitWrapper from openff.units import unit -from openff.units.openmm import ensure_quantity, from_openmm -from openmm import MonteCarloBarostat, NonbondedForce, app +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmm import MonteCarloBarostat, MonteCarloMembraneBarostat, NonbondedForce, app from openmm import unit as ommunit from openmmtools import multistate from pymbar.utils import ParameterError @@ -190,6 +191,19 @@ def test_validate_solvent_multiple_solvent(benzene_modifications): system_validation.validate_solvent(state, "pme") +def test_validate_solvent_multiple_solvated(benzene_modifications, a2a_protein_membrane_component): + state = openfe.ChemicalSystem( + { + "A": benzene_modifications["toluene"], + "S": a2a_protein_membrane_component, + "S2": a2a_protein_membrane_component, + } + ) + + with pytest.raises(ValueError, match="Multiple SolvatedPDBComponent"): + system_validation.validate_solvent(state, "pme") + + def test_not_water_solvent(benzene_modifications): state = openfe.ChemicalSystem( {"A": benzene_modifications["toluene"], "S": openfe.SolventComponent(smiles="C")} @@ -206,6 +220,24 @@ def test_multiple_proteins(T4_protein_component): system_validation.validate_protein(state) +def test_membrane_protein_warns_with_non_membrane_barostat(a2a_protein_membrane_component): + state = openfe.ChemicalSystem({"A": a2a_protein_membrane_component}) + with pytest.warns(UserWarning, match="ProteinMembraneComponent"): + system_validation.validate_barostat( + state, + barostat="MonteCarloBarostat", + ) + + +def test_non_membrane_protein_warns_with_membrane_barostat(T4_protein_component): + state = openfe.ChemicalSystem({"A": T4_protein_component}) + with pytest.warns(UserWarning, match="MonteCarloMembraneBarostat"): + system_validation.validate_barostat( + state, + barostat="MonteCarloMembraneBarostat", + ) + + def test_get_components_gas(benzene_modifications): state = openfe.ChemicalSystem( { @@ -424,6 +456,28 @@ def test_system_generator_solv_cache(self, get_settings): # Check cache file assert generator.template_generator._cache == "db.json" + def test_system_generator_membrane(self, get_settings): + ffsets, intsets, thermosets = get_settings + + thermosets.temperature = 320 * unit.kelvin + thermosets.pressure = 1.25 * unit.bar + intsets.barostat = "MonteCarloMembraneBarostat" + intsets.barostat_frequency = 200 * unit.timestep + generator = system_creation.get_system_generator( + ffsets, thermosets, intsets, Path("./db.json"), True + ) + + # Check barostat conditions + assert isinstance(generator.barostat, MonteCarloMembraneBarostat) + + pressure = ensure_quantity(generator.barostat.getDefaultPressure(), "openff") + temperature = ensure_quantity(generator.barostat.getDefaultTemperature(), "openff") + assert pressure.m == pytest.approx(1.25) + assert pressure.units == unit.bar + assert temperature.m == pytest.approx(320) + assert temperature.units == unit.kelvin + assert generator.barostat.getFrequency() == 200 + def test_get_omm_modeller_complex( self, T4_protein_component, @@ -456,6 +510,39 @@ def test_get_omm_modeller_complex( np.linspace(165, len(resids) - 1, len(resids) - 165), ) + def test_get_omm_modeller_membrane_box( + self, + a2a_protein_membrane_component, + a2a_ligands, + get_settings, + ): + ffsets, intsets, thermosets = get_settings + intsets.barostat = "MonteCarloMembraneBarostat" + ffsets.forcefields = [ + "amber/ff14SB.xml", + "amber/tip3p_standard.xml", + "amber/tip3p_HFE_multivalent.xml", + "amber/lipid17_merged.xml", + "amber/phosaa10.xml", + ] + generator = system_creation.get_system_generator(ffsets, thermosets, intsets, None, True) + + smc = a2a_ligands[0] + mol = smc.to_openff() + generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) + + model, comp_resids = system_creation.get_omm_modeller( + a2a_protein_membrane_component, + a2a_protein_membrane_component, + {smc: mol}, + generator.forcefield, + OpenMMSolvationSettings(), + ) + box_modeller = model.topology.getPeriodicBoxVectors() + box_protein = a2a_protein_membrane_component.box_vectors + + assert np.allclose(box_modeller, to_openmm(box_protein), atol=1e-6) + @pytest.fixture(scope="module") def ligand_mol_and_generator(self, get_settings): # Create offmol diff --git a/src/openfecli/commands/gather.py b/src/openfecli/commands/gather.py index cbb2b93f6..7f05e58c5 100644 --- a/src/openfecli/commands/gather.py +++ b/src/openfecli/commands/gather.py @@ -249,13 +249,18 @@ def _get_type(result: dict) -> Literal["vacuum", "solvent", "complex"]: component_types = [ x["__module__"] for x in protocol_data["inputs"]["stateA"]["components"].values() ] - if "gufe.components.solventcomponent" not in component_types: - return "vacuum" - elif "gufe.components.proteincomponent" in component_types: + if ( + "gufe.components.proteincomponent" in component_types + or "gufe.components.solvatedpdbcomponent" in component_types + ): return "complex" - else: + + elif "gufe.components.solventcomponent" in component_types: return "solvent" + else: + return "vacuum" + def _legacy_get_type(res_fn: os.PathLike | str) -> Literal["vacuum", "solvent", "complex"]: # TODO: Deprecate this when we no longer rely on key names in `_get_type()` diff --git a/src/openfecli/data/_registry.py b/src/openfecli/data/_registry.py index f9b811f73..3fdf5fd31 100644 --- a/src/openfecli/data/_registry.py +++ b/src/openfecli/data/_registry.py @@ -19,14 +19,14 @@ known_hash="md5:ff7313e14eb6f2940c6ffd50f2192181", ) zenodo_abfe_data = dict( - base_url="doi:10.5281/zenodo.17348229/", + base_url="doi:10.5281/zenodo.18757962/", fname="abfe_results.zip", - known_hash="md5:547f896e867cce61979d75b7e082f6ba", + known_hash="md5:5c4195ae5089b463534896d0dd5dc4a8", ) zenodo_septop_data = dict( - base_url="doi:10.5281/zenodo.17435569/", + base_url="doi:10.5281/zenodo.18758162/", fname="septop_results.zip", - known_hash="md5:2cfa18da59a20228f5c75a1de6ec879e", + known_hash="md5:4d1bae4a5a62067c2644059ba183f8a0", ) zenodo_data_registry = [