Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
43d6007
Migrate validation to Protocol._validate
IAlibay Dec 5, 2025
2cd56ba
some fixes
IAlibay Dec 5, 2025
70e6d7a
Merge branch 'main' into validate-rfe
IAlibay Dec 5, 2025
f330562
move some things around
IAlibay Dec 8, 2025
95b92b3
Merge branch 'main' into validate-rfe
IAlibay Dec 15, 2025
1e0153e
add validate endstate tests
IAlibay Dec 15, 2025
fe2b879
Merge branch 'validate-rfe' of github.com:OpenFreeEnergy/openfe into …
IAlibay Dec 15, 2025
fbc4554
validate mapping tests
IAlibay Dec 15, 2025
c2f49d2
net charge validation tests
IAlibay Dec 15, 2025
c50f99c
more stuff
IAlibay Dec 22, 2025
9e0d29b
remove old tests
IAlibay Dec 24, 2025
2fe8ff9
make hybrid samplers not rely on htf
IAlibay Dec 24, 2025
4a0bd26
fix up test
IAlibay Dec 24, 2025
5848adc
fix up some slow tests
IAlibay Dec 24, 2025
1aaef87
Merge branch 'main' into multistate-nohtf
IAlibay Dec 24, 2025
b6d5ecd
Fix up the one test
IAlibay Dec 26, 2025
0605d11
fix a few things
IAlibay Dec 26, 2025
48106a2
fix the remaining tests
IAlibay Dec 26, 2025
5af66e8
cleanup imports
IAlibay Dec 26, 2025
ad0b5fb
Merge branch 'validate-rfe' into move-rfe-protocol
IAlibay Dec 26, 2025
45e004c
Merge branch 'multistate-nohtf' into move-rfe-protocol
IAlibay Dec 26, 2025
58dd71c
Migrate protocol, units, and results for the hybridtop protocol
IAlibay Dec 26, 2025
792996e
Add news item
IAlibay Dec 26, 2025
91f1788
Merge branch 'validate-rfe' into move-rfe-protocol
IAlibay Dec 26, 2025
527b870
Merge branch 'main' into validate-rfe
IAlibay Dec 26, 2025
7d17998
fix redefine
IAlibay Dec 27, 2025
43eb947
start modularising everything
IAlibay Dec 27, 2025
d1bd736
Add charge validation for smcs when dealing with ismorphic molecules
IAlibay Dec 27, 2025
51a6de1
break down the rfe units into bits
IAlibay Dec 29, 2025
6a5a76a
more broadly disallow oechem as a backend when creating systems
IAlibay Dec 29, 2025
cdd3da0
fix issue with nc being undefined
IAlibay Dec 29, 2025
e0a8e2a
Merge branch 'validate-rfe' into move-rfe-protocol
IAlibay Dec 29, 2025
a0ef737
Merge branch 'move-rfe-protocol' into breakdown-rfe-protocolunit
IAlibay Dec 29, 2025
b826803
Fix missing import
IAlibay Dec 29, 2025
42ddbcf
Merge branch 'move-rfe-protocol' into breakdown-rfe-protocolunit
IAlibay Dec 29, 2025
063e8ce
Fix comp getter
IAlibay Dec 29, 2025
3844bb5
Merge branch 'move-rfe-protocol' into breakdown-rfe-protocolunit
IAlibay Dec 29, 2025
a98c799
update module name
IAlibay Dec 30, 2025
5d0bc7e
Merge branch 'move-rfe-protocol' into breakdown-rfe-protocolunit
IAlibay Dec 30, 2025
7c915ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2026
951ac15
move a few things around to make life easier
IAlibay Jan 3, 2026
b9f8264
Merge branch 'main' into breakdown-rfe-protocolunit
IAlibay Jan 7, 2026
2e4b455
fix typo
IAlibay Jan 7, 2026
7182805
fix some merge issues
IAlibay Jan 7, 2026
28b4381
fix test failures due to integrator checks
IAlibay Jan 7, 2026
726f517
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2026
1587673
try to make mypy happy
IAlibay Jan 7, 2026
5cca950
Merge branch 'breakdown-rfe-protocolunit' of github.com:OpenFreeEnerg…
IAlibay Jan 7, 2026
1fbec7d
add early exist if there's no molecules
IAlibay Jan 7, 2026
3cd758e
Apply suggestions from code review
IAlibay Jan 7, 2026
6622428
Update openfe/protocols/openmm_rfe/hybridtop_units.py
IAlibay Jan 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make hybrid samplers not rely on htf
  • Loading branch information
