@@ -415,6 +415,25 @@ def _invert_dict(dictionary):
415415 """
416416 return {v : k for k , v in dictionary .items ()}
417417
418+ @staticmethod
419+ def _pair_key (particle1 , particle2 ):
420+ """
421+ Convenience method to generate a key for a pair of atom indexes with consistent normalization.
422+
423+ Parameters
424+ ----------
425+ particle1 : int
426+ Index of first particle in exception
427+ particle2 : int
428+ Index of second particle in exception
429+
430+ Returns
431+ -------
432+ tuple
433+ Sorted tuple of the two particle indices, which is used as a key in the exception dicts.
434+ """
435+ return (particle1 , particle2 ) if particle1 < particle2 else (particle2 , particle1 )
436+
418437 def _set_mappings (self , old_to_new_map , core_old_to_new_map ):
419438 """
420439 Parameters
@@ -628,7 +647,7 @@ def _generate_dict_from_exceptions(force):
628647
629648 for exception_index in range (force .getNumExceptions ()):
630649 [index1 , index2 , chargeProd , sigma , epsilon ] = force .getExceptionParameters (exception_index )
631- exceptions_dict [tuple ( sorted ([ index1 , index2 ]) )] = [chargeProd , sigma , epsilon ]
650+ exceptions_dict [HybridTopologyFactory . _pair_key ( index1 , index2 )] = [chargeProd , sigma , epsilon ]
632651
633652 return exceptions_dict
634653
@@ -684,7 +703,7 @@ def _handle_constraints(self):
684703 ):
685704 for const_idx in range (system .getNumConstraints ()):
686705 at1 , at2 , length = system .getConstraintParameters (const_idx )
687- hybrid_atoms = tuple ( sorted ([ hybrid_map [at1 ], hybrid_map [at2 ]]) )
706+ hybrid_atoms = self . _pair_key ( hybrid_map [at1 ], hybrid_map [at2 ])
688707 if hybrid_atoms not in constraint_lengths .keys ():
689708 # add to the system
690709 self ._hybrid_system .addConstraint (hybrid_atoms [0 ],
@@ -1144,33 +1163,25 @@ def _add_nonbonded_force_terms(self):
11441163 sterics_custom_nonbonded_force .setUseSwitchingFunction (False )
11451164
11461165 @staticmethod
1147- def _find_bond_parameters (bond_force , index1 , index2 ) :
1166+ def _build_bond_lookup (bond_force ) -> dict :
11481167 """
1149- This is a convenience function to find bond parameters in another
1150- system given the two indices.
1168+ Build a lookup dictionary for bond parameters given a HarmonicBondForce.
11511169
11521170 Parameters
11531171 ----------
11541172 bond_force : openmm.HarmonicBondForce
1155- The bond force where the parameters should be found
1156- index1 : int
1157- Index1 (order does not matter) of the bond atoms
1158- index2 : int
1159- Index2 (order does not matter) of the bond atoms
1173+ The bond force from which to build the lookup.
11601174
11611175 Returns
11621176 -------
1163- bond_parameters : list
1164- List of relevant bond parameters
1177+ dict
1178+ A dictionary mapping a sorted tuple of atom indices to their bond parameters.
11651179 """
1166- index_set = {index1 , index2 }
1167- # Loop through all the bonds:
1180+ bond_lookup = {}
11681181 for bond_index in range (bond_force .getNumBonds ()):
1169- parms = bond_force .getBondParameters (bond_index )
1170- if index_set == {parms [0 ], parms [1 ]}:
1171- return parms
1172-
1173- return []
1182+ index1 , index2 , r0 , k = bond_force .getBondParameters (bond_index )
1183+ bond_lookup [HybridTopologyFactory ._pair_key (index1 , index2 )] = (r0 , k )
1184+ return bond_lookup
11741185
11751186 def _handle_harmonic_bonds (self ):
11761187 """
@@ -1191,6 +1202,10 @@ def _handle_harmonic_bonds(self):
11911202 """
11921203 old_system_bond_force = self ._old_system_forces ['HarmonicBondForce' ]
11931204 new_system_bond_force = self ._new_system_forces ['HarmonicBondForce' ]
1205+ # build a lookup for the new bonds so we don't have to loop through the force
1206+ new_bond_lookup = self ._build_bond_lookup (new_system_bond_force )
1207+ # track the terms in the hybrid system
1208+ hybrid_bonds = {}
11941209
11951210 # First, loop through the old system bond forces and add relevant terms
11961211 for bond_index in range (old_system_bond_force .getNumBonds ()):
@@ -1210,18 +1225,18 @@ def _handle_harmonic_bonds(self):
12101225 if index_set .issubset (self ._atom_classes ['core_atoms' ]):
12111226 index1_new = self ._old_to_new_map [index1_old ]
12121227 index2_new = self ._old_to_new_map [index2_old ]
1213- new_bond_parameters = self ._find_bond_parameters (
1214- new_system_bond_force , index1_new , index2_new )
1215- if not new_bond_parameters :
1228+ new_bond_parameters = new_bond_lookup .get (self ._pair_key (index1_new , index2_new ), None )
1229+ if new_bond_parameters is None :
12161230 r0_new = r0_old
12171231 k_new = 0.0 * unit .kilojoule_per_mole / unit .angstrom ** 2
12181232 else :
1219- # TODO - why is this being recalculated?
1220- [index1 , index2 , r0_new , k_new ] = self ._find_bond_parameters (
1221- new_system_bond_force , index1_new , index2_new )
1233+ r0_new , k_new = new_bond_parameters
1234+
12221235 self ._hybrid_system_forces ['core_bond_force' ].addBond (
12231236 index1_hybrid , index2_hybrid ,
12241237 [r0_old , k_old , r0_new , k_new ])
1238+ # track that we've added this bond
1239+ hybrid_bonds [self ._pair_key (index1_hybrid , index2_hybrid )] = True
12251240
12261241 # Check if the index set is a subset of anything besides
12271242 # environment (in the case of environment, we just add the bond to
@@ -1285,14 +1300,16 @@ def _handle_harmonic_bonds(self):
12851300 # This has some peculiarities to be discussed...
12861301 # TODO - Work out what the above peculiarities are...
12871302 elif index_set .issubset (self ._atom_classes ['core_atoms' ]):
1288- if not self ._find_bond_parameters (
1289- self ._hybrid_system_forces ['core_bond_force' ],
1290- index1_hybrid , index2_hybrid ):
1303+ bond_key = self ._pair_key (index1_hybrid , index2_hybrid )
1304+ if bond_key not in hybrid_bonds :
12911305 r0_old = r0_new
12921306 k_old = 0.0 * unit .kilojoule_per_mole / unit .angstrom ** 2
12931307 self ._hybrid_system_forces ['core_bond_force' ].addBond (
12941308 index1_hybrid , index2_hybrid ,
12951309 [r0_old , k_old , r0_new , k_new ])
1310+ # track that we've added this bond
1311+ hybrid_bonds [bond_key ] = True
1312+
12961313 elif index_set .issubset (self ._atom_classes ['environment_atoms' ]):
12971314 # Already been added
12981315 pass
@@ -1307,36 +1324,46 @@ def _handle_harmonic_bonds(self):
13071324 raise ValueError (errmsg )
13081325
13091326 @staticmethod
1310- def _find_angle_parameters ( angle_force , indices ) :
1327+ def _triplet_key ( a , b , c ) -> tuple [ int , int , int ] :
13111328 """
1312- Convenience function to find the angle parameters corresponding to a
1313- particular set of indices
1329+ Create a key for angle lookups that is invariant to the order of the first and third indices.
13141330
13151331 Parameters
13161332 ----------
1317- angle_force : openmm.HarmonicAngleForce
1318- The force where the angle of interest may be found.
1319- indices : list of int
1320- The indices (any order) of the angle atoms
1333+ a : int
1334+ The first atom index.
1335+ b : int
1336+ The second atom index (the central atom in the angle).
1337+ c : int
1338+ The third atom index.
13211339
13221340 Returns
13231341 -------
1324- angle_params : list
1325- list of angle parameters
1342+ tuple[int, int, int]
1343+ A tuple representing the angle, with the first and third indices sorted.
13261344 """
1327- indices_reversed = indices [:: - 1 ]
1345+ return ( a , b , c ) if a < c else ( c , b , a )
13281346
1329- # Now loop through and try to find the angle:
1330- for angle_index in range (angle_force .getNumAngles ()):
1331- angle_params = angle_force .getAngleParameters (angle_index )
1347+ @staticmethod
1348+ def _build_angle_lookup (angle_force ) -> dict :
1349+ """
1350+ Build a lookup dictionary for angle parameters given a HarmonicAngleForce.
13321351
1333- # Get a set representing the angle indices
1334- angle_param_indices = angle_params [:3 ]
1352+ Parameters
1353+ ----------
1354+ angle_force : openmm.HarmonicAngleForce
1355+ The angle force from which to build the lookup.
13351356
1336- if (indices == angle_param_indices or
1337- indices_reversed == angle_param_indices ):
1338- return angle_params
1339- return [] # Return empty if no matching angle found
1357+ Returns
1358+ -------
1359+ dict
1360+ A dictionary mapping a sorted tuple of atom indices to their angle parameters.
1361+ """
1362+ angle_lookup = {}
1363+ for angle_index in range (angle_force .getNumAngles ()):
1364+ index1 , index2 , index3 , theta0 , k = angle_force .getAngleParameters (angle_index )
1365+ angle_lookup [HybridTopologyFactory ._triplet_key (index1 , index2 , index3 )] = (theta0 , k )
1366+ return angle_lookup
13401367
13411368 def _handle_harmonic_angles (self ):
13421369 """
@@ -1363,6 +1390,10 @@ def _handle_harmonic_angles(self):
13631390 """
13641391 old_system_angle_force = self ._old_system_forces ['HarmonicAngleForce' ]
13651392 new_system_angle_force = self ._new_system_forces ['HarmonicAngleForce' ]
1393+ # build a lookup for the new system angles to save iterating through the force
1394+ new_angle_lookup = self ._build_angle_lookup (new_system_angle_force )
1395+ # hybrid angle tracking
1396+ hybrid_angles = {}
13661397
13671398 # First, loop through all the angles in the old system to determine
13681399 # what to do with them. We will only use the
@@ -1386,26 +1417,27 @@ def _handle_harmonic_angles(self):
13861417 new_indices = [
13871418 self ._old_to_new_map [old_atomid ] for old_atomid in old_angle_parameters [:3 ]
13881419 ]
1389- new_angle_parameters = self ._find_angle_parameters (
1390- new_system_angle_force , new_indices
1391- )
1392- if not new_angle_parameters :
1393- new_angle_parameters = [
1394- 0 , 0 , 0 , old_angle_parameters [3 ],
1395- 0.0 * unit .kilojoule_per_mole / unit .radian ** 2
1396- ]
1420+ new_angle_parameters = new_angle_lookup .get (self ._triplet_key (* new_indices ), None )
1421+
1422+ if new_angle_parameters is None :
1423+ new_theta0 = old_angle_parameters [3 ]
1424+ new_k = 0.0 * unit .kilojoule_per_mole / unit .radian ** 2
1425+ else :
1426+ new_theta0 , new_k = new_angle_parameters
13971427
13981428 # Add to the hybrid force:
13991429 # the parameters at indices 3 and 4 represent theta0 and k,
14001430 # respectively.
14011431 hybrid_force_parameters = [
14021432 old_angle_parameters [3 ], old_angle_parameters [4 ],
1403- new_angle_parameters [ 3 ], new_angle_parameters [ 4 ]
1433+ new_theta0 , new_k
14041434 ]
14051435 self ._hybrid_system_forces ['core_angle_force' ].addAngle (
14061436 hybrid_index_list [0 ], hybrid_index_list [1 ],
14071437 hybrid_index_list [2 ], hybrid_force_parameters
14081438 )
1439+ # track that we have added this angle for the hybrid system
1440+ hybrid_angles [self ._triplet_key (* hybrid_index_list )] = True
14091441
14101442 # Check if the atoms are neither all core nor all environment,
14111443 # which would mean they involve unique old interactions
@@ -1483,8 +1515,8 @@ def _handle_harmonic_angles(self):
14831515 )
14841516
14851517 elif hybrid_index_set .issubset (self ._atom_classes ['core_atoms' ]):
1486- if not self ._find_angle_parameters ( self . _hybrid_system_forces [ 'core_angle_force' ],
1487- hybrid_index_list ) :
1518+ angle_key = self ._triplet_key ( * hybrid_index_list )
1519+ if angle_key not in hybrid_angles :
14881520 hybrid_force_parameters = [
14891521 new_angle_parameters [3 ],
14901522 0.0 * unit .kilojoule_per_mole / unit .radian ** 2 ,
@@ -1494,6 +1526,9 @@ def _handle_harmonic_angles(self):
14941526 hybrid_index_list [0 ], hybrid_index_list [1 ],
14951527 hybrid_index_list [2 ], hybrid_force_parameters
14961528 )
1529+ # track that we have added this angle
1530+ hybrid_angles [angle_key ] = True
1531+
14971532 elif hybrid_index_set .issubset (self ._atom_classes ['environment_atoms' ]):
14981533 # We have already added the appropriate environmental atom
14991534 # terms
@@ -1863,8 +1898,11 @@ def _handle_hybrid_exceptions(self):
18631898 for atom_pair in unique_old_pairs :
18641899 # Since the pairs are indexed in the dictionary by the old system
18651900 # indices, we need to convert
1866- old_index_atom_pair = (self ._hybrid_to_old_map [atom_pair [0 ]],
1867- self ._hybrid_to_old_map [atom_pair [1 ]])
1901+ # use the exception key function to ensure we check using the correct order
1902+ old_index_atom_pair = self ._pair_key (
1903+ self ._hybrid_to_old_map [atom_pair [0 ]],
1904+ self ._hybrid_to_old_map [atom_pair [1 ]]
1905+ )
18681906
18691907 # Now we check if the pair is in the exception dictionary
18701908 if old_index_atom_pair in self ._old_system_exceptions :
@@ -1885,23 +1923,6 @@ def _handle_hybrid_exceptions(self):
18851923 atom_pair [0 ], atom_pair [1 ]
18861924 )
18871925
1888- # Check if the pair is in the reverse order and use that if so
1889- elif old_index_atom_pair [::- 1 ] in self ._old_system_exceptions :
1890- [chargeProd , sigma , epsilon ] = self ._old_system_exceptions [old_index_atom_pair [::- 1 ]]
1891- # If we are interpolating 1,4 exceptions then we have to
1892- if self ._interpolate_14s :
1893- self ._hybrid_system_forces ['standard_nonbonded_force' ].addException (
1894- atom_pair [0 ], atom_pair [1 ], chargeProd * 0.0 ,
1895- sigma , epsilon * 0.0
1896- )
1897- else :
1898- self ._hybrid_system_forces ['standard_nonbonded_force' ].addException (
1899- atom_pair [0 ], atom_pair [1 ], chargeProd , sigma , epsilon )
1900-
1901- # Add exclusion to ensure exceptions are consistent
1902- self ._hybrid_system_forces ['core_sterics_force' ].addExclusion (
1903- atom_pair [0 ], atom_pair [1 ])
1904-
19051926 # TODO: work out why there's a bunch of commented out code here
19061927 # Exerpt:
19071928 # If it's not handled by an exception in the original system, we
@@ -1915,8 +1936,10 @@ def _handle_hybrid_exceptions(self):
19151936 for atom_pair in unique_new_pairs :
19161937 # Since the pairs are indexed in the dictionary by the new system
19171938 # indices, we need to convert
1918- new_index_atom_pair = (self ._hybrid_to_new_map [atom_pair [0 ]],
1919- self ._hybrid_to_new_map [atom_pair [1 ]])
1939+ new_index_atom_pair = self ._pair_key (
1940+ self ._hybrid_to_new_map [atom_pair [0 ]],
1941+ self ._hybrid_to_new_map [atom_pair [1 ]]
1942+ )
19201943
19211944 # Now we check if the pair is in the exception dictionary
19221945 if new_index_atom_pair in self ._new_system_exceptions :
@@ -1934,25 +1957,6 @@ def _handle_hybrid_exceptions(self):
19341957 self ._hybrid_system_forces ['core_sterics_force' ].addExclusion (
19351958 atom_pair [0 ], atom_pair [1 ]
19361959 )
1937-
1938- # Check if the pair is present in the reverse order and use that if so
1939- elif new_index_atom_pair [::- 1 ] in self ._new_system_exceptions :
1940- [chargeProd , sigma , epsilon ] = self ._new_system_exceptions [new_index_atom_pair [::- 1 ]]
1941- if self ._interpolate_14s :
1942- self ._hybrid_system_forces ['standard_nonbonded_force' ].addException (
1943- atom_pair [0 ], atom_pair [1 ], chargeProd * 0.0 ,
1944- sigma , epsilon * 0.0
1945- )
1946- else :
1947- self ._hybrid_system_forces ['standard_nonbonded_force' ].addException (
1948- atom_pair [0 ], atom_pair [1 ], chargeProd , sigma , epsilon
1949- )
1950-
1951- self ._hybrid_system_forces ['core_sterics_force' ].addExclusion (
1952- atom_pair [0 ], atom_pair [1 ]
1953- )
1954-
1955-
19561960 # TODO: work out why there's a bunch of commented out code here
19571961 # If it's not handled by an exception in the original system, we
19581962 # just add the regular parameters as an exception
@@ -2027,7 +2031,7 @@ def _handle_original_exceptions(self):
20272031 index1_new = hybrid_to_new_map [index1_hybrid ]
20282032 index2_new = hybrid_to_new_map [index2_hybrid ]
20292033 # Get the exception parameters: make sure to sort the keys to match how they are stored
2030- new_exception_parms = self ._new_system_exceptions .get (tuple ( sorted ([ index1_new , index2_new ]) ), [])
2034+ new_exception_parms = self ._new_system_exceptions .get (self . _pair_key ( index1_new , index2_new ), [])
20312035
20322036 # If there's no new exception, then we should just set the
20332037 # exception parameters to be the nonbonded parameters
@@ -2112,7 +2116,7 @@ def _handle_original_exceptions(self):
21122116
21132117 # See if it's also in the old nonbonded force. if it is, then we don't need to add it.
21142118 # But if it's not, we need to interpolate
2115- old_exception_parms = self ._old_system_exceptions .get (tuple ( sorted ([ index1_old , index2_old ]) ), [])
2119+ old_exception_parms = self ._old_system_exceptions .get (self . _pair_key ( index1_old , index2_old ), [])
21162120 if not old_exception_parms :
21172121
21182122 [charge1_old , sigma1_old , epsilon1_old ] = old_system_nonbonded_force .getParticleParameters (index1_old )
0 commit comments