Skip to content

Commit 7b096f6

Browse files
committed
add dropout for simplelstm; add scalers from sklearn for torchhydro
1 parent f6651f4 commit 7b096f6

File tree

6 files changed

+293
-119
lines changed

6 files changed

+293
-119
lines changed

tests/test_data_scalers.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import tempfile
21
import pandas as pd
32
import pytest
43
import numpy as np
4+
from sklearn.discriminant_analysis import StandardScaler
55
import xarray as xr
6-
import json
76
import os
8-
from torchhydro.datasets.data_scalers import DapengScaler
7+
import pickle as pkl
8+
9+
from torchhydro.datasets.data_scalers import DapengScaler, ScalerHub
910
from hydrodatasource.reader.data_source import SelfMadeHydroDataset
1011

1112

@@ -128,3 +129,81 @@ def test_dapeng_scaler_load_data_and_denorm(sample_data):
128129
denorm_y.coords[coord].values,
129130
err_msg=f"{coord} is inconsistent",
130131
)
132+
133+
134+
def test_sklearn_scale_train_mode(sample_data):
135+
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
136+
scaler_hub = ScalerHub(
137+
target_vars=target_vars,
138+
relevant_vars=relevant_vars,
139+
constant_vars=constant_vars,
140+
data_cfgs=data_cfgs,
141+
is_tra_val_te="train",
142+
)
143+
norm_key = "target_vars"
144+
scaler = StandardScaler()
145+
data_tmp = target_vars.to_numpy().reshape(-1, target_vars.shape[-1])
146+
147+
# Call the _sklearn_scale method
148+
scaler, data_norm = scaler_hub._sklearn_scale(
149+
data_cfgs, "train", norm_key, scaler, data_tmp
150+
)
151+
152+
# Check if the scaler is fitted and data is normalized
153+
assert hasattr(scaler, "mean_"), "Scaler is not fitted"
154+
assert data_norm.shape == data_tmp.shape, "Normalized data shape mismatch"
155+
156+
# Check if the scaler file is saved
157+
save_file = os.path.join(data_cfgs["case_dir"], f"{norm_key}_scaler.pkl")
158+
assert os.path.isfile(save_file), "Scaler file was not saved"
159+
160+
161+
def test_sklearn_scale_test_mode_with_existing_scaler(sample_data):
162+
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
163+
scaler_hub = ScalerHub(
164+
target_vars=target_vars,
165+
relevant_vars=relevant_vars,
166+
constant_vars=constant_vars,
167+
data_cfgs=data_cfgs,
168+
is_tra_val_te="train",
169+
)
170+
norm_key = "target_vars"
171+
scaler = StandardScaler()
172+
data_tmp = target_vars.to_numpy().reshape(-1, target_vars.shape[-1])
173+
174+
# Save a pre-fitted scaler for testing
175+
save_file = os.path.join(data_cfgs["case_dir"], f"{norm_key}_scaler.pkl")
176+
with open(save_file, "wb") as outfile:
177+
pkl.dump(scaler.fit(data_tmp), outfile)
178+
179+
# Call the _sklearn_scale method in test mode
180+
scaler, data_norm = scaler_hub._sklearn_scale(
181+
data_cfgs, "test", norm_key, scaler, data_tmp
182+
)
183+
184+
# Check if the scaler is loaded and data is normalized
185+
assert hasattr(scaler, "mean_"), "Scaler is not loaded correctly"
186+
assert data_norm.shape == data_tmp.shape, "Normalized data shape mismatch"
187+
188+
189+
def test_sklearn_scale_test_mode_without_scaler_file(sample_data):
190+
target_vars, relevant_vars, constant_vars, data_cfgs = sample_data
191+
scaler_hub = ScalerHub(
192+
target_vars=target_vars,
193+
relevant_vars=relevant_vars,
194+
constant_vars=constant_vars,
195+
data_cfgs=data_cfgs,
196+
is_tra_val_te="test",
197+
)
198+
norm_key = "target_vars"
199+
scaler = StandardScaler()
200+
data_tmp = target_vars.to_numpy().reshape(-1, target_vars.shape[-1])
201+
202+
# Ensure no scaler file exists
203+
save_file = os.path.join(data_cfgs["case_dir"], f"{norm_key}_scaler.pkl")
204+
if os.path.isfile(save_file):
205+
os.remove(save_file)
206+
207+
# Expect a FileNotFoundError
208+
with pytest.raises(FileNotFoundError):
209+
scaler_hub._sklearn_scale(data_cfgs, "test", norm_key, scaler, data_tmp)

