|
1 | 1 | import copy |
2 | | -from typing import Sequence |
| 2 | +from scipy.stats import multivariate_normal |
| 3 | +from typing import Sequence, Union, Literal |
3 | 4 |
|
4 | 5 | import numpy as np |
5 | 6 | from scipy.special import logsumexp |
|
10 | 11 | from .kalman import KalmanPredictor, ExtendedKalmanPredictor |
11 | 12 | from ..base import Property |
12 | 13 | from ..models.transition import TransitionModel |
| 14 | +from ..types.mixture import GaussianMixture |
| 15 | +from ..types.numeric import Probability |
13 | 16 | from ..types.prediction import Prediction |
14 | 17 | from ..types.state import GaussianState |
15 | 18 | from ..sampler import Sampler |
@@ -383,3 +386,122 @@ def get_detections(prior): |
383 | 386 | detections |= {hypothesis.measurement} |
384 | 387 |
|
385 | 388 | return detections |
| 389 | + |
| 390 | + |
| 391 | +class SMCPHDPredictor(Predictor): |
| 392 | + """SMC-PHD Predictor class |
| 393 | +
|
| 394 | + Sequential Monte-Carlo (SMC) PHD predictor implementation, based on [1]_. |
| 395 | +
|
| 396 | + Notes |
| 397 | + ----- |
| 398 | + - It is assumed that the proposal distribution is the same as the dynamics |
| 399 | + - Target "spawing" is not implemented |
| 400 | +
|
| 401 | + .. [1] Ba-Ngu Vo, S. Singh and A. Doucet, "Sequential Monte Carlo Implementation of the |
| 402 | + PHD Filter for Multi-target Tracking," Sixth International Conference of Information |
| 403 | + Fusion, 2003. Proceedings of the, 2003, pp. 792-799, doi: 10.1109/ICIF.2003.177320. |
| 404 | + """ |
| 405 | + prob_death: Probability = Property( |
| 406 | + doc="The probability of death") |
| 407 | + prob_birth: Probability = Property( |
| 408 | + doc="The probability of birth") |
| 409 | + birth_rate: float = Property( |
| 410 | + doc="The birth rate (i.e. number of new/born targets at each iteration)") |
| 411 | + birth_density: Union[GaussianState, GaussianMixture] = Property( |
| 412 | + doc="The birth density (i.e. density from which to sample birth particles)") |
| 413 | + birth_scheme: Literal["expansion", "mixture"] = Property( |
| 414 | + default="expansion", |
| 415 | + doc="The scheme for birth particles. Options are 'expansion' | 'mixture'. " |
| 416 | + "Default is 'expansion'" |
| 417 | + ) |
| 418 | + |
| 419 | + def __init__(self, *args, **kwargs): |
| 420 | + super().__init__(*args, **kwargs) |
| 421 | + if self.birth_scheme not in ["expansion", "mixture"]: |
| 422 | + raise ValueError("Invalid birth scheme. Options are 'expansion' | 'mixture'") |
| 423 | + |
| 424 | + @predict_lru_cache() |
| 425 | + def predict(self, prior, timestamp=None, **kwargs): |
| 426 | + """ SMC-PHD prediction step |
| 427 | +
|
| 428 | + Parameters |
| 429 | + ---------- |
| 430 | + prior: :class:`~.ParticleState` |
| 431 | + The prior state |
| 432 | + timestamp: :class:`datetime.datetime` |
| 433 | + The time at which to predict the next state |
| 434 | +
|
| 435 | + Returns |
| 436 | + ------- |
| 437 | + : :class:`~.ParticleStatePrediction` |
| 438 | + The predicted state |
| 439 | +
|
| 440 | + """ |
| 441 | + num_samples = len(prior) |
| 442 | + log_prior_weights = prior.log_weight |
| 443 | + time_interval = timestamp - prior.timestamp |
| 444 | + |
| 445 | + # Predict surviving particles forward |
| 446 | + pred_particles_sv = self.transition_model.function(prior, |
| 447 | + time_interval=time_interval, |
| 448 | + noise=True) |
| 449 | + |
| 450 | + # Perform birth and update weights |
| 451 | + num_state_dim = pred_particles_sv.shape[0] |
| 452 | + if self.birth_scheme == "expansion": |
| 453 | + # Expansion birth scheme, as described in [1] |
| 454 | + # Compute number of birth particles (J_k) as a fraction of the number of particles |
| 455 | + num_birth = round(float(self.prob_birth) * num_samples) |
| 456 | + |
| 457 | + # Sample birth particles |
| 458 | + birth_particles_sv = self._sample_birth_particles(num_state_dim, num_birth) |
| 459 | + log_birth_weights = np.full((num_birth,), np.log(self.birth_rate / num_birth)) |
| 460 | + |
| 461 | + # Surviving particle weights |
| 462 | + log_prob_survive = -float(self.prob_death) * time_interval.total_seconds() |
| 463 | + log_pred_weights = log_prob_survive + log_prior_weights |
| 464 | + |
| 465 | + # Append birth particles to predicted ones |
| 466 | + pred_particles_sv = StateVectors( |
| 467 | + np.concatenate((pred_particles_sv, birth_particles_sv), axis=1)) |
| 468 | + log_pred_weights = np.concatenate((log_pred_weights, log_birth_weights)) |
| 469 | + else: |
| 470 | + # Flip a coin for each particle to decide if it gets replaced by a birth particle |
| 471 | + birth_inds = np.flatnonzero(np.random.binomial(1, float(self.prob_birth), num_samples)) |
| 472 | + |
| 473 | + # Sample birth particles and replace in original state vector matrix |
| 474 | + num_birth = len(birth_inds) |
| 475 | + birth_particles_sv = self._sample_birth_particles(num_state_dim, num_birth) |
| 476 | + pred_particles_sv[:, birth_inds] = birth_particles_sv |
| 477 | + |
| 478 | + # Process weights |
| 479 | + prob_survive = np.exp(-float(self.prob_death) * time_interval.total_seconds()) |
| 480 | + birth_weight = self.birth_rate / num_samples |
| 481 | + log_pred_weights = np.log(prob_survive + birth_weight) + log_prior_weights |
| 482 | + |
| 483 | + prediction = Prediction.from_state(prior, state_vector=pred_particles_sv, |
| 484 | + log_weight=log_pred_weights, |
| 485 | + timestamp=timestamp, particle_list=None, |
| 486 | + transition_model=self.transition_model) |
| 487 | + |
| 488 | + return prediction |
| 489 | + |
| 490 | + def _sample_birth_particles(self, num_state_dim: int, num_birth: int): |
| 491 | + birth_particles = np.zeros((num_state_dim, 0)) |
| 492 | + if isinstance(self.birth_density, GaussianMixture): |
| 493 | + n_parts_per_component = num_birth // len(self.birth_density) |
| 494 | + for i, component in enumerate(self.birth_density): |
| 495 | + if i == len(self.birth_density) - 1: |
| 496 | + n_parts_per_component += num_birth % len(self.birth_density) |
| 497 | + birth_particles_component = multivariate_normal.rvs( |
| 498 | + component.mean.ravel(), |
| 499 | + component.covar, |
| 500 | + n_parts_per_component).T |
| 501 | + birth_particles = np.hstack((birth_particles, birth_particles_component)) |
| 502 | + else: |
| 503 | + birth_particles = np.atleast_2d(multivariate_normal.rvs( |
| 504 | + self.birth_density.mean.ravel(), |
| 505 | + self.birth_density.covar, |
| 506 | + num_birth)).T |
| 507 | + return birth_particles |
0 commit comments