Skip to content

Commit 4f3800a

Browse files
committed
Add SMCPHDPredictor
1 parent cd30860 commit 4f3800a

File tree

2 files changed

+202
-3
lines changed

2 files changed

+202
-3
lines changed

stonesoup/predictor/particle.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
2-
from typing import Sequence
2+
from scipy.stats import multivariate_normal
3+
from typing import Sequence, Union, Literal
34

45
import numpy as np
56
from scipy.special import logsumexp
@@ -10,6 +11,8 @@
1011
from .kalman import KalmanPredictor, ExtendedKalmanPredictor
1112
from ..base import Property
1213
from ..models.transition import TransitionModel
14+
from ..types.mixture import GaussianMixture
15+
from ..types.numeric import Probability
1316
from ..types.prediction import Prediction
1417
from ..types.state import GaussianState
1518
from ..sampler import Sampler
@@ -383,3 +386,122 @@ def get_detections(prior):
383386
detections |= {hypothesis.measurement}
384387

385388
return detections
389+
390+
391+
class SMCPHDPredictor(Predictor):
392+
"""SMC-PHD Predictor class
393+
394+
Sequential Monte-Carlo (SMC) PHD predictor implementation, based on [1]_.
395+
396+
Notes
397+
-----
398+
- It is assumed that the proposal distribution is the same as the dynamics
399+
- Target "spawing" is not implemented
400+
401+
.. [1] Ba-Ngu Vo, S. Singh and A. Doucet, "Sequential Monte Carlo Implementation of the
402+
PHD Filter for Multi-target Tracking," Sixth International Conference of Information
403+
Fusion, 2003. Proceedings of the, 2003, pp. 792-799, doi: 10.1109/ICIF.2003.177320.
404+
"""
405+
prob_death: Probability = Property(
406+
doc="The probability of death")
407+
prob_birth: Probability = Property(
408+
doc="The probability of birth")
409+
birth_rate: float = Property(
410+
doc="The birth rate (i.e. number of new/born targets at each iteration)")
411+
birth_density: Union[GaussianState, GaussianMixture] = Property(
412+
doc="The birth density (i.e. density from which to sample birth particles)")
413+
birth_scheme: Literal["expansion", "mixture"] = Property(
414+
default="expansion",
415+
doc="The scheme for birth particles. Options are 'expansion' | 'mixture'. "
416+
"Default is 'expansion'"
417+
)
418+
419+
def __init__(self, *args, **kwargs):
420+
super().__init__(*args, **kwargs)
421+
if self.birth_scheme not in ["expansion", "mixture"]:
422+
raise ValueError("Invalid birth scheme. Options are 'expansion' | 'mixture'")
423+
424+
@predict_lru_cache()
425+
def predict(self, prior, timestamp=None, **kwargs):
426+
""" SMC-PHD prediction step
427+
428+
Parameters
429+
----------
430+
prior: :class:`~.ParticleState`
431+
The prior state
432+
timestamp: :class:`datetime.datetime`
433+
The time at which to predict the next state
434+
435+
Returns
436+
-------
437+
: :class:`~.ParticleStatePrediction`
438+
The predicted state
439+
440+
"""
441+
num_samples = len(prior)
442+
log_prior_weights = prior.log_weight
443+
time_interval = timestamp - prior.timestamp
444+
445+
# Predict surviving particles forward
446+
pred_particles_sv = self.transition_model.function(prior,
447+
time_interval=time_interval,
448+
noise=True)
449+
450+
# Perform birth and update weights
451+
num_state_dim = pred_particles_sv.shape[0]
452+
if self.birth_scheme == "expansion":
453+
# Expansion birth scheme, as described in [1]
454+
# Compute number of birth particles (J_k) as a fraction of the number of particles
455+
num_birth = round(float(self.prob_birth) * num_samples)
456+
457+
# Sample birth particles
458+
birth_particles_sv = self._sample_birth_particles(num_state_dim, num_birth)
459+
log_birth_weights = np.full((num_birth,), np.log(self.birth_rate / num_birth))
460+
461+
# Surviving particle weights
462+
log_prob_survive = -float(self.prob_death) * time_interval.total_seconds()
463+
log_pred_weights = log_prob_survive + log_prior_weights
464+
465+
# Append birth particles to predicted ones
466+
pred_particles_sv = StateVectors(
467+
np.concatenate((pred_particles_sv, birth_particles_sv), axis=1))
468+
log_pred_weights = np.concatenate((log_pred_weights, log_birth_weights))
469+
else:
470+
# Flip a coin for each particle to decide if it gets replaced by a birth particle
471+
birth_inds = np.flatnonzero(np.random.binomial(1, float(self.prob_birth), num_samples))
472+
473+
# Sample birth particles and replace in original state vector matrix
474+
num_birth = len(birth_inds)
475+
birth_particles_sv = self._sample_birth_particles(num_state_dim, num_birth)
476+
pred_particles_sv[:, birth_inds] = birth_particles_sv
477+
478+
# Process weights
479+
prob_survive = np.exp(-float(self.prob_death) * time_interval.total_seconds())
480+
birth_weight = self.birth_rate / num_samples
481+
log_pred_weights = np.log(prob_survive + birth_weight) + log_prior_weights
482+
483+
prediction = Prediction.from_state(prior, state_vector=pred_particles_sv,
484+
log_weight=log_pred_weights,
485+
timestamp=timestamp, particle_list=None,
486+
transition_model=self.transition_model)
487+
488+
return prediction
489+
490+
def _sample_birth_particles(self, num_state_dim: int, num_birth: int):
491+
birth_particles = np.zeros((num_state_dim, 0))
492+
if isinstance(self.birth_density, GaussianMixture):
493+
n_parts_per_component = num_birth // len(self.birth_density)
494+
for i, component in enumerate(self.birth_density):
495+
if i == len(self.birth_density) - 1:
496+
n_parts_per_component += num_birth % len(self.birth_density)
497+
birth_particles_component = multivariate_normal.rvs(
498+
component.mean.ravel(),
499+
component.covar,
500+
n_parts_per_component).T
501+
birth_particles = np.hstack((birth_particles, birth_particles_component))
502+
else:
503+
birth_particles = np.atleast_2d(multivariate_normal.rvs(
504+
self.birth_density.mean.ravel(),
505+
self.birth_density.covar,
506+
num_birth)).T
507+
return birth_particles