torchhydro/configs/config.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2021-12-31 11:08:29
4-
LastEditTime: 2025-01-12 10:12:48
4+
LastEditTime: 2025-04-17 10:11:03
55
LastEditors: Wenyu Ouyang
66
Description: Config for hydroDL
7-
FilePath: \torchhydro\torchhydro\configs\config.py
7+
FilePath: /torchhydro/torchhydro/configs/config.py
88
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -207,6 +207,10 @@ def default_config_file():
207207
# NOTE: pbm_norm is True means norm and denorm for differentiable models; if you use pure data-driven models, you should set it as False
208208
"pbm_norm": False,
209209
},
210+
# For scaler from sklearn, we need to specify the stat_dict_file for three different parts:
211+
# target_vars, relevant_vars and constant_vars, and the sequence must be target_vars, relevant_vars, constant_vars
212+
# the seperator of three stat_dict_file is ";"
213+
# for example: "stat_dict_file": "target_stat_dict_file;relevant_stat_dict_file;constant_stat_dict_file"
210214
"stat_dict_file": None,
211215
# dataset for pytorch dataset
212216
"dataset": "StreamflowDataset",
@@ -232,13 +236,13 @@ def default_config_file():
232236
# start from 0, each value means the decay rate
233237
# if initial lr is 0.001, then 0: 0.5 neans the lr of 0 epoch is 0.001*0.5=0.0005
234238
# "lr_scheduler": {0: 1, 1: 0.5, 2: 0.2},
235-
# 3rd opt config, lr as a initial value (will cover the lr setting in "optim_params")
239+
# 3rd opt config, initial lr need to be set in "optim_params" or it will use default one
236240
# lr_factor as an exponential decay factor
237-
# "lr": 0.001, "lr_factor": 0.1,
238-
# 4th opt config, lr as a initial value, it will cover the lr setting in "optim_params"
241+
# "lr_factor": 0.1,
242+
# 4th opt config, initial lr need to be set in "optim_params" or it will use default one
239243
# lr_patience represent how many epochs without opt (we watch val_loss) could be tolerated
240244
# if lr_patience is satisfied, then lr will be decayed by lr_factor by a linear way
241-
# "lr": 0.001, "lr_factor": 0.1, "lr_patience": 1,
245+
# "lr_factor": 0.1, "lr_patience": 1,
242246
},
243247
"early_stopping": False,
244248
"patience": 1,
@@ -283,7 +287,7 @@ def default_config_file():
283287
"model_loader": {"load_way": "specified", "test_epoch": 20},
284288
# "model_loader": {"load_way": "best"},
285289
# "model_loader": {"load_way": "latest"},
286-
# "model_loader": {"load_way": "pth", "pth": "path/to/weights"},
290+
# "model_loader": {"load_way": "pth", "pth_path": "path/to/weights"},
287291
"metrics": ["NSE", "RMSE", "R2", "KGE", "FHV", "FLV"],
288292
"fill_nan": "no",
289293
"explainer": None,
@@ -1087,7 +1091,15 @@ def update_cfg(cfg_file, new_args):
10871091
if new_args.start_epoch > 1:
10881092
cfg_file["training_cfgs"]["start_epoch"] = new_args.start_epoch
10891093
if new_args.stat_dict_file is not None:
1090-
cfg_file["data_cfgs"]["stat_dict_file"] = new_args.stat_dict_file
1094+
stat_dict_file = new_args.stat_dict_file
1095+
if len(stat_dict_file.split(";")) > 1:
1096+
target_, relevant_, constant_ = stat_dict_file.split(";")
1097+
stat_dict_file = {
1098+
"target_vars": target_,
1099+
"relevant_vars": relevant_,
1100+
"constant_vars": constant_,
1101+
}
1102+
cfg_file["data_cfgs"]["stat_dict_file"] = stat_dict_file
10911103
if new_args.num_workers is not None and new_args.num_workers > 0:
10921104
cfg_file["training_cfgs"]["num_workers"] = new_args.num_workers
10931105
if new_args.which_first_tensor is not None:

0 commit comments

Comments
 (0)