IAlibay committed Dec 24, 2025
commit 2fe8ff937683e2d237cddb0b54ea9ace79be1cf7
61 changes: 38 additions & 23 deletions openfe/protocols/openmm_rfe/_rfe_utils/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
logger = logging.getLogger(__name__)


class HybridCompatibilityMixin(object):
class HybridCompatibilityMixin:
"""
Mixin that allows the MultistateSampler to accommodate the situation where
unsampled endpoints have a different number of degrees of freedom.
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
self._hybrid_factory = hybrid_factory
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
self._hybrid_system = hybrid_system
self._hybrid_positions = hybrid_positions
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)

def setup(self, reporter, lambda_protocol,
Expand Down Expand Up @@ -73,15 +74,17 @@ class creation of LambdaProtocol.
"""
n_states = len(lambda_protocol.lambda_schedule)

hybrid_system = self._factory.hybrid_system
lambda_zero_state = RelativeAlchemicalState.from_system(self._hybrid_system)

lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system)
thermostate = ThermodynamicState(
self._hybrid_system,
temperature=temperature
)

thermostate = ThermodynamicState(hybrid_system,
temperature=temperature)
compound_thermostate = CompoundThermodynamicState(
thermostate,
composable_states=[lambda_zero_state])
thermostate,
composable_states=[lambda_zero_state]
)

# create lists for storing thermostates and sampler states
thermodynamic_state_list = []
Expand All @@ -105,16 +108,20 @@ class creation of LambdaProtocol.
raise ValueError(errmsg)

# starting with the hybrid factory positions
box = hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(self._factory.hybrid_positions,
box_vectors=box)
box = self._hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(
self._hybrid_positions,
box_vectors=box
)

# Loop over the lambdas and create & store a compound thermostate at
# that lambda value
for lambda_val in lambda_schedule:
compound_thermostate_copy = copy.deepcopy(compound_thermostate)
compound_thermostate_copy.set_alchemical_parameters(
lambda_val, lambda_protocol)
lambda_val,
lambda_protocol
)
thermodynamic_state_list.append(compound_thermostate_copy)

# now generating a sampler_state for each thermodyanmic state,
Expand Down Expand Up @@ -143,7 +150,8 @@ class creation of LambdaProtocol.
# generating unsampled endstates
unsampled_dispersion_endstates = create_endstates(
copy.deepcopy(thermodynamic_state_list[0]),
copy.deepcopy(thermodynamic_state_list[-1]))
copy.deepcopy(thermodynamic_state_list[-1])
)
self.create(thermodynamic_states=thermodynamic_state_list,
sampler_states=sampler_state_list, storage=reporter,
unsampled_thermodynamic_states=unsampled_dispersion_endstates)
Expand All @@ -159,10 +167,13 @@ class HybridRepexSampler(HybridCompatibilityMixin,
number of positions
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridRepexSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs)
self._factory = hybrid_factory
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)


class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler):
Expand All @@ -171,11 +182,13 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler):
of positions
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridSAMSSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory


