Skip to content

Commit bccc53c

Browse files
committed
remove force lookups by looping and use hashing
1 parent f06a7dd commit bccc53c

File tree

1 file changed

+104
-100
lines changed

1 file changed

+104
-100
lines changed

src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py

Lines changed: 104 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,25 @@ def _invert_dict(dictionary):
415415
"""
416416
return {v: k for k, v in dictionary.items()}
417417

418+
@staticmethod
419+
def _pair_key(particle1, particle2):
420+
"""
421+
Convenience method to generate a key for a pair of atom indexes with consistent normalization.
422+
423+
Parameters
424+
----------
425+
particle1 : int
426+
Index of first particle in exception
427+
particle2 : int
428+
Index of second particle in exception
429+
430+
Returns
431+
-------
432+
tuple
433+
Sorted tuple of the two particle indices, which is used as a key in the exception dicts.
434+
"""
435+
return (particle1, particle2) if particle1 < particle2 else (particle2, particle1)
436+
418437
def _set_mappings(self, old_to_new_map, core_old_to_new_map):
419438
"""
420439
Parameters
@@ -628,7 +647,7 @@ def _generate_dict_from_exceptions(force):
628647

629648
for exception_index in range(force.getNumExceptions()):
630649
[index1, index2, chargeProd, sigma, epsilon] = force.getExceptionParameters(exception_index)
631-
exceptions_dict[tuple(sorted([index1, index2]))] = [chargeProd, sigma, epsilon]
650+
exceptions_dict[HybridTopologyFactory._pair_key(index1, index2)] = [chargeProd, sigma, epsilon]
632651

633652
return exceptions_dict
634653

