Skip to content

Commit 97c09ca

Browse files
committed
refactor scalers for adding more var type data
1 parent 59b9bf7 commit 97c09ca

File tree

8 files changed

+382
-462
lines changed

8 files changed

+382
-462
lines changed

tests/test_data_scalers.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def sample_data():
6262
def test_dapeng_scaler_initialization(sample_data):
6363
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
6464
scaler = DapengScaler(
65-
target_vars=target_vars,
66-
relevant_vars=relevant_vars,
67-
constant_vars=constant_vars,
65+
vars_data=[target_vars, relevant_vars, constant_vars],
6866
data_cfgs=data_cfgs,
6967
is_tra_val_te="train",
7068
)
@@ -77,9 +75,7 @@ def test_dapeng_scaler_initialization(sample_data):
7775
def test_dapeng_scaler_cal_stat_all(sample_data):
7876
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
7977
scaler = DapengScaler(
80-
target_vars=target_vars,
81-
relevant_vars=relevant_vars,
82-
constant_vars=constant_vars,
78+
vars_data=[target_vars, relevant_vars, constant_vars],
8379
data_cfgs=data_cfgs,
8480
is_tra_val_te="train",
8581
data_source=SelfMadeHydroDataset(
@@ -96,9 +92,7 @@ def test_dapeng_scaler_cal_stat_all(sample_data):
9692
def test_dapeng_scaler_load_data_and_denorm(sample_data):
9793
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
9894
scaler = DapengScaler(
99-
target_vars=target_vars,
100-
relevant_vars=relevant_vars,
101-
constant_vars=constant_vars,
95+
[target_vars, relevant_vars, constant_vars],
10296
data_cfgs=data_cfgs,
10397
is_tra_val_te="train",
10498
data_source=SelfMadeHydroDataset(
@@ -134,9 +128,7 @@ def test_dapeng_scaler_load_data_and_denorm(sample_data):
134128
def test_sklearn_scale_train_mode(sample_data):
135129
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
136130
scaler_hub = ScalerHub(
137-
target_vars=target_vars,
138-
relevant_vars=relevant_vars,
139-
constant_vars=constant_vars,
131+
vars_data=[target_vars, relevant_vars, constant_vars],
140132
data_cfgs=data_cfgs,
141133
is_tra_val_te="train",
142134
)
@@ -161,9 +153,7 @@ def test_sklearn_scale_train_mode(sample_data):
161153
def test_sklearn_scale_test_mode_with_existing_scaler(sample_data):
162154
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
163155
scaler_hub = ScalerHub(
164-
target_vars=target_vars,
165-
relevant_vars=relevant_vars,
166-
constant_vars=constant_vars,
156+
vars_data=[target_vars, relevant_vars, constant_vars],
167157
data_cfgs=data_cfgs,
168158
is_tra_val_te="train",
169159
)
@@ -189,9 +179,7 @@ def test_sklearn_scale_test_mode_with_existing_scaler(sample_data):
189179
def test_sklearn_scale_test_mode_without_scaler_file(sample_data):
190180
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
191181
scaler_hub = ScalerHub(
192-
target_vars=target_vars,
193-
relevant_vars=relevant_vars,
194-
constant_vars=constant_vars,
182+
vars_data=[target_vars, relevant_vars, constant_vars],
195183
data_cfgs=data_cfgs,
196184
is_tra_val_te="test",
197185
)

tests/test_data_sets.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2024-05-27 13:33:08
4-
LastEditTime: 2024-11-05 18:20:19
4+
LastEditTime: 2025-04-19 07:54:43
55
LastEditors: Wenyu Ouyang
66
Description: Unit test for datasets
7-
FilePath: \torchhydro\tests\test_data_sets.py
7+
FilePath: /torchhydro/tests/test_data_sets.py
88
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -17,10 +17,6 @@
1717
import pickle
1818
from sklearn.preprocessing import StandardScaler
1919
from torchhydro.datasets.data_sets import BaseDataset, Seq2SeqDataset
20-
from torchhydro.datasets.data_source import (
21-
SelfMadeForecastDataset,
22-
SelfMadeForecastDataset_P,
23-
)
2420
from torchhydro.datasets.data_sources import data_sources_dict
2521

2622

tests/test_resulter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2024-09-15 11:23:28
4-
LastEditTime: 2024-09-15 17:07:04
4+
LastEditTime: 2025-04-19 14:11:24
55
LastEditors: Wenyu Ouyang
66
Description: Test the Resulter class
7-
FilePath: \torchhydro\tests\test_resulter.py
7+
FilePath: /torchhydro/tests/test_resulter.py
88
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -95,4 +95,4 @@ def mock_calculate_and_record_metrics(obs, pred, metrics, col, fill_nan, eval_lo
9595
for col in mock_resulter.cfgs["data_cfgs"]["target_cols"]:
9696
for metric in mock_resulter.cfgs["evaluation_cfgs"]["metrics"]:
9797
assert f"{metric} of {col}" in eval_log
98-
assert isinstance(eval_log[f"{metric} of {col}"], list)
98+
assert isinstance(eval_log[f"{metric} of {col}"], list)

torchhydro/configs/config.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2021-12-31 11:08:29
4-
LastEditTime: 2025-04-17 21:07:02
4+
LastEditTime: 2025-04-19 11:41:02
55
LastEditors: Wenyu Ouyang
66
Description: Config for hydroDL
7-
FilePath: /torchhydro/torchhydro/configs/config.py
7+
FilePath: /HydroForecastEval/mnt/disk1/owen/code/torchhydro/torchhydro/configs/config.py
88
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -103,6 +103,18 @@ def default_config_file():
103103
"hindcast_length": 30,
104104
# the length of the forecast data
105105
"forecast_length": 1,
106+
# config for data of "forecast_length" part
107+
# for each batch, we fix length of hindcast and forecast length.
108+
# data from different lead time with a number representing the lead time,
109+
# for example, now is 2020-09-30, our min_time_interval is 1 day, hindcast length is 30 and forecast length is 1,
110+
# lead_time = 3 means 2020-09-01 to 2020-09-30, and the forecast data is 2020-10-01 from 2020-09-28
111+
# for forecast data, we have two different configurations:
112+
# 1st, we can set a same lead time for all forecast time
113+
# 2020-09-30now, 30hindcast, 2forecast, 3leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 forecast data from 2020-09-28 and 2020-10-02 forecast data from 2020-09-29
114+
# 2nd, we can set a increasing lead time for each forecast time
115+
# 2020-09-30now, 30hindcast, 2forecast, [1, 2]leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 to 2010-10-02 forecast data from 2020-09-30
116+
"lead_time_type": "fixed", # must be fixed or increasing
117+
"lead_time_start": 1,
106118
# the min time step of the input data
107119
"min_time_unit": "D",
108120
# the min time interval of the input data
@@ -153,6 +165,12 @@ def default_config_file():
153165
"geol_porostiy",
154166
"geol_permeability",
155167
],
168+
# for forecast variables such as data from GFS
169+
# for each period, they have multiple forecast data with different lead time
170+
# hence we list them as a seperate type
171+
"forecast_cols": None,
172+
# global variables such as ENSO indictors are used in some long term models
173+
"global_cols": None,
156174
# specify the data source of each variable
157175
"var_to_source_map": None,
158176
# {
@@ -164,8 +182,6 @@ def default_config_file():
164182
"constant_rm_nan": True,
165183
# if constant_only, we will only use constant data as DL models' input: this is only for dpl models now
166184
"constant_only": False,
167-
# more other cols, use dict to express!
168-
"other_cols": None,
169185
# only numerical scaler: for categorical vars, they are transformed to numerical vars when reading them
170186
"scaler": "StandardScaler",
171187
# Some parameters for the chosen scaler function, default is DapengScaler's
@@ -208,7 +224,7 @@ def default_config_file():
208224
"pbm_norm": False,
209225
},
210226
# For scaler from sklearn, we need to specify the stat_dict_file for three different parts:
211-
# target_vars, relevant_vars and constant_vars, and the sequence must be target_vars, relevant_vars, constant_vars
227+
# target_cols, relevant_vars and constant_cols, and the sequence must be target_cols, relevant_cols, constant_cols
212228
# the seperator of three stat_dict_file is ";"
213229
# for example: "stat_dict_file": "target_stat_dict_file;relevant_stat_dict_file;constant_stat_dict_file"
214230
"stat_dict_file": None,
@@ -335,6 +351,8 @@ def cmd(
335351
forecast_history=None,
336352
hindcast_length=None,
337353
forecast_length=None,
354+
lead_time_type=None,
355+
lead_time_start=None,
338356
train_mode=None,
339357
train_epoch=None,
340358
save_epoch=None,
@@ -353,7 +371,8 @@ def cmd(
353371
dropout=None,
354372
weight_path_add=None,
355373
var_t_type=None,
356-
var_o=None,
374+
var_f=None,
375+
var_g=None,
357376
var_out=None,
358377
var_to_source_map=None,
359378
out_rm_nan=0,
@@ -589,6 +608,20 @@ def cmd(
589608
default=forecast_length,
590609
type=int,
591610
)
611+
parser.add_argument(
612+
"--lead_time_type",
613+
dest="lead_time_type",
614+
help="fixed or increasing",
615+
default=lead_time_type,
616+
type=str,
617+
)
618+
parser.add_argument(
619+
"--lead_time_start",
620+
dest="lead_time_start",
621+
help="the start lead time",
622+
default=lead_time_start,
623+
type=int,
624+
)
592625
parser.add_argument(
593626
"--model_type",
594627
dest="model_type",
@@ -683,10 +716,17 @@ def cmd(
683716
nargs="+",
684717
)
685718
parser.add_argument(
686-
"--var_o",
687-
dest="var_o",
688-
help="more other inputs except for var_c and var_t",
689-
default=var_o,
719+
"--var_f",
720+
dest="var_f",
721+
help="forecast variables such as precipitation from GFS",
722+
default=var_f,
723+
type=json.loads,
724+
)
725+
parser.add_argument(
726+
"--var_g",
727+
dest="var_g",
728+
help="global variables such as ENSO indicators",
729+
default=var_g,
690730
type=json.loads,
691731
)
692732
parser.add_argument(
@@ -969,11 +1009,7 @@ def update_cfg(cfg_file, new_args):
9691009
cfg_file["training_cfgs"]["optim_params"] = {}
9701010
if new_args.var_c is not None:
9711011
# I don't find a method to receive empty list for argparse, so if we input "None" or "" or " ", we treat it as []
972-
if (
973-
new_args.var_c == ["None"]
974-
or new_args.var_c == [""]
975-
or new_args.var_c == [" "]
976-
):
1012+
if new_args.var_c in [["None"], [""], [" "]]:
9771013
cfg_file["data_cfgs"]["constant_cols"] = []
9781014
else:
9791015
cfg_file["data_cfgs"]["constant_cols"] = new_args.var_c
@@ -987,8 +1023,10 @@ def update_cfg(cfg_file, new_args):
9871023
if new_args.var_t_type is not None:
9881024
cfg_file["data_cfgs"]["relevant_types"] = new_args.var_t_type
9891025
cfg_file["data_cfgs"]["relevant_rm_nan"] = bool(new_args.t_rm_nan != 0)
990-
if new_args.var_o is not None:
991-
cfg_file["data_cfgs"]["other_cols"] = new_args.var_o
1026+
if new_args.var_f is not None:
1027+
cfg_file["data_cfgs"]["forecast_cols"] = new_args.var_f
1028+
if new_args.var_g is not None:
1029+
cfg_file["data_cfgs"]["global_cols"] = new_args.var_g
9921030
if new_args.var_out is not None:
9931031
cfg_file["data_cfgs"]["target_cols"] = new_args.var_out
9941032
print(
@@ -1073,9 +1111,10 @@ def update_cfg(cfg_file, new_args):
10731111
cfg_file["evaluation_cfgs"]["model_loader"] = new_args.model_loader
10741112
if new_args.warmup_length > 0:
10751113
cfg_file["data_cfgs"]["warmup_length"] = new_args.warmup_length
1076-
if "warmup_length" in new_args.model_hyperparam.keys() and (
1077-
not cfg_file["data_cfgs"]["warmup_length"]
1078-
== new_args.model_hyperparam["warmup_length"]
1114+
if (
1115+
"warmup_length" in new_args.model_hyperparam.keys()
1116+
and cfg_file["data_cfgs"]["warmup_length"]
1117+
!= new_args.model_hyperparam["warmup_length"]
10791118
):
10801119
raise RuntimeError(
10811120
"Please set same warmup_length in model_cfgs and data_cfgs"
@@ -1090,16 +1129,23 @@ def update_cfg(cfg_file, new_args):
10901129
cfg_file["data_cfgs"]["hindcast_length"] = new_args.forecast_history
10911130
if new_args.forecast_length is not None:
10921131
cfg_file["data_cfgs"]["forecast_length"] = new_args.forecast_length
1132+
if new_args.lead_time_type is not None:
1133+
if new_args.lead_time_type not in ["fixed", "increasing"]:
1134+
raise ValueError("lead_time_type must be 'fixed' or 'increasing'")
1135+
cfg_file["data_cfgs"]["lead_time_type"] = new_args.lead_time_type
1136+
if new_args.lead_time_start is None:
1137+
raise ValueError("lead_time_start must be set when lead_time_type is set")
1138+
cfg_file["data_cfgs"]["lead_time_start"] = new_args.lead_time_start
10931139
if new_args.start_epoch > 1:
10941140
cfg_file["training_cfgs"]["start_epoch"] = new_args.start_epoch
10951141
if new_args.stat_dict_file is not None:
10961142
stat_dict_file = new_args.stat_dict_file
10971143
if len(stat_dict_file.split(";")) > 1:
10981144
target_, relevant_, constant_ = stat_dict_file.split(";")
10991145
stat_dict_file = {
1100-
"target_vars": target_,
1101-
"relevant_vars": relevant_,
1102-
"constant_vars": constant_,
1146+
"target_cols": target_,
1147+
"relevant_cols": relevant_,
1148+
"constant_cols": constant_,
11031149
}
11041150
cfg_file["data_cfgs"]["stat_dict_file"] = stat_dict_file
11051151
if new_args.num_workers is not None and new_args.num_workers > 0:

torchhydro/datasets/data_dict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2021-12-31 11:08:29
4-
LastEditTime: 2024-11-02 21:17:44
4+
LastEditTime: 2025-04-18 08:55:29
55
LastEditors: Wenyu Ouyang
66
Description: A dict used for data source and data loader
77
FilePath: /torchhydro/torchhydro/datasets/data_dict.py
@@ -13,6 +13,7 @@
1313
BasinSingleFlowDataset,
1414
DplDataset,
1515
FlexibleDataset,
16+
ObsForeDataset,
1617
Seq2SeqDataset,
1718
SeqForecastDataset,
1819
TransformerDataset,
@@ -27,4 +28,5 @@
2728
"Seq2SeqDataset": Seq2SeqDataset,
2829
"SeqForecastDataset": SeqForecastDataset,
2930
"TransformerDataset": TransformerDataset,
31+
"ObsForeDataset": ObsForeDataset,
3032
}

0 commit comments

Comments
 (0)