Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions autofit/graphical/mean_field.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from collections import ChainMap
from typing import Dict, Tuple, Optional, Union, Iterable

Expand All @@ -24,9 +23,8 @@
Plate,
VariableData,
FactorValue,
VariableLinearOperator,
)
from autofit.mapper.variable_operator import MatrixOperator, VariableFullOperator
from autofit.mapper.variable_operator import VariableFullOperator
from autofit.messages.abstract import AbstractMessage
from autofit.messages.fixed import FixedMessage

Expand Down Expand Up @@ -176,12 +174,12 @@ def rescale(self, rescale: Dict[Variable, float]) -> "MeanField":
for v, message in self.items():
scale = rescale.get(v, 1)
if scale == 1:
rescaled[v] = message
rescaled[v] = message
elif scale == 0:
rescaled[v] = 1.
else:
rescaled[v] = message ** scale
rescaled[v] = message ** scale

return MeanField(rescaled)

@property
Expand Down Expand Up @@ -238,7 +236,7 @@ def logpdf_gradient(self, values: Dict[Variable, np.ndarray], **kwargs):

def __repr__(self):
reprdict = (
"{\n" + "\n".join(f" {k}: {v}" for k, v in self.items()) + "\n }"
"{\n" + "\n".join(f" {k}: {v}" for k, v in self.items()) + "\n }"
)
classname = type(self).__name__
return f"{classname}({reprdict}, log_norm={self.log_norm})"
Expand All @@ -248,7 +246,7 @@ def is_valid(self):
return all(d.is_valid for d in self.values())