@@ -684,7 +703,7 @@ def _handle_constraints(self):
684703
):
685704
for const_idx in range(system.getNumConstraints()):
686705
at1, at2, length = system.getConstraintParameters(const_idx)
687-
hybrid_atoms = tuple(sorted([hybrid_map[at1], hybrid_map[at2]]))
706+
hybrid_atoms = self._pair_key(hybrid_map[at1], hybrid_map[at2])
688707
if hybrid_atoms not in constraint_lengths.keys():
689708
# add to the system
690709
self._hybrid_system.addConstraint(hybrid_atoms[0],
@@ -1144,33 +1163,25 @@ def _add_nonbonded_force_terms(self):
11441163
sterics_custom_nonbonded_force.setUseSwitchingFunction(False)
11451164

11461165
@staticmethod
1147-
def _find_bond_parameters(bond_force, index1, index2):
1166+
def _build_bond_lookup(bond_force) -> dict:
11481167
"""
1149-
This is a convenience function to find bond parameters in another
1150-
system given the two indices.
1168+
Build a lookup dictionary for bond parameters given a HarmonicBondForce.
11511169
11521170
Parameters
11531171
----------
11541172
bond_force : openmm.HarmonicBondForce
1155-
The bond force where the parameters should be found
1156-
index1 : int
1157-
Index1 (order does not matter) of the bond atoms
1158-
index2 : int
1159-
Index2 (order does not matter) of the bond atoms
1173+
The bond force from which to build the lookup.
11601174
11611175
Returns
11621176
-------
1163-
bond_parameters : list
1164-
List of relevant bond parameters
1177+
dict
1178+
A dictionary mapping a sorted tuple of atom indices to their bond parameters.
11651179
"""
1166-
index_set = {index1, index2}
1167-
# Loop through all the bonds:
1180+
bond_lookup = {}
11681181
for bond_index in range(bond_force.getNumBonds()):
1169-
parms = bond_force.getBondParameters(bond_index)
1170-
if index_set == {parms[0], parms[1]}:
1171-
return parms
1172-
1173-
return []
1182+
index1, index2, r0, k = bond_force.getBondParameters(bond_index)
1183+
bond_lookup[HybridTopologyFactory._pair_key(index1, index2)] = (r0, k)
1184+
return bond_lookup
11741185

11751186
def _handle_harmonic_bonds(self):
11761187
"""
@@ -1191,6 +1202,10 @@ def _handle_harmonic_bonds(self):
11911202
"""
11921203
old_system_bond_force = self._old_system_forces['HarmonicBondForce']
11931204
new_system_bond_force = self._new_system_forces['HarmonicBondForce']
1205+
# build a lookup for the new bonds so we don't have to loop through the force
1206+
new_bond_lookup = self._build_bond_lookup(new_system_bond_force)
1207+
# track the terms in the hybrid system
1208+
hybrid_bonds = {}
11941209

11951210
# First, loop through the old system bond forces and add relevant terms
11961211
for bond_index in range(old_system_bond_force.getNumBonds()):
@@ -1210,18 +1225,18 @@ def _handle_harmonic_bonds(self):
12101225
if index_set.issubset(self._atom_classes['core_atoms']):
12111226
index1_new = self._old_to_new_map[index1_old]
12121227
index2_new = self._old_to_new_map[index2_old]
1213-
new_bond_parameters = self._find_bond_parameters(
1214-
new_system_bond_force, index1_new, index2_new)
1215-
if not new_bond_parameters:
1228+
new_bond_parameters = new_bond_lookup.get(self._pair_key(index1_new, index2_new), None)
1229+
if new_bond_parameters is None:
12161230
r0_new = r0_old
12171231
k_new = 0.0*unit.kilojoule_per_mole/unit.angstrom**2
12181232
else:
1219-
# TODO - why is this being recalculated?
1220-
[index1, index2, r0_new, k_new] = self._find_bond_parameters(
1221-
new_system_bond_force, index1_new, index2_new)
1233+
r0_new, k_new = new_bond_parameters
1234+
12221235
self._hybrid_system_forces['core_bond_force'].addBond(
12231236
index1_hybrid, index2_hybrid,
12241237
[r0_old, k_old, r0_new, k_new])
1238+
# track that we've added this bond
1239+
hybrid_bonds[self._pair_key(index1_hybrid, index2_hybrid)] = True
12251240

12261241
# Check if the index set is a subset of anything besides
12271242
# environment (in the case of environment, we just add the bond to
@@ -1285,14 +1300,16 @@ def _handle_harmonic_bonds(self):
12851300
# This has some peculiarities to be discussed...
12861301
# TODO - Work out what the above peculiarities are...
12871302
elif index_set.issubset(self._atom_classes['core_atoms']):
1288-
if not self._find_bond_parameters(
1289-
self._hybrid_system_forces['core_bond_force'],
1290-
index1_hybrid, index2_hybrid):
1303+
bond_key = self._pair_key(index1_hybrid, index2_hybrid)
1304+
if bond_key not in hybrid_bonds:
12911305
r0_old = r0_new
12921306
k_old = 0.0*unit.kilojoule_per_mole/unit.angstrom**2
12931307
self._hybrid_system_forces['core_bond_force'].addBond(
12941308
index1_hybrid, index2_hybrid,
12951309
[r0_old, k_old, r0_new, k_new])
1310+
# track that we've added this bond
1311+
hybrid_bonds[bond_key] = True
1312+
12961313
elif index_set.issubset(self._atom_classes['environment_atoms']):
12971314
# Already been added
12981315
pass
@@ -1307,36 +1324,46 @@ def _handle_harmonic_bonds(self):
13071324
raise ValueError(errmsg)
13081325

13091326
@staticmethod
1310-
def _find_angle_parameters(angle_force, indices):
1327+
def _triplet_key(a, b, c) -> tuple[int, int, int]:
13111328
"""
1312-
Convenience function to find the angle parameters corresponding to a
1313-
particular set of indices
1329+
Create a key for angle lookups that is invariant to the order of the first and third indices.
13141330
13151331
Parameters
13161332
----------
1317-
angle_force : openmm.HarmonicAngleForce
1318-
The force where the angle of interest may be found.
1319-
indices : list of int
1320-
The indices (any order) of the angle atoms
1333+
a : int
1334+
The first atom index.
1335+
b : int
1336+
The second atom index (the central atom in the angle).
1337+
c : int
1338+
The third atom index.
13211339
13221340
Returns
13231341
-------
1324-
angle_params : list
1325-
list of angle parameters
1342+
tuple[int, int, int]
1343+
A tuple representing the angle, with the first and third indices sorted.
13261344
"""
1327-
indices_reversed = indices[::-1]
1345+
return (a, b, c) if a < c else (c, b, a)
13281346

1329-
# Now loop through and try to find the angle:
1330-
for angle_index in range(angle_force.getNumAngles()):
1331-
angle_params = angle_force.getAngleParameters(angle_index)
1347+
@staticmethod
1348+
def _build_angle_lookup(angle_force) -> dict:
1349+
"""
1350+
Build a lookup dictionary for angle parameters given a HarmonicAngleForce.
13321351
1333-
# Get a set representing the angle indices
1334-
angle_param_indices = angle_params[:3]
1352+
Parameters
1353+
----------
1354+
angle_force : openmm.HarmonicAngleForce
1355+
The angle force from which to build the lookup.
13351356
1336-
if (indices == angle_param_indices or
1337-
indices_reversed == angle_param_indices):
1338-
return angle_params
1339-
return [] # Return empty if no matching angle found
1357+
Returns
1358+
-------
1359+
dict
1360+
A dictionary mapping a sorted tuple of atom indices to their angle parameters.
1361+
"""
1362+
angle_lookup = {}
1363+
for angle_index in range(angle_force.getNumAngles()):
1364+
index1, index2, index3, theta0, k = angle_force.getAngleParameters(angle_index)
1365+
angle_lookup[HybridTopologyFactory._triplet_key(index1, index2, index3)] = (theta0, k)
1366+
return angle_lookup
13401367

13411368
def _handle_harmonic_angles(self):
13421369
"""
@@ -1363,6 +1390,10 @@ def _handle_harmonic_angles(self):
13631390
"""
13641391
old_system_angle_force = self._old_system_forces['HarmonicAngleForce']
13651392
new_system_angle_force = self._new_system_forces['HarmonicAngleForce']
1393+
# build a lookup for the new system angles to save iterating through the force
1394+
new_angle_lookup = self._build_angle_lookup(new_system_angle_force)
1395+
# hybrid angle tracking
1396+
hybrid_angles = {}
13661397

13671398
# First, loop through all the angles in the old system to determine
13681399
# what to do with them. We will only use the
@@ -1386,26 +1417,27 @@ def _handle_harmonic_angles(self):
13861417
new_indices = [
13871418
self._old_to_new_map[old_atomid] for old_atomid in old_angle_parameters[:3]
13881419
]
1389-
new_angle_parameters = self._find_angle_parameters(
1390-
new_system_angle_force, new_indices
1391-
)
1392-
if not new_angle_parameters:
1393-
new_angle_parameters = [
1394-
0, 0, 0, old_angle_parameters[3],
1395-
0.0*unit.kilojoule_per_mole/unit.radian**2
1396-
]
1420+
new_angle_parameters = new_angle_lookup.get(self._triplet_key(*new_indices), None)
1421+
1422+
if new_angle_parameters is None:
1423+
new_theta0 = old_angle_parameters[3]
1424+
new_k = 0.0*unit.kilojoule_per_mole/unit.radian**2
1425+
else:
1426+
new_theta0, new_k = new_angle_parameters
13971427

13981428
# Add to the hybrid force:
13991429
# the parameters at indices 3 and 4 represent theta0 and k,
14001430
# respectively.
14011431
hybrid_force_parameters = [
14021432
old_angle_parameters[3], old_angle_parameters[4],
1403-
new_angle_parameters[3], new_angle_parameters[4]
1433+
new_theta0, new_k
14041434
]
14051435
self._hybrid_system_forces['core_angle_force'].addAngle(
14061436
hybrid_index_list[0], hybrid_index_list[1],
14071437
hybrid_index_list[2], hybrid_force_parameters
14081438
)
1439+
# track that we have added this angle for the hybrid system
1440+
hybrid_angles[self._triplet_key(*hybrid_index_list)] = True
14091441

14101442
# Check if the atoms are neither all core nor all environment,
14111443
# which would mean they involve unique old interactions
@@ -1483,8 +1515,8 @@ def _handle_harmonic_angles(self):
14831515
)
14841516

