Skip to content

Commit 409d3f9

Browse files
authored
Merge pull request dstl#1106 from dstl/optuna_manager
Add Optuna Sensor Manager
2 parents 3bb5efb + 005e7f0 commit 409d3f9

File tree

5 files changed

+181
-2
lines changed

5 files changed

+181
-2
lines changed

.circleci/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
python -m venv venv
2929
. venv/bin/activate
3030
pip install --upgrade pip
31-
pip install -e .[dev,orbital] opencv-python-headless pyehm
31+
pip install -e .[dev,ehm,optuna,orbital] opencv-python-headless
3232
- save_cache:
3333
paths:
3434
- ./venv
@@ -75,7 +75,7 @@ jobs:
7575
python -m venv venv
7676
. venv/bin/activate
7777
pip install --upgrade pip
78-
pip install -e .[orbital] opencv-python-headless plotly pytest-cov pytest-remotedata pytest-skip-slow pyehm confluent-kafka h5py pandas
78+
pip install -e .[ehm,optuna,orbital] opencv-python-headless plotly pytest-cov pytest-remotedata pytest-skip-slow pyehm confluent-kafka h5py pandas
7979
- save_cache:
8080
paths:
8181
- ./venv

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ mfa = [
8484
ehm = [
8585
"pyehm",
8686
]
87+
optuna = [
88+
"optuna",
89+
]
8790

8891
[tool.setuptools]
8992
include-package-data = false

stonesoup/sensormanager/action.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def min(self):
9393
def max(self):
9494
raise NotImplementedError
9595

96+
@abstractmethod
97+
def action_from_value(self):
98+
raise NotImplementedError
99+
96100

97101
class ActionableProperty(Property):
98102
"""Property that is modified via an :class:`~.Action` with defined, non-equal start and end

stonesoup/sensormanager/optuna.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Iterable
2+
from collections import defaultdict
3+
import warnings
4+
5+
try:
6+
import optuna
7+
except ImportError as error:
8+
raise ImportError("Usage of Optuna Sensor Manager requires that the optional package "
9+
"`optuna`is installed") from error
10+
11+
from ..base import Property
12+
from ..sensor.sensor import Sensor
13+
from .action import RealNumberActionGenerator, Action
14+
from . import SensorManager
15+
16+
17+
class OptunaSensorManager(SensorManager):
18+
"""Sensor Manager that uses the optuna package to determine the best actions available within
19+
a time frame specified by :attr:`timeout`."""
20+
timeout: float = Property(
21+
doc="Number of seconds that the sensor manager should optimise for each time-step",
22+
default=10.)
23+
24+
def __init__(self, *args, **kwargs):
25+
super().__init__(*args, **kwargs)
26+
optuna.logging.set_verbosity(optuna.logging.CRITICAL)
27+
28+
def choose_actions(self, tracks, timestamp, nchoose=1, **kwargs) -> Iterable[tuple[Sensor,
29+
Action]]:
30+
"""Method to find the best actions for the given :attr:`sensors` to according to the
31+
:attr:`reward_function`.
32+
33+
Parameters
34+
----------
35+
tracks_list : List[Track]
36+
List of Tracks for the sensor manager to observe.
37+
timestamp: datetime.datetime
38+
The time for the actions to be produced for.
39+
40+
Returns
41+
-------
42+
Iterable[Tuple[Sensor, Action]]
43+
The actions and associated sensors produced by the sensor manager."""
44+
all_action_generators = dict()
45+
46+
for sensor in self.sensors:
47+
action_generators = sensor.actions(timestamp)
48+
all_action_generators[sensor] = action_generators # set of generators
49+
50+
def config_from_trial(trial):
51+
config = defaultdict(list)
52+
for i, (sensor, generators) in enumerate(all_action_generators.items()):
53+
54+
for j, generator in enumerate(generators):
55+
if isinstance(generator, RealNumberActionGenerator):
56+
with warnings.catch_warnings():
57+
warnings.simplefilter("ignore", UserWarning)
58+
value = trial.suggest_float(
59+
f'{i}{j}', generator.min, generator.max + generator.epsilon,
60+
step=getattr(generator, 'resolution', None))
61+
else:
62+
raise TypeError(f"type {type(generator)} not handled yet")
63+
action = generator.action_from_value(value)
64+
if action is not None:
65+
config[sensor].append(action)
66+
else:
67+
config[sensor].append(generator.default_action)
68+
return config
69+
70+
def optimise_func(trial):
71+
config = config_from_trial(trial)
72+
73+
return -self.reward_function(config, tracks, timestamp)
74+
75+
study = optuna.create_study()
76+
# will finish study after `timeout` seconds has elapsed.
77+
study.optimize(optimise_func, n_trials=None, timeout=self.timeout)
78+
79+
best_params = study.best_params
80+
config = defaultdict(list)
81+
for i, (sensor, generators) in enumerate(all_action_generators.items()):
82+
for j, generator in enumerate(generators):
83+
if isinstance(generator, RealNumberActionGenerator):
84+
action = generator.action_from_value(best_params[f'{i}{j}'])
85+
else:
86+
raise TypeError(f"generator type {type(generator)} not supported")
87+
if action is not None:
88+
config[sensor].append(action)
89+
else:
90+
config[sensor].append(generator.default_action)
91+
92+
# Return mapping of sensors and chosen actions for sensors
93+
return [config]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import copy
2+
from collections import defaultdict
3+
import pytest
4+
from ordered_set import OrderedSet
5+
import numpy as np
6+
7+
try:
8+
from ..optuna import OptunaSensorManager
9+
except ImportError:
10+
# Catch optional dependencies import error
11+
pytest.skip(
12+
"Skipping due to missing optional dependencies. Usage of Optuna Sensor Manager requires "
13+
"that the optional package `optuna`is installed.",
14+
allow_module_level=True
15+
)
16+
17+
from ..reward import UncertaintyRewardFunction
18+
from ...hypothesiser.distance import DistanceHypothesiser
19+
from ...measures import Mahalanobis
20+
from ...dataassociator.neighbour import GNNWith2DAssignment
21+
from ...sensor.radar.radar import RadarRotatingBearingRange
22+
from ...sensor.action.dwell_action import ChangeDwellAction
23+
24+
25+
def test_optuna_manager(params):
26+
predictor = params['predictor']
27+
updater = params['updater']
28+
sensor_set = params['sensor_set']
29+
timesteps = params['timesteps']
30+
tracks = params['tracks']
31+
truths = params['truths']
32+
33+
reward_function = UncertaintyRewardFunction(predictor, updater)
34+
optunasensormanager = OptunaSensorManager(sensor_set, reward_function=reward_function,
35+
timeout=0.1)
36+
37+
hypothesiser = DistanceHypothesiser(predictor, updater, measure=Mahalanobis(),
38+
missed_distance=5)
39+
data_associator = GNNWith2DAssignment(hypothesiser)
40+
41+
sensor_history = defaultdict(dict)
42+
dwell_centres = dict()
43+
44+
for timestep in timesteps[1:]:
45+
chosen_actions = optunasensormanager.choose_actions(tracks, timestep)
46+
measurements = set()
47+
for chosen_action in chosen_actions:
48+
for sensor, actions in chosen_action.items():
49+
sensor.add_actions(actions)
50+
for sensor in sensor_set:
51+
sensor.act(timestep)
52+
sensor_history[timestep][sensor] = copy.copy(sensor)
53+
dwell_centres[timestep] = sensor.dwell_centre[0][0]
54+
measurements |= sensor.measure(OrderedSet(truth[timestep] for truth in truths),
55+
noise=False)
56+
hypotheses = data_associator.associate(tracks,
57+
measurements,
58+
timestep)
59+
for track in tracks:
60+
hypothesis = hypotheses[track]
61+
if hypothesis.measurement:
62+
post = updater.update(hypothesis)
63+
track.append(post)
64+
else:
65+
track.append(hypothesis.prediction)
66+
67+
# Double check choose_actions method types are as expected
68+
assert isinstance(chosen_actions, list)
69+
70+
for chosen_actions in chosen_actions:
71+
for sensor, actions in chosen_action.items():
72+
assert isinstance(sensor, RadarRotatingBearingRange)
73+
assert isinstance(actions[0], ChangeDwellAction)
74+
75+
# Check sensor following track as expected
76+
assert dwell_centres[timesteps[5]] - np.radians(135) < 1e-3
77+
assert dwell_centres[timesteps[15]] - np.radians(45) < 1e-3
78+
assert dwell_centres[timesteps[25]] - np.radians(-45) < 1e-3
79+
assert dwell_centres[timesteps[35]] - np.radians(-135) < 1e-3

0 commit comments

Comments
 (0)