From d7cf9d49fbbd0dcb390e38eaddbb5fb29b231b39 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 1 Jun 2026 15:05:51 +0100 Subject: [PATCH] feat: cross-Analysis shared per-evaluation state in FactorGraphModel Add an opt-in mechanism for the per-factor Analysis objects of a FactorGraphModel to compute a model-dependent object once per likelihood evaluation and reuse it across every factor, instead of each factor recomputing identical work. - Analysis.shared_state_from(instance) -> None: opt-in hook (default None), the per-evaluation cross-factor sibling of modify_before_fit. - Analysis.log_likelihood_function gains an optional shared= kwarg. - FactorGraphModel computes the shared object once from the lead factor and forwards it to each factor only when non-None, so existing graphs are byte-for-byte unchanged. - AnalysisFactor forwards shared=; PriorFactor and HierarchicalFactor accept and ignore it for a uniform calling convention. - af.ex.Analysis demonstrates it on the 1D Gaussian toy via an opt-in share_model_data flag. - New unit tests in test_autofit/graphical/test_shared_state.py. Co-Authored-By: Claude Opus 4.8 (1M context) --- autofit/example/analysis.py | 52 ++++++- autofit/graphical/declarative/collection.py | 38 ++++- .../graphical/declarative/factor/analysis.py | 6 +- .../declarative/factor/hierarchical.py | 2 +- autofit/graphical/declarative/factor/prior.py | 6 +- autofit/non_linear/analysis/analysis.py | 42 +++++- test_autofit/graphical/test_shared_state.py | 131 ++++++++++++++++++ 7 files changed, 268 insertions(+), 9 deletions(-) create mode 100644 test_autofit/graphical/test_shared_state.py diff --git a/autofit/example/analysis.py b/autofit/example/analysis.py index 8ed8a5ffd..f2681dc5f 100644 --- a/autofit/example/analysis.py +++ b/autofit/example/analysis.py @@ -36,7 +36,13 @@ class Analysis(af.Analysis): LATENT_KEYS = ["gaussian.fwhm"] - def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False): + def __init__( + self, + data: np.ndarray, + noise_map: np.ndarray, + use_jax=False, + share_model_data=False, + ): """ In this example the `Analysis` object only contains the data and noise-map. It can be easily extended, for more complex data-sets and model fitting problems. @@ -48,13 +54,45 @@ def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False): noise_map A 1D numpy array containing the noise values of the data, used for computing the goodness of fit metric. + share_model_data + If `True`, opt this `Analysis` into the `FactorGraphModel` cross-factor shared-state mechanism + (see `shared_state_from`). This is only valid when the *entire* model is shared across every + factor, so the model data is identical for all of them and can be computed once instead of being + rebuilt by each factor. It is `False` by default, so the standard per-analysis behaviour is + unchanged. """ super().__init__(use_jax=use_jax) self.data = data self.noise_map = noise_map + self.share_model_data = share_model_data + + def shared_state_from(self, instance: af.ModelInstance): + """ + Compute the model data once so that it can be shared across the factors of a `FactorGraphModel`. + + This is the worked example of `Analysis.shared_state_from` (see that method). When every factor of + the graph shares the *entire* model — for example several datasets fit by the same 1D profile via + shared priors — the model data is identical for every factor, so it is wasteful to rebuild it once + per factor. Returning it here means the `FactorGraphModel` computes it a single time on the lead + factor and reuses it for all the others. + + In this toy the model data is cheap, but it stands in for an expensive shared computation: it is the + 1D analog of the lensing case, where the shared work (ray-tracing, the source-plane mapper, the + mapping matrix and the curvature matrix) dominates the per-factor cost. + + Sharing is opt-in (`share_model_data`) because it is only correct when the model really is fully + shared. If only some parameters are shared the model data differs between factors and this returns + `None`, so each factor computes its own as usual. + """ + if not self.share_model_data: + return None + + return self.model_data_1d_from(instance=instance) - def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: + def log_likelihood_function( + self, instance: af.ModelInstance, shared=None, xp=np + ) -> float: """ Determine the log likelihood of a fit of multiple profiles to the dataset. @@ -62,12 +100,20 @@ def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: ---------- instance : af.Collection The model instances of the profiles. + shared + The model data shared across the factors of a `FactorGraphModel`, computed once by + `shared_state_from` (see that method). When provided it is used directly instead of being + recomputed here; when `None` (the default, e.g. a standalone fit) the model data is computed + as normal. Returns ------- The log likelihood value indicating how well this model fit the dataset. """ - model_data_1d = self.model_data_1d_from(instance=instance) + if shared is None: + model_data_1d = self.model_data_1d_from(instance=instance) + else: + model_data_1d = shared residual_map = self.data - model_data_1d chi_squared_map = (residual_map / self.noise_map) ** 2.0 diff --git a/autofit/graphical/declarative/collection.py b/autofit/graphical/declarative/collection.py index 6dff83693..f8ba1f291 100644 --- a/autofit/graphical/declarative/collection.py +++ b/autofit/graphical/declarative/collection.py @@ -91,6 +91,17 @@ def log_likelihood_function(self, instance: ModelInstance) -> float: Compute the combined likelihood of each factor from a collection of instances with the same ordering as the factors. + Before the per-factor loop, the lead factor is asked to compute a `shared` + object via `Analysis.shared_state_from` (see that method). If it returns a + non-`None` value — the opt-in case — that object is forwarded as the `shared` + keyword argument to every factor's `log_likelihood_function`, so that work + which is identical for all factors at this point in parameter space is computed + once and reused rather than recomputed for each factor. + + When no factor provides a shared object (the default) the loop calls each + factor's `log_likelihood_function` exactly as it would without this mechanism, + so existing graphs are unchanged. + Parameters ---------- instance @@ -100,12 +111,37 @@ def log_likelihood_function(self, instance: ModelInstance) -> float: ------- The combined likelihood of all factors """ + shared = self._shared_state_from(instance) + log_likelihood = 0 for model_factor, instance_ in zip(self.model_factors, instance): - log_likelihood += model_factor.log_likelihood_function(instance_) + if shared is None: + log_likelihood += model_factor.log_likelihood_function(instance_) + else: + log_likelihood += model_factor.log_likelihood_function( + instance_, shared=shared + ) return log_likelihood + def _shared_state_from(self, instance: ModelInstance): + """ + Compute the per-evaluation object shared across factors, by asking each factor's + `Analysis` in turn (via `Analysis.shared_state_from`) until one returns a + non-`None` value — the "lead" factor. Returns `None` if no factor opts in, in + which case no state is shared this evaluation. + """ + for model_factor, instance_ in zip(self.model_factors, instance): + analysis = getattr(model_factor, "analysis", None) + shared_state_from = getattr(analysis, "shared_state_from", None) + if shared_state_from is None: + continue + shared = shared_state_from(instance_) + if shared is not None: + return shared + + return None + @property def model_factors(self): model_factors = list() diff --git a/autofit/graphical/declarative/factor/analysis.py b/autofit/graphical/declarative/factor/analysis.py index 368a14eda..55dda7d8d 100644 --- a/autofit/graphical/declarative/factor/analysis.py +++ b/autofit/graphical/declarative/factor/analysis.py @@ -250,8 +250,10 @@ def save_results(self, paths: AbstractPaths, result): """ self.analysis.save_results(paths=paths, result=result) - def log_likelihood_function(self, instance: ModelInstance) -> float: - return self.analysis.log_likelihood_function(instance) + def log_likelihood_function(self, instance: ModelInstance, shared=None) -> float: + if shared is None: + return self.analysis.log_likelihood_function(instance) + return self.analysis.log_likelihood_function(instance, shared=shared) class EPAnalysisFactor(AnalysisFactor): diff --git a/autofit/graphical/declarative/factor/hierarchical.py b/autofit/graphical/declarative/factor/hierarchical.py index 69bd3c272..f61f46d9c 100644 --- a/autofit/graphical/declarative/factor/hierarchical.py +++ b/autofit/graphical/declarative/factor/hierarchical.py @@ -190,7 +190,7 @@ def message_dict(self) -> Dict[Prior, NormalMessage]: def variable(self): return self.drawn_prior - def log_likelihood_function(self, instance): + def log_likelihood_function(self, instance, shared=None): return instance.distribution_model.message(instance.drawn_prior, xp=self._xp) @property diff --git a/autofit/graphical/declarative/factor/prior.py b/autofit/graphical/declarative/factor/prior.py index 093219466..c89c92644 100644 --- a/autofit/graphical/declarative/factor/prior.py +++ b/autofit/graphical/declarative/factor/prior.py @@ -47,13 +47,17 @@ def analysis(self) -> "PriorFactor": """ return self - def log_likelihood_function(self, instance) -> float: + def log_likelihood_function(self, instance, shared=None) -> float: """ Compute the likelihood. The instance is a collection with a single argument expressing a possible value for this prior. The likelihood is computed by simply evaluating the prior's PDF for the given value. + + The `shared` argument (the cross-factor shared state of a + `FactorGraphModel`) is accepted for a uniform calling convention but is + not used by a prior factor. """ return self.prior.factor(instance[0]) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 097ceb911..c947a36e0 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -305,9 +305,49 @@ def with_model(self, model): return ModelAnalysis(analysis=self, model=model) - def log_likelihood_function(self, instance): + def log_likelihood_function(self, instance, shared=None): raise NotImplementedError() + def shared_state_from(self, instance): + """ + Optionally compute a per-evaluation object that is shared across the factors + of a `FactorGraphModel`. + + This is the per-evaluation, cross-factor sibling of `modify_before_fit`. Where + `modify_before_fit` runs once before sampling to precompute analysis state that + does not depend on the model, `shared_state_from` runs once per likelihood + evaluation (the model parameters change every sample) and computes state that + is identical for every factor at the current point in parameter space. + + When a `FactorGraphModel` evaluates its likelihood it calls this method on its + lead factor's `Analysis`. If the returned value is not `None` it is forwarded as + the `shared` keyword argument to every factor's `log_likelihood_function`, so + that work which is identical for all factors (because they share model + parameters) is computed once and reused rather than recomputed `N` times. + + The default implementation returns `None`, meaning no state is shared and every + factor's `log_likelihood_function` runs exactly as it does without this + mechanism. An `Analysis` opts in by overriding this method. + + The returned object must be a valid JAX pytree of traced arrays when the fit is + JIT-compiled: it is recomputed inside the jitted region each evaluation (it + depends on the traced model parameters) and must not be memoised on the instance. + + Correctness is the responsibility of the overriding `Analysis`: only return a + shared object when the parameters it depends on really are shared across every + factor. If they are not, return `None` and let each factor compute its own state. + + Parameters + ---------- + instance + The model instance of the factor whose `Analysis` is acting as the lead. + + Returns + ------- + An object shared across all factors for this evaluation, or `None` for no sharing. + """ + return None + def save_attributes(self, paths: AbstractPaths): pass diff --git a/test_autofit/graphical/test_shared_state.py b/test_autofit/graphical/test_shared_state.py new file mode 100644 index 000000000..96545ce62 --- /dev/null +++ b/test_autofit/graphical/test_shared_state.py @@ -0,0 +1,131 @@ +import itertools + +import numpy as np +import pytest + +import autofit as af +import autofit.graphical as g + + +@pytest.fixture(autouse=True) +def reset_ids(): + af.Prior._ids = itertools.count() + + +class CountingAnalysis(af.ex.Analysis): + """ + An example `Analysis` that counts how many times the (notionally expensive) model + data computation runs, so a test can prove the `FactorGraphModel` shared-state + mechanism computes it once per evaluation rather than once per factor. + """ + + def __init__(self, data, noise_map, share_model_data=True): + super().__init__( + data=data, noise_map=noise_map, share_model_data=share_model_data + ) + self.model_data_calls = 0 + + def model_data_1d_from(self, instance): + self.model_data_calls += 1 + return super().model_data_1d_from(instance=instance) + + +def _shared_gaussian_graph(analyses): + """ + Build a `FactorGraphModel` whose factors share the *entire* Gaussian model via + shared prior objects, so the model data is identical for every factor. + """ + centre = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + normalization = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + sigma = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + + factors = [] + for analysis in analyses: + gaussian = af.Model(af.ex.Gaussian) + gaussian.centre = centre + gaussian.normalization = normalization + gaussian.sigma = sigma + factors.append(af.AnalysisFactor(gaussian, analysis)) + + return g.FactorGraphModel(*factors) + + +def _datasets(n=3, size=10): + """`n` distinct 1D datasets sharing a common noise map of ones.""" + return [ + (np.arange(size, dtype=float) + float(i), np.ones(size)) + for i in range(n) + ] + + +def _instance(collection): + prior_count = collection.global_prior_model.prior_count + return collection.global_prior_model.instance_from_unit_vector( + [0.5] * prior_count + ) + + +def _reference_log_likelihood(collection, instance): + """Sum each factor's likelihood with no sharing (each computes its own model data).""" + return sum( + factor.analysis.log_likelihood_function(instance_) + for factor, instance_ in zip(collection.model_factors, instance) + ) + + +def test_shared_state_computed_once_per_evaluation(): + analyses = [ + CountingAnalysis(data, noise_map) for data, noise_map in _datasets(n=3) + ] + collection = _shared_gaussian_graph(analyses) + instance = _instance(collection) + + collection.log_likelihood_function(instance) + + total_calls = sum(analysis.model_data_calls for analysis in analyses) + assert total_calls == 1 + + +def test_shared_likelihood_equals_unshared_sum(): + analyses = [ + CountingAnalysis(data, noise_map) for data, noise_map in _datasets(n=3) + ] + collection = _shared_gaussian_graph(analyses) + instance = _instance(collection) + + shared_log_likelihood = collection.log_likelihood_function(instance) + reference_log_likelihood = _reference_log_likelihood(collection, instance) + + assert shared_log_likelihood == pytest.approx(reference_log_likelihood) + + +def test_no_provider_graph_is_unchanged(): + """ + With `share_model_data=False` no factor opts in, so no state is shared: each factor + computes its own model data (N calls) and the summed likelihood is unchanged. + """ + analyses = [ + CountingAnalysis(data, noise_map, share_model_data=False) + for data, noise_map in _datasets(n=3) + ] + collection = _shared_gaussian_graph(analyses) + instance = _instance(collection) + + log_likelihood = collection.log_likelihood_function(instance) + reference_log_likelihood = _reference_log_likelihood(collection, instance) + + total_calls = sum(analysis.model_data_calls for analysis in analyses) + # one call per factor from the graph evaluation, plus one per factor from the + # reference sum — the graph did not share, so it computed all three itself. + assert total_calls == 2 * len(analyses) + assert log_likelihood == pytest.approx(reference_log_likelihood) + + +def test_shared_state_from_default_returns_none(): + analysis = af.ex.Analysis( + data=np.arange(10, dtype=float), noise_map=np.ones(10) + ) + model = af.Model(af.ex.Gaussian) + instance = model.instance_from_unit_vector([0.5] * model.prior_count) + + assert analysis.shared_state_from(instance) is None