14851517
elif hybrid_index_set.issubset(self._atom_classes['core_atoms']):
1486-
if not self._find_angle_parameters(self._hybrid_system_forces['core_angle_force'],
1487-
hybrid_index_list):
1518+
angle_key = self._triplet_key(*hybrid_index_list)
1519+
if angle_key not in hybrid_angles:
14881520
hybrid_force_parameters = [
14891521
new_angle_parameters[3],
14901522
0.0*unit.kilojoule_per_mole/unit.radian**2,
@@ -1494,6 +1526,9 @@ def _handle_harmonic_angles(self):
14941526
hybrid_index_list[0], hybrid_index_list[1],
14951527
hybrid_index_list[2], hybrid_force_parameters
14961528
)
1529+
# track that we have added this angle
1530+
hybrid_angles[angle_key] = True
1531+
14971532
elif hybrid_index_set.issubset(self._atom_classes['environment_atoms']):
14981533
# We have already added the appropriate environmental atom
14991534
# terms
@@ -1863,8 +1898,11 @@ def _handle_hybrid_exceptions(self):
18631898
for atom_pair in unique_old_pairs:
18641899
# Since the pairs are indexed in the dictionary by the old system
18651900
# indices, we need to convert
1866-
old_index_atom_pair = (self._hybrid_to_old_map[atom_pair[0]],
1867-
self._hybrid_to_old_map[atom_pair[1]])
1901+
# use the exception key function to ensure we check using the correct order
1902+
old_index_atom_pair = self._pair_key(
1903+
self._hybrid_to_old_map[atom_pair[0]],
1904+
self._hybrid_to_old_map[atom_pair[1]]
1905+
)
18681906