def prod(self, *approxs: "MeanField") -> "MeanField":
dists = (
dists = list(
(k, prod((m.get(k, 1.0) for m in approxs), m)) for k, m in self.items()
)
return MeanField({k: m for k, m in dists if isinstance(m, Prior)})
Expand Down Expand Up @@ -300,7 +298,13 @@ def from_mode_covariance(
mode = ChainMap(mode, self.fixed_values)
projection = MeanField(
{
v: self[v].from_mode(mode[v], covar.get(v), id_=self[v].id)
v: self[v].from_mode(
mode[v],
covar.get(v),
id_=self[v].id,
lower_limit=self[v].lower_limit,
upper_limit=self[v].upper_limit,
)
for v in self.keys() & mode.keys()
}
)
Expand All @@ -324,11 +328,11 @@ def from_dist(
return dist if isinstance(dist, cls) else MeanField(dist)

def update_factor_mean_field(
self,
cavity_dist: "MeanField",
last_dist: Optional["MeanField"] = None,
delta: float = 1.0,
status: Status = Status(),
self,
cavity_dist: "MeanField",
last_dist: Optional["MeanField"] = None,
delta: float = 1.0,
status: Status = Status(),
) -> Tuple["MeanField", Status]:

success, messages, _, flag = status
Expand All @@ -340,7 +344,7 @@ def update_factor_mean_field(
log_norm = factor_dist.log_norm
factor_dist = factor_dist ** delta * last_dist ** (1 - delta)
factor_dist.log_norm = (
delta * log_norm + (1 - delta) * last_dist.log_norm
delta * log_norm + (1 - delta) * last_dist.log_norm
)

for m in caught_warnings.messages:
Expand Down Expand Up @@ -492,7 +496,6 @@ def func_gradient(
self,
values: Dict[Variable, np.ndarray],
) -> Tuple[FactorValue, VariableData]:

variable_dict = {**self.fixed_values, **values}
fval, fjac = self.factor.func_jacobian(variable_dict)

Expand All @@ -507,12 +510,11 @@ def func_gradient(
return logl, grad

def project_mean_field(
self,
model_dist: MeanField,
delta: float = 1.0,
status: Status = Status(),
self,
model_dist: MeanField,
delta: float = 1.0,
status: Status = Status(),
) -> Tuple["FactorApproximation", Status]:

factor_dist, status = model_dist.update_factor_mean_field(
self.cavity_dist,
last_dist=self.factor_dist,
Expand Down
1 change: 1 addition & 0 deletions autofit/graphical/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def prod(iterable: Iterable[_M], *arg: Tuple[_M]) -> _M:
>>> prod(range(1, 3), 2.)
4.
"""
iterable = list(iterable)
return reduce(mul, iterable, *arg)


Expand Down
64 changes: 50 additions & 14 deletions autofit/messages/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def __init__(self, *parameters: Union[np.ndarray, float], log_norm=0.0, **kwargs
def copy(self):
cls = self._Base_class or type(self)
result = cls(
*(copy(params) for params in self.parameters), log_norm=self.log_norm
*(copy(params) for params in self.parameters),
log_norm=self.log_norm,
lower_limit=self.lower_limit,
upper_limit=self.upper_limit,
)
result.id = self.id
return result
Expand Down Expand Up @@ -184,19 +187,26 @@ def sum_natural_parameters(self, *dists: "AbstractMessage") -> "AbstractMessage"
),
self.natural_parameters,
)
mul_dist = self.from_natural_parameters(new_params, id_=self.id)
return mul_dist
return self.from_natural_parameters(
new_params,
id_=self.id,
lower_limit=self.lower_limit,
upper_limit=self.upper_limit,
)

def sub_natural_parameters(self, other: "AbstractMessage") -> "AbstractMessage":
"""return the unnormalised result of dividing the pdf
of this distribution with another distribution of the same
type"""
log_norm = self.log_norm - other.log_norm
new_params = self.natural_parameters - other.natural_parameters
div_dist = self.from_natural_parameters(
new_params, log_norm=log_norm, id_=self.id
return self.from_natural_parameters(
new_params,
log_norm=log_norm,
id_=self.id,
lower_limit=self.lower_limit,
upper_limit=self.upper_limit,
)
return div_dist

_multiply = sum_natural_parameters
_divide = sub_natural_parameters
Expand All @@ -207,7 +217,13 @@ def __mul__(self, other: Union["AbstractMessage", Real]) -> "AbstractMessage":
else:
cls = self._Base_class or type(self)
log_norm = self.log_norm + np.log(other)
return cls(*self.parameters, log_norm=log_norm, id_=self.id)
return cls(
*self.parameters,
log_norm=log_norm,
id_=self.id,
lower_limit=self.lower_limit,
upper_limit=self.upper_limit,
)

def __rmul__(self, other: "AbstractMessage") -> "AbstractMessage":
return self * other
Expand All @@ -218,13 +234,26 @@ def __truediv__(self, other: Union["AbstractMessage", Real]) -> "AbstractMessage
else:
cls = self._Base_class or type(self)
log_norm = self.log_norm - np.log(other)
return cls(*self.parameters, log_norm=log_norm, id_=self.id)
return cls(
*self.parameters,
log_norm=log_norm,
id_=self.id,
lower_limit=self.lower_limit,
upper_limit=self.upper_limit,
)

def __pow__(self, other: Real) -> "AbstractMessage":
natural = self.natural_parameters
new_params = other * natural
log_norm = other * self.log_norm
return self.from_natural_parameters(new_params, log_norm=log_norm, id_=self.id)
new = self.from_natural_parameters(
new_params,
log_norm=log_norm,
id_=self.id,
lower_limit=self.lower_limit,
upper_limit=self.upper_limit,
)
return new

@classmethod
def parameter_names(cls):
Expand Down Expand Up @@ -268,7 +297,7 @@ def _broadcast_natural_parameters(self, x):
)

def factor(self, x):
# self.assert_within_limits(x)
# self.assert_within_limits(x)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still cant tell if we should be setting up the bounds ofr the factor optimization properly here or not...

Anyway, if the test works this is fine for now.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure

return self.logpdf(x)

def logpdf(self, x: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -362,7 +391,7 @@ def numerical_logpdf_gradient_hessian(

@classmethod
def project(
cls, samples: np.ndarray, log_weight_list: Optional[np.ndarray] = None, id_=None
cls, samples: np.ndarray, log_weight_list: Optional[np.ndarray] = None, **kwargs
) -> "AbstractMessage":
"""Calculates the sufficient statistics of a set of samples
and returns the distribution with the appropriate parameters
Expand All @@ -388,11 +417,11 @@ def project(
assert np.isfinite(suff_stats).all()

cls_ = cls._projection_class or cls._Base_class or cls
return cls_.from_sufficient_statistics(suff_stats, log_norm=log_norm, id_=id_)
return cls_.from_sufficient_statistics(suff_stats, log_norm=log_norm, **kwargs)

@classmethod
def from_mode(
cls, mode: np.ndarray, covariance: np.ndarray, id_
cls, mode: np.ndarray, covariance: np.ndarray, **kwargs
) -> "AbstractMessage":
pass

Expand Down Expand Up @@ -441,7 +470,14 @@ def update_invalid(self, other: "AbstractMessage") -> "AbstractMessage":
# TODO: Fairly certain this would not work
valid_parameters = iter(self if valid else other)
cls = self._Base_class or type(self)
return cls(*valid_parameters, log_norm=self.log_norm, id_=self.id)
new = cls(
*valid_parameters,
log_norm=self.log_norm,
id_=self.id,
lower_limit=self.lower_limit,
upper_limit=self.upper_limit,
)
return new

def check_support(self) -> np.ndarray:
if self._parameter_support is not None:
Expand Down
5 changes: 5 additions & 0 deletions autofit/messages/beta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import warnings
from typing import Union

Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(
self,
alpha=0.5,
beta=0.5,
lower_limit=-math.inf,
upper_limit=math.inf,
log_norm=0,
id_=None
):
Expand All @@ -82,6 +85,8 @@ def __init__(
super().__init__(
alpha,
beta,
lower_limit=lower_limit,
upper_limit=upper_limit,
log_norm=log_norm,
id_=id_
)
Expand Down
5 changes: 5 additions & 0 deletions autofit/messages/fixed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Optional, Tuple

import numpy as np
Expand All @@ -12,12 +13,16 @@ class FixedMessage(AbstractMessage):
def __init__(
self,
value: np.ndarray,
lower_limit=-math.inf,
upper_limit=math.inf,
log_norm: np.ndarray = 0.,
id_=None
):
self.value = value
super().__init__(
value,
lower_limit=lower_limit,
upper_limit=upper_limit,
log_norm=log_norm,
id_=id_
)
Expand Down
25 changes: 21 additions & 4 deletions autofit/messages/gamma.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import numpy as np
from scipy import special

Expand All @@ -16,10 +18,25 @@ def log_partition(self):
_support = ((0, np.inf),)
_parameter_support = ((0, np.inf), (0, np.inf))

def __init__(self, alpha=1.0, beta=1.0, log_norm=0.0, id_=None):
def __init__(
self,
alpha=1.0,
beta=1.0,
lower_limit=-math.inf,
upper_limit=math.inf,
log_norm=0.0,
id_=None
):
self.alpha = alpha
self.beta = beta
super().__init__(alpha, beta, log_norm=log_norm, id_=id_)
super().__init__(
alpha,
beta,
lower_limit=lower_limit,
upper_limit=upper_limit,
log_norm=log_norm,
id_=id_
)

def value_for(self, unit: float) -> float:
raise NotImplemented()
Expand Down Expand Up @@ -62,12 +79,12 @@ def sample(self, n_samples=None):
return np.random.gamma(a1, scale=1 / b1, size=shape)

@classmethod
def from_mode(cls, mode, covariance, id_):
def from_mode(cls, mode, covariance, **kwargs):
m, V = cls._get_mean_variance(mode, covariance)

alpha = 1 + m ** 2 * V # match variance
beta = alpha / m # match mean
return cls(alpha, beta, id_=id_)
return cls(alpha, beta, **kwargs)

def kl(self, dist):
P, Q = dist, self
Expand Down
4 changes: 2 additions & 2 deletions autofit/messages/transform_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def project(
self,
samples: np.ndarray,
log_weight_list: Optional[np.ndarray] = None,
id_=None
**kwargs,
):
return self._new_for_base_message(
self.transformed_wrapper.project(
samples=samples,
log_weight_list=log_weight_list,
id_=id_
**kwargs,
)
)

Expand Down
4 changes: 3 additions & 1 deletion autofit/non_linear/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def projected_model(self) -> AbstractPriorModel:
)
),
log_weight_list=weights,
id_=prior.id
id_=prior.id,
lower_limit=prior.lower_limit,
upper_limit=prior.upper_limit,
)
for path, prior
in self._model.path_priors_tuples
Expand Down
Loading