|
1 | | -import tempfile |
2 | 1 | import pandas as pd |
3 | 2 | import pytest |
4 | 3 | import numpy as np |
| 4 | +from sklearn.discriminant_analysis import StandardScaler |
5 | 5 | import xarray as xr |
6 | | -import json |
7 | 6 | import os |
8 | | -from torchhydro.datasets.data_scalers import DapengScaler |
| 7 | +import pickle as pkl |
| 8 | + |
| 9 | +from torchhydro.datasets.data_scalers import DapengScaler, ScalerHub |
9 | 10 | from hydrodatasource.reader.data_source import SelfMadeHydroDataset |
10 | 11 |
|
11 | 12 |
|
@@ -128,3 +129,81 @@ def test_dapeng_scaler_load_data_and_denorm(sample_data): |
128 | 129 | denorm_y.coords[coord].values, |
129 | 130 | err_msg=f"{coord} is inconsistent", |
130 | 131 | ) |
| 132 | + |
| 133 | + |
| 134 | +def test_sklearn_scale_train_mode(sample_data): |
| 135 | + target_vars, relevant_vars, constant_vars, data_cfgs = sample_data |
| 136 | + scaler_hub = ScalerHub( |
| 137 | + target_vars=target_vars, |
| 138 | + relevant_vars=relevant_vars, |
| 139 | + constant_vars=constant_vars, |
| 140 | + data_cfgs=data_cfgs, |
| 141 | + is_tra_val_te="train", |
| 142 | + ) |
| 143 | + norm_key = "target_vars" |
| 144 | + scaler = StandardScaler() |
| 145 | + data_tmp = target_vars.to_numpy().reshape(-1, target_vars.shape[-1]) |
| 146 | + |
| 147 | + # Call the _sklearn_scale method |
| 148 | + scaler, data_norm = scaler_hub._sklearn_scale( |
| 149 | + data_cfgs, "train", norm_key, scaler, data_tmp |
| 150 | + ) |
| 151 | + |
| 152 | + # Check if the scaler is fitted and data is normalized |
| 153 | + assert hasattr(scaler, "mean_"), "Scaler is not fitted" |
| 154 | + assert data_norm.shape == data_tmp.shape, "Normalized data shape mismatch" |
| 155 | + |
| 156 | + # Check if the scaler file is saved |
| 157 | + save_file = os.path.join(data_cfgs["case_dir"], f"{norm_key}_scaler.pkl") |
| 158 | + assert os.path.isfile(save_file), "Scaler file was not saved" |
| 159 | + |
| 160 | + |
| 161 | +def test_sklearn_scale_test_mode_with_existing_scaler(sample_data): |
| 162 | + target_vars, relevant_vars, constant_vars, data_cfgs = sample_data |
| 163 | + scaler_hub = ScalerHub( |
| 164 | + target_vars=target_vars, |
| 165 | + relevant_vars=relevant_vars, |
| 166 | + constant_vars=constant_vars, |
| 167 | + data_cfgs=data_cfgs, |
| 168 | + is_tra_val_te="train", |
| 169 | + ) |
| 170 | + norm_key = "target_vars" |
| 171 | + scaler = StandardScaler() |
| 172 | + data_tmp = target_vars.to_numpy().reshape(-1, target_vars.shape[-1]) |
| 173 | + |
| 174 | + # Save a pre-fitted scaler for testing |
| 175 | + save_file = os.path.join(data_cfgs["case_dir"], f"{norm_key}_scaler.pkl") |
| 176 | + with open(save_file, "wb") as outfile: |
| 177 | + pkl.dump(scaler.fit(data_tmp), outfile) |
| 178 | + |
| 179 | + # Call the _sklearn_scale method in test mode |
| 180 | + scaler, data_norm = scaler_hub._sklearn_scale( |
| 181 | + data_cfgs, "test", norm_key, scaler, data_tmp |
| 182 | + ) |
| 183 | + |
| 184 | + # Check if the scaler is loaded and data is normalized |
| 185 | + assert hasattr(scaler, "mean_"), "Scaler is not loaded correctly" |
| 186 | + assert data_norm.shape == data_tmp.shape, "Normalized data shape mismatch" |
| 187 | + |
| 188 | + |
| 189 | +def test_sklearn_scale_test_mode_without_scaler_file(sample_data): |
| 190 | + target_vars, relevant_vars, constant_vars, data_cfgs = sample_data |
| 191 | + scaler_hub = ScalerHub( |
| 192 | + target_vars=target_vars, |
| 193 | + relevant_vars=relevant_vars, |
| 194 | + constant_vars=constant_vars, |
| 195 | + data_cfgs=data_cfgs, |
| 196 | + is_tra_val_te="test", |
| 197 | + ) |
| 198 | + norm_key = "target_vars" |
| 199 | + scaler = StandardScaler() |
| 200 | + data_tmp = target_vars.to_numpy().reshape(-1, target_vars.shape[-1]) |
| 201 | + |
| 202 | + # Ensure no scaler file exists |
| 203 | + save_file = os.path.join(data_cfgs["case_dir"], f"{norm_key}_scaler.pkl") |
| 204 | + if os.path.isfile(save_file): |
| 205 | + os.remove(save_file) |
| 206 | + |
| 207 | + # Expect a FileNotFoundError |
| 208 | + with pytest.raises(FileNotFoundError): |
| 209 | + scaler_hub._sklearn_scale(data_cfgs, "test", norm_key, scaler, data_tmp) |
0 commit comments