Skip to content

Commit 50e0b09

Browse files
committed
Add SMCPHDUpdater
1 parent 4f3800a commit 50e0b09

File tree

2 files changed

+133
-2
lines changed

2 files changed

+133
-2
lines changed

stonesoup/updater/particle.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..predictor.particle import MultiModelPredictor, RaoBlackwellisedMultiModelPredictor
1515
from ..resampler import Resampler
1616
from ..regulariser import Regulariser
17+
from ..types.numeric import Probability
1718
from ..types.prediction import (
1819
Prediction, ParticleMeasurementPrediction, GaussianStatePrediction, MeasurementPrediction)
1920
from ..types.update import ParticleStateUpdate, Update
@@ -519,3 +520,100 @@ def _log_space_product(A, B):
519520
Astack = np.stack([A] * B.shape[1]).transpose(1, 0, 2)
520521
Bstack = np.stack([B] * A.shape[0]).transpose(0, 2, 1)
521522
return np.squeeze(logsumexp(Astack + Bstack, axis=2))
523+
524+
525+
class SMCPHDUpdater(ParticleUpdater):
526+
""" SMC-PHD updater class
527+
528+
Sequential Monte-Carlo (SMC) PHD updater implementation, based on [1]_ .
529+
530+
Notes
531+
-----
532+
- It is assumed that the proposal distribution is the same as the dynamics
533+
- Target "spawing" is not implemented
534+
535+
.. [1] Ba-Ngu Vo, S. Singh and A. Doucet, "Sequential Monte Carlo Implementation of the
536+
PHD Filter for Multi-target Tracking," Sixth International Conference of Information
537+
Fusion, 2003. Proceedings of the, 2003, pp. 792-799, doi: 10.1109/ICIF.2003.177320.
538+
.. [2] P. Horridge and S. Maskell, “Using a probabilistic hypothesis density filter to
539+
confirm tracks in a multi-target environment,” in 2011 Jahrestagung der Gesellschaft
540+
fr Informatik, October 2011.
541+
"""
542+
prob_detect: Probability = Property(
543+
default=Probability(0.85),
544+
doc="Target Detection Probability")
545+
clutter_intensity: float = Property(
546+
doc="Average number of clutter measurements per time step, per unit volume")
547+
num_samples: int = Property(
548+
default=1024,
549+
doc="The number of samples. Default is 1024")
550+
551+
def update(self, multihypothesis, **kwargs):
552+
""" SMC-PHD update step
553+
554+
Parameters
555+
----------
556+
multihypothesis : :class:`~.MultipleHypothesis`
557+
A container of :class:`~SingleHypothesis` objects. All hypotheses are assumed to have
558+
the same prediction (and hence same timestamp).
559+
560+
Returns
561+
-------
562+
: :class:`~.ParticleStateUpdate`
563+
The state posterior
564+
"""
565+
566+
prediction = copy.copy(multihypothesis[0].prediction)
567+
detections = [hypothesis.measurement for hypothesis in multihypothesis if hypothesis]
568+
569+
# Calculate w^{n,i} Eq. (20) of [2]
570+
log_weights_per_hyp = self.get_log_weights_per_hypothesis(prediction, detections)
571+
572+
# Update weights Eq. (8) of [1]
573+
# w_k^i = \sum_{z \in Z_k}{w^{n,i}}, where i is the index of z in Z_k
574+
log_post_weights = logsumexp(log_weights_per_hyp, axis=1)
575+
prediction.log_weight = log_post_weights
576+
577+
# Resample
578+
log_num_targets = logsumexp(log_post_weights) # N_{k|k}
579+
# Normalize weights
580+
prediction.log_weight = log_post_weights - log_num_targets
581+
if self.resampler is not None:
582+
prediction = self.resampler.resample(prediction, self.num_samples) # Resample
583+
# De-normalize
584+
prediction.log_weight = prediction.log_weight + log_num_targets
585+
586+
return Update.from_state(
587+
state=multihypothesis[0].prediction,
588+
state_vector=prediction.state_vector,
589+
log_weight=prediction.log_weight,
590+
hypothesis=multihypothesis,
591+
timestamp=multihypothesis[0].measurement.timestamp,
592+
)
593+
594+
def get_log_weights_per_hypothesis(self, prediction, detections):
595+
num_samples = prediction.state_vector.shape[1]
596+
597+
# Compute g(z|x) matrix as in [1]
598+
g = self._get_measurement_loglikelihoods(prediction, detections)
599+
600+
# Calculate w^{n,i} Eq. (20) of [2]
601+
Ck = self.prob_detect.log() + g + prediction.log_weight[:, np.newaxis]
602+
C = logsumexp(Ck, axis=0)
603+
k = np.log(self.clutter_intensity)
604+
C_plus = np.logaddexp(C, k)
605+
log_weights_per_hyp = np.full((num_samples, len(detections) + 1), -np.inf)
606+
log_weights_per_hyp[:, 0] = np.log(1 - self.prob_detect) + prediction.log_weight
607+
if len(detections):
608+
log_weights_per_hyp[:, 1:] = Ck - C_plus
609+
610+
return log_weights_per_hyp
611+
612+
def _get_measurement_loglikelihoods(self, prediction, detections):
613+
num_samples = prediction.state_vector.shape[1]
614+
# Compute g(z|x) matrix as in [1]
615+
g = np.zeros((num_samples, len(detections)))
616+
for i, detection in enumerate(detections):
617+
measurement_model = self._check_measurement_model(detection.measurement_model)
618+
g[:, i] = measurement_model.logpdf(detection, prediction, noise=True)
619+
return g

