Skip to content

Commit a8dec71

Browse files
committed
test for forecast data dapengscaler
1 parent 97c09ca commit a8dec71

File tree

4 files changed

+145
-86
lines changed

4 files changed

+145
-86
lines changed

torchhydro/configs/config.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def default_config_file():
127127
"t_range_test": ["1993-01-01", "1994-01-01"],
128128
# the output
129129
"target_cols": [Q_CAMELS_US_NAME],
130-
"target_rm_nan": True,
130+
"target_rm_nan": False,
131131
# only for cases in which target data will be used as input:
132132
# data assimilation -- use streamflow from period 0 to t-1 (TODO: not included now)
133133
# for physics-based model -- use streamflow to calibrate models
@@ -169,6 +169,7 @@ def default_config_file():
169169
# for each period, they have multiple forecast data with different lead time
170170
# hence we list them as a seperate type
171171
"forecast_cols": None,
172+
"forecast_rm_nan": True,
172173
# global variables such as ENSO indictors are used in some long term models
173174
"global_cols": None,
174175
# specify the data source of each variable
@@ -346,7 +347,7 @@ def cmd(
346347
lr_scheduler=None,
347348
opt_param=None,
348349
batch_size=None,
349-
warmup_length=0,
350+
warmup_length=None,
350351
# forecast_history will be deprecated in the future
351352
forecast_history=None,
352353
hindcast_length=None,
@@ -362,9 +363,9 @@ def cmd(
362363
weight_path=None,
363364
continue_train=None,
364365
var_c=None,
365-
c_rm_nan=1,
366+
c_rm_nan=None,
366367
var_t=None,
367-
t_rm_nan=1,
368+
t_rm_nan=None,
368369
n_output=None,
369370
loss_func=None,
370371
model_hyperparam=None,
@@ -375,21 +376,22 @@ def cmd(
375376
var_g=None,
376377
var_out=None,
377378
var_to_source_map=None,
378-
out_rm_nan=0,
379-
target_as_input=0,
380-
constant_only=0,
379+
out_rm_nan=None,
380+
f_rm_nan=None,
381+
target_as_input=None,
382+
constant_only=None,
381383
gage_id_screen=None,
382384
loss_param=None,
383385
metrics=None,
384386
fill_nan=None,
385387
explainer=None,
386388
rolling=None,
387389
calc_metrics=None,
388-
start_epoch=1,
390+
start_epoch=None,
389391
stat_dict_file=None,
390392
num_workers=None,
391393
which_first_tensor=None,
392-
ensemble=0,
394+
ensemble=None,
393395
ensemble_items=None,
394396
early_stopping=None,
395397
patience=None,
@@ -746,6 +748,13 @@ def cmd(
746748
default=out_rm_nan,
747749
type=int,
748750
)
751+
parser.add_argument(
752+
"--f_rm_nan",
753+
dest="f_rm_nan",
754+
help="if true, we remove NaN value for var_f data when scaling",
755+
default=f_rm_nan,
756+
type=int,
757+
)
749758
parser.add_argument(
750759
"--target_as_input",
751760
dest="target_as_input",
@@ -1013,7 +1022,8 @@ def update_cfg(cfg_file, new_args):
10131022
cfg_file["data_cfgs"]["constant_cols"] = []
10141023
else:
10151024
cfg_file["data_cfgs"]["constant_cols"] = new_args.var_c
1016-
cfg_file["data_cfgs"]["constant_rm_nan"] = bool(new_args.c_rm_nan != 0)
1025+
if new_args.c_rm_nan is not None:
1026+
cfg_file["data_cfgs"]["constant_rm_nan"] = bool(new_args.c_rm_nan > 0)
10171027
if new_args.var_t is not None:
10181028
cfg_file["data_cfgs"]["relevant_cols"] = new_args.var_t
10191029
print(
@@ -1022,7 +1032,8 @@ def update_cfg(cfg_file, new_args):
10221032
print("If you have POTENTIAL_EVAPOTRANSPIRATION, please set it the 2nd!!!-")
10231033
if new_args.var_t_type is not None:
10241034
cfg_file["data_cfgs"]["relevant_types"] = new_args.var_t_type
1025-
cfg_file["data_cfgs"]["relevant_rm_nan"] = bool(new_args.t_rm_nan != 0)
1035+
if new_args.t_rm_nan is not None:
1036+
cfg_file["data_cfgs"]["relevant_rm_nan"] = bool(new_args.t_rm_nan > 0)
10261037
if new_args.var_f is not None:
10271038
cfg_file["data_cfgs"]["forecast_cols"] = new_args.var_f
10281039
if new_args.var_g is not None:
@@ -1034,10 +1045,14 @@ def update_cfg(cfg_file, new_args):
10341045
)
10351046
if new_args.var_to_source_map is not None:
10361047
cfg_file["data_cfgs"]["var_to_source_map"] = new_args.var_to_source_map
1037-
cfg_file["data_cfgs"]["target_rm_nan"] = bool(new_args.out_rm_nan != 0)
1038-
if new_args.target_as_input == 0:
1039-
cfg_file["data_cfgs"]["target_as_input"] = False
1040-
cfg_file["data_cfgs"]["constant_only"] = bool(new_args.constant_only != 0)
1048+
if new_args.out_rm_nan is not None:
1049+
cfg_file["data_cfgs"]["target_rm_nan"] = bool(new_args.out_rm_nan > 0)
1050+
if new_args.f_rm_nan is not None:
1051+
cfg_file["data_cfgs"]["forecast_rm_nan"] = bool(new_args.f_rm_nan > 0)
1052+
if new_args.target_as_input is not None:
1053+
cfg_file["data_cfgs"]["target_as_input"] = bool(new_args.target_as_input > 0)
1054+
if new_args.constant_only is not None:
1055+
cfg_file["data_cfgs"]["constant_only"] = bool(new_args.constant_only > 0)
10411056
else:
10421057
cfg_file["data_cfgs"]["target_as_input"] = True
10431058
if new_args.calc_metrics is not None:
@@ -1055,7 +1070,7 @@ def update_cfg(cfg_file, new_args):
10551070
if new_args.weight_path is not None:
10561071
cfg_file["model_cfgs"]["weight_path"] = new_args.weight_path
10571072
continue_train = bool(
1058-
new_args.continue_train is not None and new_args.continue_train != 0
1073+
new_args.continue_train is not None and new_args.continue_train > 0
10591074
)
10601075
cfg_file["model_cfgs"]["continue_train"] = continue_train
10611076
if new_args.weight_path_add is not None:
@@ -1109,7 +1124,7 @@ def update_cfg(cfg_file, new_args):
11091124
cfg_file["evaluation_cfgs"]["rolling"] = new_args.rolling
11101125
if new_args.model_loader is not None:
11111126
cfg_file["evaluation_cfgs"]["model_loader"] = new_args.model_loader
1112-
if new_args.warmup_length > 0:
1127+
if new_args.warmup_length is not None:
11131128
cfg_file["data_cfgs"]["warmup_length"] = new_args.warmup_length
11141129
if (
11151130
"warmup_length" in new_args.model_hyperparam.keys()
@@ -1136,7 +1151,7 @@ def update_cfg(cfg_file, new_args):
11361151
if new_args.lead_time_start is None:
11371152
raise ValueError("lead_time_start must be set when lead_time_type is set")
11381153
cfg_file["data_cfgs"]["lead_time_start"] = new_args.lead_time_start
1139-
if new_args.start_epoch > 1:
1154+
if new_args.start_epoch is not None:
11401155
cfg_file["training_cfgs"]["start_epoch"] = new_args.start_epoch
11411156
if new_args.stat_dict_file is not None:
11421157
stat_dict_file = new_args.stat_dict_file

torchhydro/datasets/data_scalers.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def __init__(
6666
Parameters
6767
----------
6868
vars_data
69-
data for all variables used
69+
data for all variables used.
70+
the dim must be (basin, time, lead_step, var) for 4-d array;
71+
the dim must be (basin, time, var) for 3-d array;
72+
the dim must be (basin, time) for 2-d array;
7073
data_cfgs
7174
configs for reading data
7275
is_tra_val_te
@@ -77,21 +80,13 @@ def __init__(
7780
other optional parameters for ScalerHub
7881
"""
7982
self.data_cfgs = data_cfgs
80-
vars_data_map = {
81-
key: (
82-
var_data.transpose("basin", "time", "variable")
83-
if var_data.ndim == 3
84-
else var_data.transpose("basin", "variable")
85-
)
86-
for key, var_data in vars_data.items()
87-
}
8883
scaler_type = data_cfgs["scaler"]
8984
pbm_norm = data_cfgs["scaler_params"]["pbm_norm"]
9085
if scaler_type == "DapengScaler":
9186
gamma_norm_cols = data_cfgs["scaler_params"]["gamma_norm_cols"]
9287
prcp_norm_cols = data_cfgs["scaler_params"]["prcp_norm_cols"]
9388
scaler = DapengScaler(
94-
vars_data_map,
89+
vars_data,
9590
data_cfgs,
9691
is_tra_val_te,
9792
prcp_norm_cols=prcp_norm_cols,
@@ -101,7 +96,7 @@ def __init__(
10196
)
10297
elif scaler_type in SCALER_DICT.keys():
10398
scaler = SklearnScaler(
104-
vars_data_map,
99+
vars_data,
105100
data_cfgs,
106101
is_tra_val_te,
107102
pbm_norm=pbm_norm,
@@ -110,7 +105,7 @@ def __init__(
110105
raise NotImplementedError(
111106
"We don't provide this Scaler now!!! Please choose another one: DapengScaler or key in SCALER_DICT"
112107
)
113-
self.norm_data = scaler.load_norm_data(vars_data_map)
108+
self.norm_data = scaler.load_norm_data(vars_data)
114109
# we will use target_scaler during denormalization
115110
self.target_scaler = scaler
116111
print("Finish Normalization\n")

0 commit comments

Comments
 (0)