18691907
# Now we check if the pair is in the exception dictionary
18701908
if old_index_atom_pair in self._old_system_exceptions:
@@ -1885,23 +1923,6 @@ def _handle_hybrid_exceptions(self):
18851923
atom_pair[0], atom_pair[1]
18861924
)
18871925

1888-
# Check if the pair is in the reverse order and use that if so
1889-
elif old_index_atom_pair[::-1] in self._old_system_exceptions:
1890-
[chargeProd, sigma, epsilon] = self._old_system_exceptions[old_index_atom_pair[::-1]]
1891-
# If we are interpolating 1,4 exceptions then we have to
1892-
if self._interpolate_14s:
1893-
self._hybrid_system_forces['standard_nonbonded_force'].addException(
1894-
atom_pair[0], atom_pair[1], chargeProd*0.0,
1895-
sigma, epsilon*0.0
1896-
)
1897-
else:
1898-
self._hybrid_system_forces['standard_nonbonded_force'].addException(
1899-
atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon)
1900-
1901-
# Add exclusion to ensure exceptions are consistent
1902-
self._hybrid_system_forces['core_sterics_force'].addExclusion(
1903-
atom_pair[0], atom_pair[1])
1904-
19051926
# TODO: work out why there's a bunch of commented out code here
19061927
# Exerpt:
19071928
# If it's not handled by an exception in the original system, we
@@ -1915,8 +1936,10 @@ def _handle_hybrid_exceptions(self):
19151936
for atom_pair in unique_new_pairs:
19161937
# Since the pairs are indexed in the dictionary by the new system
19171938
# indices, we need to convert
1918-
new_index_atom_pair = (self._hybrid_to_new_map[atom_pair[0]],
1919-
self._hybrid_to_new_map[atom_pair[1]])
1939+
new_index_atom_pair = self._pair_key(
1940+
self._hybrid_to_new_map[atom_pair[0]],
1941+
self._hybrid_to_new_map[atom_pair[1]]
1942+
)
19201943

19211944
# Now we check if the pair is in the exception dictionary
19221945
if new_index_atom_pair in self._new_system_exceptions:
@@ -1934,25 +1957,6 @@ def _handle_hybrid_exceptions(self):
19341957
self._hybrid_system_forces['core_sterics_force'].addExclusion(
19351958
atom_pair[0], atom_pair[1]
19361959
)
1937-
1938-
# Check if the pair is present in the reverse order and use that if so
1939-
elif new_index_atom_pair[::-1] in self._new_system_exceptions:
1940-
[chargeProd, sigma, epsilon] = self._new_system_exceptions[new_index_atom_pair[::-1]]
1941-
if self._interpolate_14s:
1942-
self._hybrid_system_forces['standard_nonbonded_force'].addException(
1943-
atom_pair[0], atom_pair[1], chargeProd*0.0,
1944-
sigma, epsilon*0.0
1945-
)
1946-
else:
1947-
self._hybrid_system_forces['standard_nonbonded_force'].addException(
1948-
atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon
1949-
)
1950-
1951-
self._hybrid_system_forces['core_sterics_force'].addExclusion(
1952-
atom_pair[0], atom_pair[1]
1953-
)
1954-
1955-
19561960
# TODO: work out why there's a bunch of commented out code here
19571961
# If it's not handled by an exception in the original system, we
19581962
# just add the regular parameters as an exception
@@ -2027,7 +2031,7 @@ def _handle_original_exceptions(self):
20272031
index1_new = hybrid_to_new_map[index1_hybrid]
20282032
index2_new = hybrid_to_new_map[index2_hybrid]
20292033
# Get the exception parameters: make sure to sort the keys to match how they are stored
2030-
new_exception_parms = self._new_system_exceptions.get(tuple(sorted([index1_new, index2_new])), [])
2034+
new_exception_parms = self._new_system_exceptions.get(self._pair_key(index1_new, index2_new), [])
20312035

20322036
# If there's no new exception, then we should just set the
20332037
# exception parameters to be the nonbonded parameters
@@ -2112,7 +2116,7 @@ def _handle_original_exceptions(self):
21122116

21132117
# See if it's also in the old nonbonded force. if it is, then we don't need to add it.
21142118
# But if it's not, we need to interpolate
2115-
old_exception_parms = self._old_system_exceptions.get(tuple(sorted([index1_old, index2_old])), [])
2119+
old_exception_parms = self._old_system_exceptions.get(self._pair_key(index1_old, index2_old), [])
21162120
if not old_exception_parms:
21172121

21182122
[charge1_old, sigma1_old, epsilon1_old] = old_system_nonbonded_force.getParticleParameters(index1_old)

0 commit comments

Comments
 (0)