diff --git a/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py index 096199a68..b304de320 100644 --- a/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py +++ b/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py @@ -408,19 +408,40 @@ def _invert_dict(dictionary): """ Convenience method to invert a dictionary (since we do it so often). - Paramters: + Parameters: ---------- dictionary : dict Dictionary you want to invert """ return {v: k for k, v in dictionary.items()} + @staticmethod + def _pair_key(particle1, particle2): + """ + Convenience method to generate a key for a pair of atom indexes with consistent normalization. + + Parameters + ---------- + particle1 : int + Index of first particle in exception + particle2 : int + Index of second particle in exception + + Returns + ------- + tuple + Sorted tuple of the two particle indices, which is used as a key in the exception dicts. + """ + return (particle1, particle2) if particle1 < particle2 else (particle2, particle1) + def _set_mappings(self, old_to_new_map, core_old_to_new_map): """ Parameters ---------- old_to_new_map : dict of int : int Dictionary mapping atoms between the old and new systems. + core_old_to_new_map: dict[int,int] + Dictionary mapping core atoms between the old and new systems. This is a subset of the old_to_new_map. Notes ----- @@ -615,12 +636,18 @@ def _generate_dict_from_exceptions(force): ------- exceptions_dict : dict Dictionary of exceptions + + Note + ---- + * The keys of the dictionary are sorted tuples of the particle indices + to make it easier to search for exceptions between two particles + without worrying about order. """ exceptions_dict = {} for exception_index in range(force.getNumExceptions()): [index1, index2, chargeProd, sigma, epsilon] = force.getExceptionParameters(exception_index) - exceptions_dict[(index1, index2)] = [chargeProd, sigma, epsilon] + exceptions_dict[HybridTopologyFactory._pair_key(index1, index2)] = [chargeProd, sigma, epsilon] return exceptions_dict @@ -664,42 +691,29 @@ def _handle_constraints(self): """ This method adds relevant constraints from the old and new systems. - First, all constraints from the old systenm are added. + First, all constraints from the old system are added. Then, constraints to atoms unique to the new system are added. - - TODO: condense duplicated code """ # lengths of constraints already added constraint_lengths = dict() - # old system - hybrid_map = self._old_to_hybrid_map - for const_idx in range(self._old_system.getNumConstraints()): - at1, at2, length = self._old_system.getConstraintParameters( - const_idx) - hybrid_atoms = tuple(sorted([hybrid_map[at1], hybrid_map[at2]])) - if hybrid_atoms not in constraint_lengths.keys(): - self._hybrid_system.addConstraint(hybrid_atoms[0], - hybrid_atoms[1], length) - constraint_lengths[hybrid_atoms] = length - else: + for system, hybrid_map in ( + (self._old_system, self._old_to_hybrid_map), + (self._new_system, self._new_to_hybrid_map), + ): + for const_idx in range(system.getNumConstraints()): + at1, at2, length = system.getConstraintParameters(const_idx) + hybrid_atoms = self._pair_key(hybrid_map[at1], hybrid_map[at2]) + if hybrid_atoms not in constraint_lengths.keys(): + # add to the system + self._hybrid_system.addConstraint(hybrid_atoms[0], + hybrid_atoms[1], length) + # store for later checks + constraint_lengths[hybrid_atoms] = length + else: + if constraint_lengths[hybrid_atoms] != length: + raise AssertionError('constraint length is changing') - if constraint_lengths[hybrid_atoms] != length: - raise AssertionError('constraint length is changing') - - # new system - hybrid_map = self._new_to_hybrid_map - for const_idx in range(self._new_system.getNumConstraints()): - at1, at2, length = self._new_system.getConstraintParameters( - const_idx) - hybrid_atoms = tuple(sorted([hybrid_map[at1], hybrid_map[at2]])) - if hybrid_atoms not in constraint_lengths.keys(): - self._hybrid_system.addConstraint(hybrid_atoms[0], - hybrid_atoms[1], length) - constraint_lengths[hybrid_atoms] = length - else: - if constraint_lengths[hybrid_atoms] != length: - raise AssertionError('constraint length is changing') @staticmethod def _copy_threeparticleavg(atm_map, env_atoms, vs): @@ -1149,33 +1163,25 @@ def _add_nonbonded_force_terms(self): sterics_custom_nonbonded_force.setUseSwitchingFunction(False) @staticmethod - def _find_bond_parameters(bond_force, index1, index2): + def _build_bond_lookup(bond_force) -> dict: """ - This is a convenience function to find bond parameters in another - system given the two indices. + Build a lookup dictionary for bond parameters given a HarmonicBondForce. Parameters ---------- bond_force : openmm.HarmonicBondForce - The bond force where the parameters should be found - index1 : int - Index1 (order does not matter) of the bond atoms - index2 : int - Index2 (order does not matter) of the bond atoms + The bond force from which to build the lookup. Returns ------- - bond_parameters : list - List of relevant bond parameters + dict + A dictionary mapping a sorted tuple of atom indices to their bond parameters. """ - index_set = {index1, index2} - # Loop through all the bonds: + bond_lookup = {} for bond_index in range(bond_force.getNumBonds()): - parms = bond_force.getBondParameters(bond_index) - if index_set == {parms[0], parms[1]}: - return parms - - return [] + index1, index2, r0, k = bond_force.getBondParameters(bond_index) + bond_lookup[HybridTopologyFactory._pair_key(index1, index2)] = (r0, k) + return bond_lookup def _handle_harmonic_bonds(self): """ @@ -1196,6 +1202,10 @@ def _handle_harmonic_bonds(self): """ old_system_bond_force = self._old_system_forces['HarmonicBondForce'] new_system_bond_force = self._new_system_forces['HarmonicBondForce'] + # build a lookup for the new bonds so we don't have to loop through the force + new_bond_lookup = self._build_bond_lookup(new_system_bond_force) + # track the terms in the hybrid system + hybrid_bonds = {} # First, loop through the old system bond forces and add relevant terms for bond_index in range(old_system_bond_force.getNumBonds()): @@ -1215,18 +1225,18 @@ def _handle_harmonic_bonds(self): if index_set.issubset(self._atom_classes['core_atoms']): index1_new = self._old_to_new_map[index1_old] index2_new = self._old_to_new_map[index2_old] - new_bond_parameters = self._find_bond_parameters( - new_system_bond_force, index1_new, index2_new) - if not new_bond_parameters: + new_bond_parameters = new_bond_lookup.get(self._pair_key(index1_new, index2_new), None) + if new_bond_parameters is None: r0_new = r0_old k_new = 0.0*unit.kilojoule_per_mole/unit.angstrom**2 else: - # TODO - why is this being recalculated? - [index1, index2, r0_new, k_new] = self._find_bond_parameters( - new_system_bond_force, index1_new, index2_new) + r0_new, k_new = new_bond_parameters + self._hybrid_system_forces['core_bond_force'].addBond( index1_hybrid, index2_hybrid, [r0_old, k_old, r0_new, k_new]) + # track that we've added this bond + hybrid_bonds[self._pair_key(index1_hybrid, index2_hybrid)] = True # Check if the index set is a subset of anything besides # environment (in the case of environment, we just add the bond to @@ -1290,14 +1300,16 @@ def _handle_harmonic_bonds(self): # This has some peculiarities to be discussed... # TODO - Work out what the above peculiarities are... elif index_set.issubset(self._atom_classes['core_atoms']): - if not self._find_bond_parameters( - self._hybrid_system_forces['core_bond_force'], - index1_hybrid, index2_hybrid): + bond_key = self._pair_key(index1_hybrid, index2_hybrid) + if bond_key not in hybrid_bonds: r0_old = r0_new k_old = 0.0*unit.kilojoule_per_mole/unit.angstrom**2 self._hybrid_system_forces['core_bond_force'].addBond( index1_hybrid, index2_hybrid, [r0_old, k_old, r0_new, k_new]) + # track that we've added this bond + hybrid_bonds[bond_key] = True + elif index_set.issubset(self._atom_classes['environment_atoms']): # Already been added pass @@ -1312,36 +1324,46 @@ def _handle_harmonic_bonds(self): raise ValueError(errmsg) @staticmethod - def _find_angle_parameters(angle_force, indices): + def _triplet_key(a, b, c) -> tuple[int, int, int]: """ - Convenience function to find the angle parameters corresponding to a - particular set of indices + Create a key for angle lookups that is invariant to the order of the first and third indices. Parameters ---------- - angle_force : openmm.HarmonicAngleForce - The force where the angle of interest may be found. - indices : list of int - The indices (any order) of the angle atoms + a : int + The first atom index. + b : int + The second atom index (the central atom in the angle). + c : int + The third atom index. Returns ------- - angle_params : list - list of angle parameters + tuple[int, int, int] + A tuple representing the angle, with the first and third indices sorted. """ - indices_reversed = indices[::-1] + return (a, b, c) if a < c else (c, b, a) - # Now loop through and try to find the angle: - for angle_index in range(angle_force.getNumAngles()): - angle_params = angle_force.getAngleParameters(angle_index) + @staticmethod + def _build_angle_lookup(angle_force) -> dict: + """ + Build a lookup dictionary for angle parameters given a HarmonicAngleForce. - # Get a set representing the angle indices - angle_param_indices = angle_params[:3] + Parameters + ---------- + angle_force : openmm.HarmonicAngleForce + The angle force from which to build the lookup. - if (indices == angle_param_indices or - indices_reversed == angle_param_indices): - return angle_params - return [] # Return empty if no matching angle found + Returns + ------- + dict + A dictionary mapping a sorted tuple of atom indices to their angle parameters. + """ + angle_lookup = {} + for angle_index in range(angle_force.getNumAngles()): + index1, index2, index3, theta0, k = angle_force.getAngleParameters(angle_index) + angle_lookup[HybridTopologyFactory._triplet_key(index1, index2, index3)] = (theta0, k) + return angle_lookup def _handle_harmonic_angles(self): """ @@ -1368,6 +1390,10 @@ def _handle_harmonic_angles(self): """ old_system_angle_force = self._old_system_forces['HarmonicAngleForce'] new_system_angle_force = self._new_system_forces['HarmonicAngleForce'] + # build a lookup for the new system angles to save iterating through the force + new_angle_lookup = self._build_angle_lookup(new_system_angle_force) + # hybrid angle tracking + hybrid_angles = {} # First, loop through all the angles in the old system to determine # what to do with them. We will only use the @@ -1391,26 +1417,27 @@ def _handle_harmonic_angles(self): new_indices = [ self._old_to_new_map[old_atomid] for old_atomid in old_angle_parameters[:3] ] - new_angle_parameters = self._find_angle_parameters( - new_system_angle_force, new_indices - ) - if not new_angle_parameters: - new_angle_parameters = [ - 0, 0, 0, old_angle_parameters[3], - 0.0*unit.kilojoule_per_mole/unit.radian**2 - ] + new_angle_parameters = new_angle_lookup.get(self._triplet_key(*new_indices), None) + + if new_angle_parameters is None: + new_theta0 = old_angle_parameters[3] + new_k = 0.0*unit.kilojoule_per_mole/unit.radian**2 + else: + new_theta0, new_k = new_angle_parameters # Add to the hybrid force: # the parameters at indices 3 and 4 represent theta0 and k, # respectively. hybrid_force_parameters = [ old_angle_parameters[3], old_angle_parameters[4], - new_angle_parameters[3], new_angle_parameters[4] + new_theta0, new_k ] self._hybrid_system_forces['core_angle_force'].addAngle( hybrid_index_list[0], hybrid_index_list[1], hybrid_index_list[2], hybrid_force_parameters ) + # track that we have added this angle for the hybrid system + hybrid_angles[self._triplet_key(*hybrid_index_list)] = True # Check if the atoms are neither all core nor all environment, # which would mean they involve unique old interactions @@ -1488,8 +1515,8 @@ def _handle_harmonic_angles(self): ) elif hybrid_index_set.issubset(self._atom_classes['core_atoms']): - if not self._find_angle_parameters(self._hybrid_system_forces['core_angle_force'], - hybrid_index_list): + angle_key = self._triplet_key(*hybrid_index_list) + if angle_key not in hybrid_angles: hybrid_force_parameters = [ new_angle_parameters[3], 0.0*unit.kilojoule_per_mole/unit.radian**2, @@ -1499,6 +1526,9 @@ def _handle_harmonic_angles(self): hybrid_index_list[0], hybrid_index_list[1], hybrid_index_list[2], hybrid_force_parameters ) + # track that we have added this angle + hybrid_angles[angle_key] = True + elif hybrid_index_set.issubset(self._atom_classes['environment_atoms']): # We have already added the appropriate environmental atom # terms @@ -1513,40 +1543,6 @@ def _handle_harmonic_angles(self): "fit into a canonical atom set") raise ValueError(errmsg) - @staticmethod - def _find_torsion_parameters(torsion_force, indices): - """ - Convenience function to find the torsion parameters corresponding to a - particular set of indices. - - Parameters - ---------- - torsion_force : openmm.PeriodicTorsionForce - torsion force where the torsion of interest may be found - indices : list of int - The indices of the atoms of the torsion - - Returns - ------- - torsion_parameters : list - torsion parameters - """ - indices_reversed = indices[::-1] - - torsion_params_list = list() - - # Now loop through and try to find the torsion: - for torsion_idx in range(torsion_force.getNumTorsions()): - torsion_params = torsion_force.getTorsionParameters(torsion_idx) - - # Get a set representing the torsion indices: - torsion_param_indices = torsion_params[:4] - - if (indices == torsion_param_indices or - indices_reversed == torsion_param_indices): - torsion_params_list.append(torsion_params) - - return torsion_params_list def _handle_periodic_torsion_force(self): """ @@ -1568,9 +1564,13 @@ def _handle_periodic_torsion_force(self): """ old_system_torsion_force = self._old_system_forces['PeriodicTorsionForce'] new_system_torsion_force = self._new_system_forces['PeriodicTorsionForce'] + # local variables for speed while doing many lookups + unique_old_atoms = self._atom_classes["unique_old_atoms"] + unique_new_atoms = self._atom_classes["unique_new_atoms"] - auxiliary_custom_torsion_force = [] - old_custom_torsions_to_standard = [] + # use sets to keep membership checks quick as systems have many torsions + auxiliary_custom_torsion_force = set() + old_custom_torsions_to_standard = set() # We need to keep track of what torsions we added so that we do not # double count @@ -1590,7 +1590,7 @@ def _handle_periodic_torsion_force(self): # If all atoms are in the core, we'll need to find the # corresponding parameters in the old system and interpolate - if hybrid_index_set.intersection(self._atom_classes['unique_old_atoms']): + if hybrid_index_set.intersection(unique_old_atoms): # Then it goes to a standard force... self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( hybrid_index_list[0], hybrid_index_list[1], @@ -1603,13 +1603,13 @@ def _handle_periodic_torsion_force(self): # core/env term; in any case, it goes to the core torsion_force # TODO - why are we even adding the 0.0, 0.0, 0.0 section? hybrid_force_parameters = [ - torsion_parameters[4], torsion_parameters[5], - torsion_parameters[6], 0.0, 0.0, 0.0 + torsion_parameters[4], torsion_parameters[5].value_in_unit(unit.radian), + torsion_parameters[6].value_in_unit(unit.kilojoule_per_mole), 0.0, 0.0, 0.0 ] - auxiliary_custom_torsion_force.append( - [hybrid_index_list[0], hybrid_index_list[1], + auxiliary_custom_torsion_force.add( + (hybrid_index_list[0], hybrid_index_list[1], hybrid_index_list[2], hybrid_index_list[3], - hybrid_force_parameters[:3]] + *hybrid_force_parameters[:3]) ) for torsion_index in range(new_system_torsion_force.getNumTorsions()): @@ -1620,7 +1620,7 @@ def _handle_periodic_torsion_force(self): self._new_to_hybrid_map[new_index] for new_index in torsion_parameters[:4]] hybrid_index_set = set(hybrid_index_list) - if hybrid_index_set.intersection(self._atom_classes['unique_new_atoms']): + if hybrid_index_set.intersection(unique_new_atoms): # Then it goes to the custom torsion force (scaled to zero) self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( hybrid_index_list[0], hybrid_index_list[1], @@ -1631,16 +1631,15 @@ def _handle_periodic_torsion_force(self): else: hybrid_force_parameters = [ 0.0, 0.0, 0.0, torsion_parameters[4], - torsion_parameters[5], torsion_parameters[6]] + torsion_parameters[5].value_in_unit(unit.radian), torsion_parameters[6].value_in_unit(unit.kilojoule_per_mole)] # Check to see if this term is in the olds... - term = [hybrid_index_list[0], hybrid_index_list[1], + term = (hybrid_index_list[0], hybrid_index_list[1], hybrid_index_list[2], hybrid_index_list[3], - hybrid_force_parameters[3:]] + *hybrid_force_parameters[3:]) if term in auxiliary_custom_torsion_force: # Then this terms has to go to standard and be deleted... - old_index = auxiliary_custom_torsion_force.index(term) - old_custom_torsions_to_standard.append(old_index) + old_custom_torsions_to_standard.add(term) self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( hybrid_index_list[0], hybrid_index_list[1], hybrid_index_list[2], hybrid_index_list[3], @@ -1656,16 +1655,15 @@ def _handle_periodic_torsion_force(self): ) # Now we have to loop through the aux custom torsion force - for index in [q for q in range(len(auxiliary_custom_torsion_force)) - if q not in old_custom_torsions_to_standard]: - terms = auxiliary_custom_torsion_force[index] - hybrid_index_list = terms[:4] - hybrid_force_parameters = terms[4] + [0., 0., 0.] - self._hybrid_system_forces['custom_torsion_force'].addTorsion( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - hybrid_force_parameters - ) + for term in auxiliary_custom_torsion_force: + if term not in old_custom_torsions_to_standard: + hybrid_index_list = term[:4] + hybrid_force_parameters = term[4:] + (0., 0., 0.) + self._hybrid_system_forces['custom_torsion_force'].addTorsion( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + hybrid_force_parameters + ) def _handle_nonbonded(self): """ @@ -1900,8 +1898,11 @@ def _handle_hybrid_exceptions(self): for atom_pair in unique_old_pairs: # Since the pairs are indexed in the dictionary by the old system # indices, we need to convert - old_index_atom_pair = (self._hybrid_to_old_map[atom_pair[0]], - self._hybrid_to_old_map[atom_pair[1]]) + # use the exception key function to ensure we check using the correct order + old_index_atom_pair = self._pair_key( + self._hybrid_to_old_map[atom_pair[0]], + self._hybrid_to_old_map[atom_pair[1]] + ) # Now we check if the pair is in the exception dictionary if old_index_atom_pair in self._old_system_exceptions: @@ -1922,23 +1923,6 @@ def _handle_hybrid_exceptions(self): atom_pair[0], atom_pair[1] ) - # Check if the pair is in the reverse order and use that if so - elif old_index_atom_pair[::-1] in self._old_system_exceptions: - [chargeProd, sigma, epsilon] = self._old_system_exceptions[old_index_atom_pair[::-1]] - # If we are interpolating 1,4 exceptions then we have to - if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd*0.0, - sigma, epsilon*0.0 - ) - else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon) - - # Add exclusion to ensure exceptions are consistent - self._hybrid_system_forces['core_sterics_force'].addExclusion( - atom_pair[0], atom_pair[1]) - # TODO: work out why there's a bunch of commented out code here # Exerpt: # If it's not handled by an exception in the original system, we @@ -1952,8 +1936,10 @@ def _handle_hybrid_exceptions(self): for atom_pair in unique_new_pairs: # Since the pairs are indexed in the dictionary by the new system # indices, we need to convert - new_index_atom_pair = (self._hybrid_to_new_map[atom_pair[0]], - self._hybrid_to_new_map[atom_pair[1]]) + new_index_atom_pair = self._pair_key( + self._hybrid_to_new_map[atom_pair[0]], + self._hybrid_to_new_map[atom_pair[1]] + ) # Now we check if the pair is in the exception dictionary if new_index_atom_pair in self._new_system_exceptions: @@ -1971,57 +1957,10 @@ def _handle_hybrid_exceptions(self): self._hybrid_system_forces['core_sterics_force'].addExclusion( atom_pair[0], atom_pair[1] ) - - # Check if the pair is present in the reverse order and use that if so - elif new_index_atom_pair[::-1] in self._new_system_exceptions: - [chargeProd, sigma, epsilon] = self._new_system_exceptions[new_index_atom_pair[::-1]] - if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd*0.0, - sigma, epsilon*0.0 - ) - else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon - ) - - self._hybrid_system_forces['core_sterics_force'].addExclusion( - atom_pair[0], atom_pair[1] - ) - - # TODO: work out why there's a bunch of commented out code here # If it's not handled by an exception in the original system, we # just add the regular parameters as an exception - @staticmethod - def _find_exception(force, index1, index2): - """ - Find the exception that corresponds to the given indices in the given - system - - Parameters - ---------- - force : openmm.NonbondedForce object - System containing the exceptions - index1 : int - The index of the first atom (order is unimportant) - index2 : int - The index of the second atom (order is unimportant) - - Returns - ------- - exception_parameters : list - List of exception parameters - """ - index_set = {index1, index2} - - # Loop through the exceptions and try to find one matching the criteria - for exception_idx in range(force.getNumExceptions()): - exception_parameters = force.getExceptionParameters(exception_idx) - if index_set==set(exception_parameters[:2]): - return exception_parameters - return [] def _handle_original_exceptions(self): """ @@ -2091,10 +2030,8 @@ def _handle_original_exceptions(self): # First get the new indices. index1_new = hybrid_to_new_map[index1_hybrid] index2_new = hybrid_to_new_map[index2_hybrid] - # Get the exception parameters: - new_exception_parms= self._find_exception( - new_system_nonbonded_force, - index1_new, index2_new) + # Get the exception parameters: make sure to sort the keys to match how they are stored + new_exception_parms = self._new_system_exceptions.get(self._pair_key(index1_new, index2_new), []) # If there's no new exception, then we should just set the # exception parameters to be the nonbonded parameters @@ -2106,7 +2043,7 @@ def _handle_original_exceptions(self): sigma_new = 0.5 * (sigma1_new + sigma2_new) epsilon_new = unit.sqrt(epsilon1_new*epsilon2_new) else: - [index1_new, index2_new, chargeProd_new, sigma_new, epsilon_new] = new_exception_parms + [chargeProd_new, sigma_new, epsilon_new] = new_exception_parms # Interpolate between old and new exception_index = self._hybrid_system_forces['standard_nonbonded_force'].addException( @@ -2179,7 +2116,8 @@ def _handle_original_exceptions(self): # See if it's also in the old nonbonded force. if it is, then we don't need to add it. # But if it's not, we need to interpolate - if not self._find_exception(old_system_nonbonded_force, index1_old, index2_old): + old_exception_parms = self._old_system_exceptions.get(self._pair_key(index1_old, index2_old), []) + if not old_exception_parms: [charge1_old, sigma1_old, epsilon1_old] = old_system_nonbonded_force.getParticleParameters(index1_old) [charge2_old, sigma2_old, epsilon2_old] = old_system_nonbonded_force.getParticleParameters(index2_old)