Skip to content

Commit 2a79d55

Browse files
committed
refactor valid and test batch mode
1 parent 9e14062 commit 2a79d55

File tree

8 files changed

+241
-192
lines changed

8 files changed

+241
-192
lines changed

tests/test_deep_hydro.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424

2525
# Mock dataset class using random data
2626
class MockDataset(Dataset):
27-
def __init__(self, data_cfgs, is_tra_val_te):
27+
def __init__(self, cfgs, is_tra_val_te):
2828
super(MockDataset, self).__init__()
29-
self.data_cfgs = data_cfgs # Store the passed configuration for later use
29+
self.data_cfgs = cfgs["data_cfgs"]
30+
self.training_cfgs = cfgs[
31+
"training_cfgs"
32+
] # Store the passed configuration for later use
3033
# Simulate other configuration and setup steps
3134

3235
@property
@@ -38,11 +41,11 @@ def nt(self):
3841
return 100
3942

4043
def __len__(self):
41-
return self.ngrid * (self.nt - self.data_cfgs["forecast_length"] + 1)
44+
return self.ngrid * (self.nt - self.training_cfgs["forecast_length"] + 1)
4245

4346
def __getitem__(self, idx):
4447
# Use the stored configurations to generate mock data
45-
rho = self.data_cfgs["forecast_length"]
48+
rho = self.training_cfgs["forecast_length"]
4649
x = torch.randn(rho, self.data_cfgs["input_features"])
4750
y = torch.randn(rho, self.data_cfgs["output_features"])
4851
return x, y
@@ -61,20 +64,29 @@ def dummy_data_cfgs():
6164
"t_range_valid": None,
6265
"case_dir": test_path,
6366
"sampler": "KuaiSampler",
67+
"object_ids": ["02051500", "21401550"],
68+
}
69+
70+
71+
@pytest.fixture()
72+
def dummy_training_cfgs():
73+
return {
6474
"batch_size": 5,
6575
"hindcast_length": 0,
6676
"forecast_length": 30,
6777
"warmup_length": 0,
68-
"object_ids": ["02051500", "21401550"],
6978
}
7079

7180

72-
def test_using_mock_dataset(dummy_data_cfgs):
81+
def test_using_mock_dataset(dummy_data_cfgs, dummy_training_cfgs):
7382
datasets_dict["MockDataset"] = MockDataset
7483
is_tra_val_te = True
7584
dataset_name = "MockDataset"
7685

77-
dataset = datasets_dict[dataset_name](dummy_data_cfgs, is_tra_val_te)
86+
dataset = datasets_dict[dataset_name](
87+
{"data_cfgs": dummy_data_cfgs, "training_cfgs": dummy_training_cfgs},
88+
is_tra_val_te,
89+
)
7890

7991
assert len(dataset) == 710
8092
sample_x, sample_y = dataset[0]
@@ -83,11 +95,11 @@ def test_using_mock_dataset(dummy_data_cfgs):
8395
print(sample_x[2].shape)
8496
print(sample_y.shape)
8597
assert sample_x.shape == (
86-
dummy_data_cfgs["forecast_length"],
98+
dummy_training_cfgs["forecast_length"],
8799
dummy_data_cfgs["input_features"],
88100
)
89101
assert sample_y.shape == (
90-
dummy_data_cfgs["forecast_length"],
102+
dummy_training_cfgs["forecast_length"],
91103
dummy_data_cfgs["output_features"],
92104
)
93105

torchhydro/configs/config.py