stonesoup/updater/tests/test_particle.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Test for updater.particle module"""
2+
import itertools
3+
24
import datetime
35
from functools import partial
46

@@ -8,16 +10,18 @@
810
from ...models.measurement.linear import LinearGaussian
911
from ...resampler.particle import SystematicResampler
1012
from ...types.array import StateVectors
11-
from ...types.detection import Detection
13+
from ...types.detection import Detection, MissedDetection
1214
from ...types.hypothesis import SingleHypothesis
1315
from ...types.multihypothesis import MultipleHypothesis
16+
from ...types.numeric import Probability
1417
from ...types.particle import Particle
1518
from ...types.state import ParticleState
1619
from ...types.prediction import (
1720
ParticleStatePrediction, ParticleMeasurementPrediction)
1821
from ...updater.particle import (
1922
ParticleUpdater, GromovFlowParticleUpdater,
20-
GromovFlowKalmanParticleUpdater, BernoulliParticleUpdater)
23+
GromovFlowKalmanParticleUpdater, BernoulliParticleUpdater,
24+
SMCPHDUpdater)
2125
from ...predictor.particle import BernoulliParticlePredictor
2226
from ...models.transition.linear import ConstantVelocity, CombinedLinearGaussianTransitionModel
2327
from ...types.update import BernoulliParticleStateUpdate
@@ -263,3 +267,32 @@ def test_regularised_particle(transition_model, model_flag):
263267
assert updated_state.hypothesis.measurement_prediction == measurement_prediction
264268
assert updated_state.hypothesis.prediction == prediction
265269
assert updated_state.hypothesis.measurement == measurement
270+
271+
272+
def test_smcphd():
273+
prob_detect = Probability(.9) # 90% chance of detection.
274+
clutter_intensity = 1e-5
275+
num_particles = 9
276+
timestamp = datetime.datetime.now()
277+
278+
particles = [Particle(np.array([[i], [j]]), 1 / num_particles)
279+
for i, j in itertools.product([10., 20., 30.], [10., 20., 30.])]
280+
prediction = ParticleStatePrediction(None, particle_list=particles,
281+
timestamp=timestamp)
282+
measurements = [Detection([[i]], timestamp=timestamp) for i in [10., 20., 30.]]
283+
284+
hypotheses = [SingleHypothesis(prediction, MissedDetection(timestamp=timestamp), None)]
285+
hypotheses.extend([SingleHypothesis(prediction, measurement, None)
286+
for measurement in measurements])
287+
multihypothesis = MultipleHypothesis(hypotheses)
288+
289+
measurement_model = LinearGaussian(ndim_state=2, mapping=[0], noise_covar=np.array([[0.04]]))
290+
updater = SMCPHDUpdater(measurement_model=measurement_model, resampler=SystematicResampler(),
291+
prob_detect=prob_detect, clutter_intensity=clutter_intensity,
292+
num_samples=num_particles)
293+
294+
updated_state = updater.update(multihypothesis)
295+
296+
assert updated_state.timestamp == timestamp
297+
assert updated_state.hypothesis == multihypothesis
298+
assert np.isclose(float(updated_state.weight.sum()), 3, atol=1e-1)

0 commit comments

Comments
 (0)