Skip to content

Commit edb43ec

Browse files
committed
Tweaked plotting code.
1 parent ab62f32 commit edb43ec

File tree

7 files changed

+311
-20
lines changed

7 files changed

+311
-20
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""Makes 4-panel figure to show results of permutation test."""
2+
3+
import os
4+
import argparse
5+
import numpy
6+
import matplotlib
7+
matplotlib.use('agg')
8+
import matplotlib.pyplot as pyplot
9+
from gewittergefahr.gg_utils import file_system_utils
10+
from gewittergefahr.deep_learning import permutation_utils as gg_permutation
11+
from gewittergefahr.plotting import permutation_plotting
12+
from gewittergefahr.plotting import imagemagick_utils
13+
from ml4convection.machine_learning import permutation
14+
15+
BAR_FACE_COLOUR = numpy.array([27, 158, 119], dtype=float) / 255
16+
17+
FIGURE_WIDTH_INCHES = 15.
18+
FIGURE_HEIGHT_INCHES = 15.
19+
FIGURE_RESOLUTION_DPI = 300
20+
CONCAT_FIGURE_SIZE_PX = int(1e7)
21+
22+
PREDICTOR_NAME_TO_VERBOSE = {
23+
'Band 8': r'Band 8 (6.25 $\mu$m)',
24+
'Band 9': r'Band 9 (6.95 $\mu$m)',
25+
'Band 10': r'Band 10 (7.35 $\mu$m)',
26+
'Band 11': r'Band 11 (8.60 $\mu$m)',
27+
'Band 13': r'Band 13 (10.45 $\mu$m)',
28+
'Band 14': r'Band 14 (11.20 $\mu$m)',
29+
'Band 16': r'Band 16 (13.30 $\mu$m)'
30+
}
31+
32+
FORWARD_FILE_ARG_NAME = 'input_forward_file_name'
33+
BACKWARDS_FILE_ARG_NAME = 'input_backwards_file_name'
34+
NUM_PREDICTORS_ARG_NAME = 'num_predictors_to_plot'
35+
CONFIDENCE_LEVEL_ARG_NAME = 'confidence_level'
36+
OUTPUT_DIR_ARG_NAME = 'output_dir_name'
37+
38+
FORWARD_FILE_HELP_STRING = (
39+
'Path to file with results of forward test (will be read by '
40+
'`permutation.read_file` in the ml4convection library).'
41+
)
42+
BACKWARDS_FILE_HELP_STRING = (
43+
'Path to file with results of backwards test (will be read by '
44+
'`permutation.read_file` in the ml4convection library).'
45+
)
46+
NUM_PREDICTORS_HELP_STRING = (
47+
'Will plot only the `{0:s}` most important predictors in each panel. To '
48+
'plot all predictors, leave this argument alone.'
49+
).format(NUM_PREDICTORS_ARG_NAME)
50+
51+
CONFIDENCE_LEVEL_HELP_STRING = (
52+
'Confidence level for error bars (in range 0...1).'
53+
)
54+
OUTPUT_DIR_HELP_STRING = (
55+
'Path to output directory (figures will be saved here).'
56+
)
57+
58+
INPUT_ARG_PARSER = argparse.ArgumentParser()
59+
INPUT_ARG_PARSER.add_argument(
60+
'--' + FORWARD_FILE_ARG_NAME, type=str, required=True,
61+
help=FORWARD_FILE_HELP_STRING
62+
)
63+
INPUT_ARG_PARSER.add_argument(
64+
'--' + BACKWARDS_FILE_ARG_NAME, type=str, required=True,
65+
help=BACKWARDS_FILE_HELP_STRING
66+
)
67+
INPUT_ARG_PARSER.add_argument(
68+
'--' + NUM_PREDICTORS_ARG_NAME, type=int, required=False, default=-1,
69+
help=NUM_PREDICTORS_HELP_STRING
70+
)
71+
INPUT_ARG_PARSER.add_argument(
72+
'--' + CONFIDENCE_LEVEL_ARG_NAME, type=float, required=False, default=0.95,
73+
help=CONFIDENCE_LEVEL_HELP_STRING
74+
)
75+
INPUT_ARG_PARSER.add_argument(
76+
'--' + OUTPUT_DIR_ARG_NAME, type=str, required=True,
77+
help=OUTPUT_DIR_HELP_STRING
78+
)
79+
80+
81+
def _results_to_gg_format(permutation_dict):
82+
"""Converts permutation results from ml4rt format to GewitterGefahr format.
83+
84+
:param permutation_dict: Dictionary created by `run_forward_test` or
85+
`run_backwards_test` in `ml4rt.machine_learning.permutation`.
86+
:return: permutation_dict: Same but in format created by `run_forward_test`
87+
or `run_backwards_test` in `gewittergefahr.deep_learning.permutation`.
88+
"""
89+
90+
permutation_dict[gg_permutation.ORIGINAL_COST_ARRAY_KEY] = (
91+
permutation_dict[permutation.ORIGINAL_COST_KEY]
92+
)
93+
94+
permutation_dict[gg_permutation.BACKWARDS_FLAG] = (
95+
permutation_dict[permutation.BACKWARDS_FLAG_KEY]
96+
)
97+
98+
permutation_dict[gg_permutation.BEST_PREDICTORS_KEY] = [
99+
PREDICTOR_NAME_TO_VERBOSE[s] for s in
100+
permutation_dict[permutation.BEST_PREDICTORS_KEY]
101+
]
102+
103+
permutation_dict[gg_permutation.STEP1_PREDICTORS_KEY] = [
104+
PREDICTOR_NAME_TO_VERBOSE[s] for s in
105+
permutation_dict[permutation.STEP1_PREDICTORS_KEY]
106+
]
107+
108+
return permutation_dict
109+
110+
111+
def _run(forward_file_name, backwards_file_name, num_predictors_to_plot,
112+
confidence_level, output_dir_name):
113+
"""Makes 4-panel figure to show results of permutation test.
114+
115+
This is effectively the main method.
116+
117+
:param forward_file_name: See documentation at top of file.
118+
:param backwards_file_name: Same.
119+
:param num_predictors_to_plot: Same.
120+
:param confidence_level: Same.
121+
:param output_dir_name: Same.
122+
"""
123+
124+
if num_predictors_to_plot <= 0:
125+
num_predictors_to_plot = None
126+
127+
file_system_utils.mkdir_recursive_if_necessary(
128+
directory_name=output_dir_name
129+
)
130+
131+
print('Reading data from: "{0:s}"...'.format(forward_file_name))
132+
forward_permutation_dict = permutation.read_file(forward_file_name)
133+
forward_permutation_dict = _results_to_gg_format(forward_permutation_dict)
134+
135+
print('Reading data from: "{0:s}"...'.format(backwards_file_name))
136+
backwards_permutation_dict = permutation.read_file(backwards_file_name)
137+
backwards_permutation_dict = _results_to_gg_format(
138+
backwards_permutation_dict
139+
)
140+
141+
figure_object, axes_object = pyplot.subplots(
142+
1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)
143+
)
144+
permutation_plotting.plot_single_pass_test(
145+
permutation_dict=forward_permutation_dict, axes_object=axes_object,
146+
num_predictors_to_plot=num_predictors_to_plot,
147+
plot_percent_increase=False, confidence_level=confidence_level,
148+
bar_face_colour=BAR_FACE_COLOUR
149+
)
150+
axes_object.set_title('Single-pass forward')
151+
axes_object.set_xlabel('')
152+
153+
this_file_name = '{0:s}/single_pass_forward.jpg'.format(output_dir_name)
154+
panel_file_names = [this_file_name]
155+
156+
print('Saving figure to: "{0:s}"...'.format(panel_file_names[-1]))
157+
figure_object.savefig(
158+
panel_file_names[-1], dpi=FIGURE_RESOLUTION_DPI,
159+
pad_inches=0, bbox_inches='tight'
160+
)
161+
pyplot.close(figure_object)
162+
163+
figure_object, axes_object = pyplot.subplots(
164+
1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)
165+
)
166+
permutation_plotting.plot_multipass_test(
167+
permutation_dict=forward_permutation_dict, axes_object=axes_object,
168+
num_predictors_to_plot=num_predictors_to_plot,
169+
plot_percent_increase=False, confidence_level=confidence_level,
170+
bar_face_colour=BAR_FACE_COLOUR
171+
)
172+
axes_object.set_title('Multi-pass forward')
173+
axes_object.set_xlabel('')
174+
axes_object.set_ylabel('')
175+
176+
this_file_name = '{0:s}/multi_pass_forward.jpg'.format(output_dir_name)
177+
panel_file_names.append(this_file_name)
178+
179+
print('Saving figure to: "{0:s}"...'.format(panel_file_names[-1]))
180+
figure_object.savefig(
181+
panel_file_names[-1], dpi=FIGURE_RESOLUTION_DPI,
182+
pad_inches=0, bbox_inches='tight'
183+
)
184+
pyplot.close(figure_object)
185+
186+
figure_object, axes_object = pyplot.subplots(
187+
1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)
188+
)
189+
permutation_plotting.plot_single_pass_test(
190+
permutation_dict=backwards_permutation_dict, axes_object=axes_object,
191+
num_predictors_to_plot=num_predictors_to_plot,
192+
plot_percent_increase=False, confidence_level=confidence_level,
193+
bar_face_colour=BAR_FACE_COLOUR
194+
)
195+
axes_object.set_title('Single-pass backward')
196+
axes_object.set_xlabel('1 - FSS')
197+
198+
this_file_name = '{0:s}/single_pass_backward.jpg'.format(output_dir_name)
199+
panel_file_names.append(this_file_name)
200+
201+
print('Saving figure to: "{0:s}"...'.format(panel_file_names[-1]))
202+
figure_object.savefig(
203+
panel_file_names[-1], dpi=FIGURE_RESOLUTION_DPI,
204+
pad_inches=0, bbox_inches='tight'
205+
)
206+
pyplot.close(figure_object)
207+
208+
figure_object, axes_object = pyplot.subplots(
209+
1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)
210+
)
211+
permutation_plotting.plot_multipass_test(
212+
permutation_dict=backwards_permutation_dict, axes_object=axes_object,
213+
num_predictors_to_plot=num_predictors_to_plot,
214+
plot_percent_increase=False, confidence_level=confidence_level,
215+
bar_face_colour=BAR_FACE_COLOUR
216+
)
217+
axes_object.set_title('Multi-pass backward')
218+
axes_object.set_xlabel('1 - FSS')
219+
axes_object.set_ylabel('')
220+
221+
this_file_name = '{0:s}/multi_pass_backward.jpg'.format(output_dir_name)
222+
panel_file_names.append(this_file_name)
223+
224+
print('Saving figure to: "{0:s}"...'.format(panel_file_names[-1]))
225+
figure_object.savefig(
226+
panel_file_names[-1], dpi=FIGURE_RESOLUTION_DPI,
227+
pad_inches=0, bbox_inches='tight'
228+
)
229+
pyplot.close(figure_object)
230+
231+
concat_figure_file_name = '{0:s}/permutation_test.jpg'.format(
232+
output_dir_name
233+
)
234+
print('Concatenating panels to: "{0:s}"...'.format(concat_figure_file_name))
235+
236+
imagemagick_utils.concatenate_images(
237+
input_file_names=panel_file_names,
238+
output_file_name=concat_figure_file_name,
239+
num_panel_rows=2, num_panel_columns=2
240+
)
241+
imagemagick_utils.trim_whitespace(
242+
input_file_name=concat_figure_file_name,
243+
output_file_name=concat_figure_file_name
244+
)
245+
imagemagick_utils.resize_image(
246+
input_file_name=concat_figure_file_name,
247+
output_file_name=concat_figure_file_name,
248+
output_size_pixels=CONCAT_FIGURE_SIZE_PX
249+
)
250+
251+
for this_file_name in panel_file_names:
252+
os.remove(this_file_name)
253+
254+
255+
if __name__ == '__main__':
256+
INPUT_ARG_OBJECT = INPUT_ARG_PARSER.parse_args()
257+
258+
_run(
259+
forward_file_name=getattr(INPUT_ARG_OBJECT, FORWARD_FILE_ARG_NAME),
260+
backwards_file_name=getattr(INPUT_ARG_OBJECT, BACKWARDS_FILE_ARG_NAME),
261+
num_predictors_to_plot=getattr(
262+
INPUT_ARG_OBJECT, NUM_PREDICTORS_ARG_NAME
263+
),
264+
confidence_level=getattr(INPUT_ARG_OBJECT, CONFIDENCE_LEVEL_ARG_NAME),
265+
output_dir_name=getattr(INPUT_ARG_OBJECT, OUTPUT_DIR_ARG_NAME)
266+
)

ml4convection/machine_learning/neural_net.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Methods for training and applying neural nets."""
22