stonesoup/predictor/tests/test_particle.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import datetime
24
import copy
35

@@ -6,11 +8,13 @@
68

79
from ...models.transition.linear import ConstantVelocity
810
from ...predictor.particle import (
9-
ParticlePredictor, ParticleFlowKalmanPredictor, BernoulliParticlePredictor)
11+
ParticlePredictor, ParticleFlowKalmanPredictor, BernoulliParticlePredictor, SMCPHDPredictor)
12+
from ...types.array import StateVector
13+
from ...types.numeric import Probability
1014
from ...types.particle import Particle
1115
from ...types.prediction import ParticleStatePrediction, BernoulliParticleStatePrediction
1216
from ...types.update import BernoulliParticleStateUpdate
13-
from ...types.state import ParticleState, BernoulliParticleState
17+
from ...types.state import ParticleState, BernoulliParticleState, GaussianState
1418
from ...models.measurement.linear import LinearGaussian
1519
from ...types.detection import Detection
1620
from ...sampler.particle import ParticleSampler
@@ -271,3 +275,76 @@ def test_bernoulli_particle_detection():
271275
assert np.allclose(eval_weight, prediction.weight.astype(np.float64))
272276
# check that the weights are normalised
273277
assert np.around(float(np.sum(prediction.weight)), decimals=1) == 1
278+
279+
280+
@pytest.mark.parametrize(
281+
"birth_scheme",
282+
('mixture', 'expansion', 'some_other_scheme'))
283+
def test_smcphd(birth_scheme):
284+
285+
# Initialise a transition model
286+
cv = ConstantVelocity(noise_diff_coeff=0)
287+
288+
# Define time related variables
289+
timestamp = datetime.datetime.now()
290+
timediff = 2 # 2sec
291+
new_timestamp = timestamp + datetime.timedelta(seconds=timediff)
292+
time_interval = new_timestamp - timestamp
293+
294+
# Parameters for SMC-PHD
295+
prob_death = Probability(0.01) # Probability of death
296+
prob_birth = Probability(0.1) # Probability of birth
297+
birth_rate = 0.05 # Birth-rate (Mean number of new targets per scan)
298+
birth_density = GaussianState(StateVector(np.array([20., 0.0])),
299+
np.diag([10. ** 2, 1. ** 2])) # Birth density
300+
num_particles = 9 # Number of particles
301+
302+
# Define prior state
303+
prior_particles = [Particle(np.array([[i], [j]]), 1/num_particles)
304+
for i, j in itertools.product([10, 20, 30], [10, 20, 30])]
305+
prior = ParticleState(None, particle_list=prior_particles, timestamp=timestamp)
306+
307+
if birth_scheme == 'some_other_scheme':
308+
with pytest.raises(ValueError):
309+
SMCPHDPredictor(transition_model=cv, birth_density=birth_density,
310+
prob_death=prob_death, prob_birth=prob_birth,
311+
birth_rate=birth_rate, birth_scheme=birth_scheme)
312+
return
313+
314+
predictor = SMCPHDPredictor(transition_model=cv, birth_density=birth_density,
315+
prob_death=prob_death, prob_birth=prob_birth,
316+
birth_rate=birth_rate, birth_scheme=birth_scheme)
317+
318+
# Ensure same random numbers are generated
319+
np.random.seed(16549)
320+
321+
prediction = predictor.predict(prior, timestamp=new_timestamp)
322+
323+
prob_survive = np.exp(-float(prob_death) * time_interval.total_seconds())
324+
eval_particles = [Particle(cv.matrix(timestamp=new_timestamp,
325+
time_interval=time_interval)
326+
@ particle.state_vector,
327+
prob_survive * particle.weight)
328+
for particle in prior_particles]
329+
if birth_scheme == 'mixture':
330+
birth_weight = birth_rate / num_particles
331+
new_weight = (prob_survive + birth_weight) * 1 / num_particles
332+
eval_particles[0].state_vector = StateVector([[11.31091636],
333+
[-1.39374536]])
334+
for particle in eval_particles:
335+
particle.weight = new_weight
336+
337+
else:
338+
num_birth = round(float(prob_birth) * num_particles)
339+
birth_weight = birth_rate / num_birth
340+
eval_particles.append(Particle(state_vector=StateVector([[18.3918058],
341+
[0.31072265]]),
342+
weight=birth_weight,
343+
parent=None))
344+
345+
eval_prediction = ParticleStatePrediction(None, new_timestamp, particle_list=eval_particles)
346+
347+
assert np.allclose(prediction.mean, eval_prediction.mean)
348+
assert prediction.timestamp == new_timestamp
349+
assert np.allclose(eval_prediction.state_vector, prediction.state_vector)
350+
assert np.allclose(prediction.log_weight, eval_prediction.log_weight)

0 commit comments

Comments
 (0)