Lines changed: 94 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -95,26 +95,6 @@ def default_config_file():
9595
"source_paths": ["../../example/camels_us"],
9696
},
9797
"case_dir": None,
98-
"batch_size": 100,
99-
# we generally have three times: [warmup, hindcast_length, forecast_length]
100-
# For physics-based models, we need warmup; default is 0 as DL models generally don't need it
101-
"warmup_length": 0,
102-
# the length of the history data for forecasting
103-
"hindcast_length": 30,
104-
# the length of the forecast data
105-
"forecast_length": 1,
106-
# config for data of "forecast_length" part
107-
# for each batch, we fix length of hindcast and forecast length.
108-
# data from different lead time with a number representing the lead time,
109-
# for example, now is 2020-09-30, our min_time_interval is 1 day, hindcast length is 30 and forecast length is 1,
110-
# lead_time = 3 means 2020-09-01 to 2020-09-30, and the forecast data is 2020-10-01 from 2020-09-28
111-
# for forecast data, we have two different configurations:
112-
# 1st, we can set a same lead time for all forecast time
113-
# 2020-09-30now, 30hindcast, 2forecast, 3leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 forecast data from 2020-09-28 and 2020-10-02 forecast data from 2020-09-29
114-
# 2nd, we can set a increasing lead time for each forecast time
115-
# 2020-09-30now, 30hindcast, 2forecast, [1, 2]leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 to 2010-10-02 forecast data from 2020-09-30
116-
"lead_time_type": "fixed", # must be fixed or increasing
117-
"lead_time_start": 1,
11898
# the min time step of the input data
11999
"min_time_unit": "D",
120100
# the min time interval of the input data
@@ -244,6 +224,29 @@ def default_config_file():
244224
"port": "12335",
245225
# if train_mode is False, don't train and evaluate
246226
"train_mode": True,
227+
"batch_size": 100,
228+
# we generally have three times: [warmup, hindcast_length, forecast_length]
229+
# warmup period means no observation will be used to calculate loss for it.
230+
# For physics-based models, we generally need warmup to get a better initial state
231+
# its default is 0 as DL models generally don't need it
232+
"warmup_length": 0,
233+
# the length of the history data to forecast
234+
"hindcast_length": 30,
235+
# the length of the forecast data
236+
"forecast_length": 1,
237+
# for each batch, we fix length of hindcast and forecast length.
238+
# data from different lead time with a number representing the lead time,
239+
# for example, now is 2020-09-30, our min_time_interval is 1 day, hindcast length is 30 and forecast length is 1,
240+
# lead_time = 3 means 2020-09-01 to 2020-09-30, and the forecast data is 2020-10-01 from 2020-09-28
241+
# for forecast data, we have two different configurations:
242+
# 1st, we can set a same lead time for all forecast time
243+
# 2020-09-30now, 30hindcast, 2forecast, 3leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 forecast data from 2020-09-28 and 2020-10-02 forecast data from 2020-09-29
244+
# 2nd, we can set a increasing lead time for each forecast time
245+
# 2020-09-30now, 30hindcast, 2forecast, [1, 2]leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 to 2010-10-02 forecast data from 2020-09-30
246+
"lead_time_type": "fixed", # must be fixed or increasing
247+
"lead_time_start": 1,
248+
# valid batch can be organized as same way with training or testing
249+
"valid_batch_mode": "test",
247250
"criterion": "RMSE",
248251
"criterion_params": None,
249252
# "weight_decay": None, a regularization term in loss func
@@ -278,7 +281,6 @@ def default_config_file():
278281
# when we train a model for long time, some accidents may interrupt our training.
279282
# Then we need retrain the model with saved weights, and the start_epoch is not 1 yet.
280283
"start_epoch": 1,
281-
"batch_size": 100,
282284
"random_seed": 1234,
283285
"device": [0, 1, 2],
284286
"multi_targets": 1,
@@ -315,11 +317,20 @@ def default_config_file():
315317
"metrics": ["NSE", "RMSE", "R2", "KGE", "FHV", "FLV"],
316318
"fill_nan": "no",
317319
"explainer": None,
318-
# rolling is 0 means decoder-only model's prediction -- each period has one prediction
320+
# rolling is stride, 0 means each period has only one prediction.
319321
# when rolling>0, such as 1, means perform forecasting each step after 1 period.
320322
# For example, at 8:00am we perform one forecasting and our time-step is 3h,
321323
# rolling=1 means 11:00, 14:00, 17:00 ..., we will perform forecasting
324+
# when rolling>0, we will perform rolling forecast, for each forecasting,
325+
# a rolling window (rwin) can be chosen:
326+
# hindcast (hrwin) and forecast (frwin) length in rwin need to be chosen
327+
# and then rwin = hrwin + frwin
322328
"rolling": 0,
329+
"hrwin": None,
330+
"frwin": None,
331+
# current_idx means we assume a current period in testing periods,
332+
# 0 means all testing periods belong to forecast periods without hindcast part
333+
"current_idx": 0,
323334
"calc_metrics": True,
324335
},
325336
}
@@ -392,6 +403,9 @@ def cmd(
392403
fill_nan=None,
393404
explainer=None,
394405
rolling=None,
406+
current_idx=None,
407+
hrwin=None,
408+
frwin=None,
395409
calc_metrics=None,
396410
start_epoch=None,
397411
stat_dict_file=None,
@@ -403,6 +417,7 @@ def cmd(
403417
patience=None,
404418
min_time_unit=None,
405419
min_time_interval=None,
420+
valid_batch_mode=None,
406421
):
407422
"""input args from cmd"""
408423
parser = argparse.ArgumentParser(
@@ -824,6 +839,27 @@ def cmd(
824839
default=rolling,
825840
type=int,
826841
)
842+
parser.add_argument(
843+
"--current_idx",
844+
dest="current_idx",
845+
help="current_idx",
846+
default=current_idx,
847+
type=int,
848+
)
849+
parser.add_argument(
850+
"--hrwin",
851+
dest="hrwin",
852+
help="hrwin",
853+
default=hrwin,
854+
type=int,
855+
)
856+
parser.add_argument(
857+
"--frwin",
858+
dest="frwin",
859+
help="frwin",
860+
default=frwin,
861+
type=int,
862+
)
827863
parser.add_argument(
828864
"--model_loader",
829865
dest="model_loader",
@@ -916,6 +952,12 @@ def cmd(
916952
default=min_time_interval,
917953
type=int,
918954
)
955+
parser.add_argument(
956+
"--valid_batch_mode",
957+
dest="valid_batch_mode",
958+
help="The batch organization mode of valid data, train means same as train; test means same as test",
959+
default=valid_batch_mode,
960+
)
919961
# To make pytest work in PyCharm, here we use the following code instead of "args = parser.parse_args()":
920962
# https://blog.csdn.net/u014742995/article/details/100119905
921963
args, unknown = parser.parse_known_args()
@@ -1098,18 +1140,20 @@ def update_cfg(cfg_file, new_args):
10981140
if new_args.model_hyperparam is not None:
10991141
# raise AttributeError("Please set the model_hyperparam!!!")
11001142
cfg_file["model_cfgs"]["model_hyperparam"] = new_args.model_hyperparam
1101-
if "batch_size" in new_args.model_hyperparam.keys():
1102-
# TODO: batch_size's setting may conflict with batch_size's direct setting
1103-
cfg_file["data_cfgs"]["batch_size"] = new_args.model_hyperparam[
1104-
"batch_size"
1105-
]
1106-
cfg_file["training_cfgs"]["batch_size"] = new_args.model_hyperparam[
1107-
"batch_size"
1108-
]
1109-
if "forecast_length" in new_args.model_hyperparam.keys():
1110-
cfg_file["data_cfgs"]["forecast_length"] = new_args.model_hyperparam[
1111-
"forecast_length"
1112-
]
1143+
if (
1144+
"batch_size" in new_args.model_hyperparam.keys()
1145+
and new_args.model_hyperparam["batch_size"] != new_args.batch_size
1146+
):
1147+
raise RuntimeError(
1148+
"Please set same batch_size in model_cfgs and training_cfgs"
1149+
)
1150+
if (
1151+
"forecast_length" in new_args.model_hyperparam.keys()
1152+
and new_args.forecast_length != new_args.model_hyperparam["forecast_length"]
1153+
):
1154+
raise RuntimeError(
1155+
"Please set same forecast_length in model_cfgs and training_cfgs"
1156+
)
11131157
# The following two configurations are for encoder-decoder models' seq2seqdataset
11141158
if "hindcast_output_window" in new_args.model_hyperparam.keys():
11151159
cfg_file["data_cfgs"]["hindcast_output_window"] = new_args.model_hyperparam[
@@ -1118,10 +1162,7 @@ def update_cfg(cfg_file, new_args):
11181162
else:
11191163
cfg_file["data_cfgs"]["hindcast_output_window"] = 0
11201164
if new_args.batch_size is not None:
1121-
# raise AttributeError("Please set the batch_size!!!")
1122-
batch_size = new_args.batch_size
1123-
cfg_file["data_cfgs"]["batch_size"] = batch_size
1124-
cfg_file["training_cfgs"]["batch_size"] = batch_size
1165+
cfg_file["training_cfgs"]["batch_size"] = new_args.batch_size
11251166
if new_args.min_time_unit is not None:
11261167
if new_args.min_time_unit not in ["h", "D"]:
11271168
raise ValueError("min_time_unit must be 'h' (HOURLY) or 'D' (DAILY)")
@@ -1136,35 +1177,40 @@ def update_cfg(cfg_file, new_args):
11361177
cfg_file["evaluation_cfgs"]["explainer"] = new_args.explainer
11371178
if new_args.rolling is not None:
11381179
cfg_file["evaluation_cfgs"]["rolling"] = new_args.rolling
1180+
if new_args.current_idx is not None:
1181+
cfg_file["evaluation_cfgs"]["current_idx"] = new_args.current_idx
1182+
if new_args.hrwin is not None:
1183+
cfg_file["evaluation_cfgs"]["hrwin"] = new_args.hrwin
1184+
if new_args.frwin is not None:
1185+
cfg_file["evaluation_cfgs"]["frwin"] = new_args.frwin
11391186
if new_args.model_loader is not None:
11401187
cfg_file["evaluation_cfgs"]["model_loader"] = new_args.model_loader
11411188
if new_args.warmup_length is not None:
1142-
cfg_file["data_cfgs"]["warmup_length"] = new_args.warmup_length
1189+
cfg_file["training_cfgs"]["warmup_length"] = new_args.warmup_length
11431190
if (
11441191
"warmup_length" in new_args.model_hyperparam.keys()
1145-
and cfg_file["data_cfgs"]["warmup_length"]
1146-
!= new_args.model_hyperparam["warmup_length"]
1192+
and new_args.warmup_length != new_args.model_hyperparam["warmup_length"]
11471193
):
11481194
raise RuntimeError(
11491195
"Please set same warmup_length in model_cfgs and data_cfgs"
11501196
)
11511197
if new_args.hindcast_length is not None:
1152-
cfg_file["data_cfgs"]["hindcast_length"] = new_args.hindcast_length
1198+
cfg_file["training_cfgs"]["hindcast_length"] = new_args.hindcast_length
11531199
if new_args.hindcast_length is None and new_args.forecast_history is not None:
11541200
# forecast_history will be deprecated in the future
11551201
warnings.warn(
11561202
"forecast_history will be deprecated in the future, please use hindcast_length instead"
11571203
)
1158-
cfg_file["data_cfgs"]["hindcast_length"] = new_args.forecast_history
1204+
cfg_file["training_cfgs"]["hindcast_length"] = new_args.forecast_history
11591205
if new_args.forecast_length is not None:
1160-
cfg_file["data_cfgs"]["forecast_length"] = new_args.forecast_length
1206+
cfg_file["training_cfgs"]["forecast_length"] = new_args.forecast_length
11611207
if new_args.lead_time_type is not None:
11621208
if new_args.lead_time_type not in ["fixed", "increasing"]:
11631209
raise ValueError("lead_time_type must be 'fixed' or 'increasing'")
1164-
cfg_file["data_cfgs"]["lead_time_type"] = new_args.lead_time_type
1210+
cfg_file["training_cfgs"]["lead_time_type"] = new_args.lead_time_type
11651211
if new_args.lead_time_start is None:
11661212
raise ValueError("lead_time_start must be set when lead_time_type is set")
1167-
cfg_file["data_cfgs"]["lead_time_start"] = new_args.lead_time_start
1213+
cfg_file["training_cfgs"]["lead_time_start"] = new_args.lead_time_start
11681214
if new_args.start_epoch is not None:
11691215
cfg_file["training_cfgs"]["start_epoch"] = new_args.start_epoch
11701216
if new_args.stat_dict_file is not None:
@@ -1193,6 +1239,8 @@ def update_cfg(cfg_file, new_args):
11931239
cfg_file["training_cfgs"]["early_stopping"] = new_args.early_stopping
11941240
if new_args.lr_scheduler is not None:
11951241
cfg_file["training_cfgs"]["lr_scheduler"] = new_args.lr_scheduler
1242+
if new_args.valid_batch_mode is not None:
1243+
cfg_file["training_cfgs"]["valid_batch_mode"] = new_args.valid_batch_mode
11961244
# print("the updated config:\n", json.dumps(cfg_file, indent=4, ensure_ascii=False))
11971245

11981246

0 commit comments

Comments
 (0)