33
import copy
4-
import random
54
import os.path
65
import dill
76
import numpy
7+
numpy.random.seed(6695)
88
import keras
99
import tensorflow
10+
tensorflow.random.set_seed(6695)
1011
import tensorflow.keras as tf_keras
1112
from gewittergefahr.gg_utils import file_system_utils
1213
from gewittergefahr.gg_utils import error_checking
@@ -1497,7 +1498,9 @@ def generator_full_grid(option_dict):
14971498
' are available.'
14981499
)
14991500

1500-
random.shuffle(valid_date_strings)
1501+
valid_date_strings = numpy.array(valid_date_strings)
1502+
numpy.random.shuffle(valid_date_strings)
1503+
valid_date_strings = valid_date_strings.tolist()
15011504
date_index = 0
15021505

15031506
while True:
@@ -1743,7 +1746,8 @@ def train_model(
17431746
num_validation_batches_per_epoch, validation_option_dict,
17441747
mask_matrix, full_mask_matrix, loss_function_name, metric_names,
17451748
do_early_stopping=True,
1746-
plateau_lr_multiplier=DEFAULT_LEARNING_RATE_MULTIPLIER):
1749+
plateau_lr_multiplier=DEFAULT_LEARNING_RATE_MULTIPLIER,
1750+
save_every_epoch=True):
17471751
"""Trains neural net on either full grid or partial grids.
17481752
17491753
M = number of rows in full grid
@@ -1785,6 +1789,8 @@ def train_model(
17851789
:param plateau_lr_multiplier: Multiplier for learning rate. Learning
17861790
rate will be multiplied by this factor upon plateau in validation
17871791
performance.
1792+
:param save_every_epoch: Boolean flag. If True, will save new model after
1793+
every epoch.
17881794
"""
17891795