class HybridMultiStateSampler(HybridCompatibilityMixin,
Expand All @@ -184,11 +197,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin,
MultiStateSampler that supports unsample end states with a different
number of positions
"""
def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridMultiStateSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory


def create_endstates(first_thermostate, last_thermostate):
Expand Down
17 changes: 12 additions & 5 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,15 +1128,17 @@ def run(
if sampler_settings.sampler_method.lower() == "repex":
sampler = _rfe_utils.multistate.HybridRepexSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
hybrid_system=hybrid_factory.hybrid_system,
hybrid_positions=hybrid_factory.hybrid_positions,
online_analysis_interval=rta_its,
online_analysis_target_error=early_termination_target_error,
online_analysis_minimum_iterations=rta_min_its,
)
elif sampler_settings.sampler_method.lower() == "sams":
sampler = _rfe_utils.multistate.HybridSAMSSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
hybrid_system=hybrid_factory.hybrid_system,
hybrid_positions=hybrid_factory.hybrid_positions,
online_analysis_interval=rta_its,
online_analysis_minimum_iterations=rta_min_its,
flatness_criteria=sampler_settings.sams_flatness_criteria,
Expand All @@ -1145,12 +1147,12 @@ def run(
elif sampler_settings.sampler_method.lower() == "independent":
sampler = _rfe_utils.multistate.HybridMultiStateSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
hybrid_system=hybrid_factory.hybrid_system,
hybrid_positions=hybrid_factory.hybrid_positions,
online_analysis_interval=rta_its,
online_analysis_target_error=early_termination_target_error,
online_analysis_minimum_iterations=rta_min_its,
)

else:
raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}")

Expand Down Expand Up @@ -1247,7 +1249,12 @@ def run(
if not dry: # pragma: no-cover
return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict}
else:
return {"debug": {"sampler": sampler}}
return {"debug":
{
"sampler": sampler,
"hybrid_factory": hybrid_factory
}
}

@staticmethod
def structural_analysis(scratch, shared) -> dict:
Expand Down
27 changes: 15 additions & 12 deletions openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,14 @@ def test_dry_run_default_vacuum(
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
debug = dag_unit.run(dry=True)["debug"]
sampler = debug["sampler"]
assert isinstance(sampler, MultiStateSampler)
assert not sampler.is_periodic
assert sampler._thermodynamic_states[0].barostat is None

# Check hybrid OMM and MDTtraj Topologies
htf = sampler._hybrid_factory
htf = debug["hybrid_factory"]
# 16 atoms:
# 11 common atoms, 1 extra hydrogen in benzene, 4 extra in toluene
# 12 bonds in benzene + 4 extra toluene bonds
Expand Down Expand Up @@ -414,7 +415,7 @@ def test_dry_core_element_change(vac_settings, tmpdir):

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
system = sampler._hybrid_factory.hybrid_system
system = sampler._hybrid_system
assert system.getNumParticles() == 12
# Average mass between nitrogen and carbon
assert system.getParticleMass(1) == 12.0127235 * omm_unit.amu
Expand Down Expand Up @@ -518,7 +519,7 @@ def tip4p_hybrid_factory(
shared_basepath=shared_temp,
)

return dag_unit_result["debug"]["sampler"]._factory
return dag_unit_result["debug"]["hybrid_factory"]


def test_tip4p_particle_count(tip4p_hybrid_factory):
Expand Down Expand Up @@ -624,7 +625,7 @@ def test_dry_run_ligand_system_cutoff(

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
hs = sampler._factory.hybrid_system
hs = sampler._hybrid_system

nbfs = [
f
Expand Down Expand Up @@ -691,9 +692,10 @@ def test_dry_run_charge_backends(
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
htf = sampler._factory
hybrid_system = htf.hybrid_system
debug = dag_unit.run(dry=True)["debug"]
sampler = debug["sampler"]
htf = debug["hybrid_factory"]
hybrid_system = sampler._hybrid_system

# get the standard nonbonded force
nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)]
Expand Down Expand Up @@ -785,9 +787,10 @@ def check_propchgs(smc, charge_array):
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
htf = sampler._factory
hybrid_system = htf.hybrid_system
debug = dag_unit.run(dry=True)["debug"]
sampler = debug["sampler"]
htf = debug["hybrid_factory"]
hybrid_system = sampler._hybrid_system

# get the standard nonbonded force
nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)]
Expand Down Expand Up @@ -902,7 +905,7 @@ def test_dodecahdron_ligand_box(

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
hs = sampler._factory.hybrid_system
hs = sampler._hybrid_system

vectors = hs.getDefaultPeriodicBoxVectors()

Expand Down