-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinference_model.py
More file actions
96 lines (70 loc) · 2.89 KB
/
inference_model.py
File metadata and controls
96 lines (70 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class InferenceModel(object):
def get_inference(self, features, mode):
"""
Get the inference associated with this model.
This should create all variables necessary. For example, for
classification problems, this would be the logits.
Args:
features: possibly nested structure of tensors returned by
the data source's `get_inputs` (first return value).
mode: on of `tf.estimator.ModeKeys` - 'train', 'eval' or 'infer'.
Returns:
possibly nested structure of tensors for predictions/losses.
"""
raise NotImplementedError('Abstract method')
def get_predictions(self, features, inference):
"""
Convert inferences to predictions.
This should not introduce new trainable parameters.
Args:
features: first output of `DataSource.get_inputs` - possibly nested
structure of batched tensors.
inference: output of `self.get_inference`.
Returns:
possibly nested structure of tensor predictions.
Defaults to returning inference unchanged.
"""
return inference
def prediction_vis(self, prediction_data):
"""
Get a vis of prediction data for a single example.
Args:
prediciton_data: numpy data with same structure as
`self.get_predictions` output.
Returns:
`Visualization`, or iterable of `Visualization`s
"""
raise NotImplementedError('Abstract method')
def get_warm_start_settings(self):
"""
Get `tf.estimator.WarmStartSettings` for transfer learning.
Can be None, in which case variables are initialized from scratch.
"""
return None
class DelegatingInferenceModel(InferenceModel):
"""
Wrapper class that defaults to redirecting all methods to another model.
Derived classes presumably override some methods.
"""
def __init__(self, base):
self._base = base
@property
def base(self):
return self._base
def get_inference(self, features, mode):
return self._base.get_inference(features, mode)
def get_predictions(self, features, inference):
return self._base.get_predictions(features, inference)
def prediction_vis(self, prediction_data):
return self._base.prediction_vis(prediction_data)
def get_warm_start_settings(self):
return self._base.get_warm_start_settings()
class TransferInferenceModel(DelegatingInferenceModel):
def __init__(self, base, warm_start_settings):
self._warm_start_settings = warm_start_settings
super(TransferInferenceModel, self).__init__(base)
def get_warm_start_settings(self):
return self._warm_start_settings