17901796
file_system_utils.mkdir_recursive_if_necessary(
@@ -1799,6 +1805,7 @@ def train_model(
17991805
error_checking.assert_is_integer(num_validation_batches_per_epoch)
18001806
error_checking.assert_is_geq(num_validation_batches_per_epoch, 2)
18011807
error_checking.assert_is_boolean(do_early_stopping)
1808+
error_checking.assert_is_boolean(save_every_epoch)
18021809

18031810
error_checking.assert_is_numpy_array(mask_matrix, num_dimensions=2)
18041811
error_checking.assert_is_numpy_array(full_mask_matrix, num_dimensions=2)
@@ -1840,16 +1847,23 @@ def train_model(
18401847
validation_option_dict[this_key] = training_option_dict[this_key]
18411848

18421849
validation_option_dict = _check_generator_args(validation_option_dict)
1843-
model_file_name = '{0:s}/model.h5'.format(output_dir_name)
1850+
1851+
if save_every_epoch:
1852+
model_file_name = (
1853+
output_dir_name +
1854+
'/model_epoch={epoch:03d}_val-loss={val_loss:.6f}.h5'
1855+
)
1856+
else:
1857+
model_file_name = '{0:s}/model.h5'.format(output_dir_name)
18441858

18451859
history_object = keras.callbacks.CSVLogger(
18461860
filename='{0:s}/history.csv'.format(output_dir_name),
18471861
separator=',', append=False
18481862
)
18491863
checkpoint_object = keras.callbacks.ModelCheckpoint(
18501864
filepath=model_file_name, monitor='val_loss', verbose=1,
1851-
save_best_only=do_early_stopping, save_weights_only=False, mode='min',
1852-
period=1
1865+
save_best_only=not save_every_epoch, save_weights_only=False,
1866+
mode='min', period=1
18531867
)
18541868
list_of_callback_objects = [history_object, checkpoint_object]
18551869

ml4convection/plotting/evaluation_plotting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def plot_performance_diagram(
780780
if num_bootstrap_reps > 1:
781781
polygon_coord_matrix = confidence_interval_to_polygon(
782782
x_value_matrix=success_ratio_matrix, y_value_matrix=pod_matrix,
783-
confidence_level=confidence_level, same_order=False
783+
confidence_level=confidence_level, same_order=True
784784
)
785785

786786
polygon_colour = matplotlib.colors.to_rgba(line_colour, POLYGON_OPACITY)

ml4convection/scripts/plot_composite_saliency_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def _plot_predictors(brightness_temp_matrix_kelvins, band_numbers,
162162
cbar_orientation_string = None
163163

164164
colour_bar_object = satellite_plotting.plot_2d_grid_xy(
165-
brightness_temp_matrix_kelvins=brightness_temp_matrix_kelvins,
165+
brightness_temp_matrix_kelvins=
166+
brightness_temp_matrix_kelvins[..., j, k],
166167
axes_object=axes_object_matrix[j, k],
167168
cbar_orientation_string=cbar_orientation_string,
168169
font_size=FONT_SIZE

ml4convection/scripts/plot_evaluation_by_time.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def _run(input_dir_name, probability_threshold, confidence_level,
590590
# Plot hourly reliability curves.
591591
figure_object, axes_object = _plot_reliability_curves(
592592
score_tables_xarray=hourly_score_tables_xarray,
593-
confidence_level=None
593+
confidence_level=confidence_level
594594
)
595595
axes_object.set_title('Reliability curve by hour')
596596

@@ -608,7 +608,7 @@ def _run(input_dir_name, probability_threshold, confidence_level,
608608
# Plot monthly reliability curves.
609609
figure_object, axes_object = _plot_reliability_curves(
610610
score_tables_xarray=monthly_score_tables_xarray,
611-
confidence_level=None
611+
confidence_level=confidence_level
612612
)
613613
axes_object.set_title('Reliability curve by month')
614614

0 commit comments

Comments
 (0)