Skip to content

Commit fb16bf6

Browse files
committed
refactor rolling evaluate func; refactor denormalize -- set it into dataset
1 parent d23d905 commit fb16bf6

File tree

6 files changed

+173
-114
lines changed

6 files changed

+173
-114
lines changed

torchhydro/datasets/data_scalers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2024-04-08 18:17:44
4-
LastEditTime: 2024-11-05 09:21:24
4+
LastEditTime: 2025-01-12 15:23:28
55
LastEditors: Wenyu Ouyang
66
Description: normalize the data
77
FilePath: \torchhydro\torchhydro\datasets\data_scalers.py

torchhydro/datasets/data_sets.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2024-04-08 18:16:53
4-
LastEditTime: 2025-01-02 14:34:59
4+
LastEditTime: 2025-01-12 15:16:28
55
LastEditors: Wenyu Ouyang
66
Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology
7-
FilePath: /torchhydro/torchhydro/datasets/data_sets.py
7+
FilePath: \torchhydro\torchhydro\datasets\data_sets.py
88
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -26,6 +26,7 @@
2626
from torchhydro.datasets.data_sources import data_sources_dict
2727

2828
from torchhydro.datasets.data_utils import (
29+
set_unit_to_var,
2930
warn_if_nan,
3031
wrap_t_s_dict,
3132
)
@@ -279,6 +280,49 @@ def _normalize(self):
279280
self.target_scaler = scaler_hub.target_scaler
280281
return scaler_hub.x, scaler_hub.y, scaler_hub.c
281282

283+
def denormalize(self, norm_data, rolling=0):
284+
"""Denormalize the norm_data
285+
286+
Parameters
287+
----------
288+
norm_data : np.ndarray
289+
batch-first data
290+
rolling: int
291+
default 0, if rolling is used, perform forecasting using rolling window size
292+
293+
Returns
294+
-------
295+
xr.Dataset
296+
denormlized data
297+
"""
298+
target_scaler = self.target_scaler
299+
target_data = target_scaler.data_target
300+
# the units are dimensionless for pure DL models
301+
units = {k: "dimensionless" for k in target_data.attrs["units"].keys()}
302+
if target_scaler.pbm_norm:
303+
units = {**units, **target_data.attrs["units"]}
304+
if rolling > 0:
305+
hindcast_output_window = target_scaler.data_cfgs["hindcast_output_window"]
306+
rho = target_scaler.data_cfgs["hindcast_length"]
307+
# TODO: -1 because seq2seqdataset has one more time, hence we need to cut it, as rolling will be refactored, we will modify it later
308+
selected_time_points = target_data.coords["time"][
309+
rho - hindcast_output_window : -1
310+
]
311+
else:
312+
warmup_length = self.warmup_length
313+
selected_time_points = target_data.coords["time"][warmup_length:]
314+
315+
selected_data = target_data.sel(time=selected_time_points)
316+
denorm_xr_ds = target_scaler.inverse_transform(
317+
xr.DataArray(
318+
norm_data.transpose(2, 0, 1),
319+
dims=selected_data.dims,
320+
coords=selected_data.coords,
321+
attrs={"units": units},
322+
)
323+
)
324+
return set_unit_to_var(denorm_xr_ds)
325+
282326
def _to_dataarray_with_unit(self, data_forcing_ds, data_output_ds, data_attr_ds):
283327
# trans to dataarray to better use xbatch
284328
if data_output_ds is not None:

torchhydro/datasets/data_utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2023-09-21 15:37:58
4-
LastEditTime: 2025-01-02 14:06:24
4+
LastEditTime: 2025-01-12 15:31:29
55
LastEditors: Wenyu Ouyang
66
Description: Some basic funtions for dealing with data
7-
FilePath: /torchhydro/torchhydro/datasets/data_utils.py
7+
FilePath: \torchhydro\torchhydro\datasets\data_utils.py
88
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -390,3 +390,26 @@ def dam_num_chosen(gages, usgs_id, dam_num):
390390
usgs_id[i] for i in range(data_attr.size) if data_attr[:, 0][i] == dam_num
391391
]
392392
)
393+
394+
395+
def set_unit_to_var(ds):
396+
"""returned xa.Dataset need has units for each variable -- xr.DataArray
397+
or the dataset cannot be saved to netCDF file
398+
399+
Parameters
400+
----------
401+
ds : xr.Dataset
402+
the dataset with units as attributes
403+
404+
Returns
405+
-------
406+
ds : xr.Dataset
407+
unit attrs are for each variable dataarray
408+
"""
409+
units_dict = ds.attrs["units"]
410+
for var_name, units in units_dict.items():
411+
if var_name in ds:
412+
ds[var_name].attrs["units"] = units
413+
if "units" in ds.attrs:
414+
del ds.attrs["units"]
415+
return ds

torchhydro/trainers/deep_hydro.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2024-04-08 18:15:48
4-
LastEditTime: 2025-01-09 12:17:20
4+
LastEditTime: 2025-01-12 14:57:18
55
LastEditors: Wenyu Ouyang
66
Description: HydroDL model class
7-
FilePath: /torchhydro/torchhydro/trainers/deep_hydro.py
7+
FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py
88
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -42,8 +42,8 @@
4242
from torchhydro.trainers.train_logger import TrainLogger
4343
from torchhydro.trainers.train_utils import (
4444
EarlyStopper,
45+
rolling_evaluate,
4546
average_weights,
46-
denormalize4eval,
4747
evaluate_validation,
4848
compute_validation,
4949
model_infer,
@@ -399,37 +399,31 @@ def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
399399
obs = obs.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
400400

401401
if evaluation_cfgs["rolling"] > 0:
402-
if evaluation_cfgs["rolling"] != data_cfgs["forecast_length"]:
403-
raise NotImplementedError(
404-
"rolling should be equal to forecast_length in data_cfgs now, others are not supported yet"
405-
)
406-
# TODO: now we only guarantee each time has only one value,
407-
# so we directly reshape the data rather than a real rolling
408402
ngrid = self.testdataset.ngrid
409403
nt = self.testdataset.nt
410-
target_len = len(data_cfgs["target_cols"])
411-
hindcast_output_window = data_cfgs["hindcast_output_window"]
404+
nf = len(data_cfgs["target_cols"])
405+
rolling = evaluation_cfgs["rolling"]
412406
forecast_length = data_cfgs["forecast_length"]
413-
window_size = hindcast_output_window + forecast_length
407+
hindcast_output_window = data_cfgs["hindcast_output_window"]
414408
rho = data_cfgs["hindcast_length"]
415-
recover_len = nt - rho + hindcast_output_window
416-
samples = int(pred.shape[0] / ngrid)
417-
pred_ = np.full((ngrid, recover_len, target_len), np.nan)
418-
obs_ = np.full((ngrid, recover_len, target_len), np.nan)
419-
# recover pred to pred_ and obs to obs_
420-
pred_4d = pred.reshape(ngrid, samples, window_size, target_len)
421-
obs_4d = obs.reshape(ngrid, samples, window_size, target_len)
422-
for i in range(ngrid):
423-
for j in range(0, recover_len - window_size + 1, window_size):
424-
pred_[i, j : j + window_size, :] = pred_4d[i, j, :, :]
425-
for i in range(ngrid):
426-
for j in range(0, recover_len - window_size + 1, window_size):
427-
obs_[i, j : j + window_size, :] = obs_4d[i, j, :, :]
428-
pred = pred_.reshape(ngrid, recover_len, target_len)
429-
obs = obs_.reshape(ngrid, recover_len, target_len)
430-
pred_xr, obs_xr = denormalize4eval(
431-
test_dataloader, pred, obs, rolling=evaluation_cfgs["rolling"]
432-
)
409+
pred = rolling_evaluate(
410+
(ngrid, nt, nf),
411+
rho,
412+
forecast_length,
413+
rolling,
414+
hindcast_output_window,
415+
pred,
416+
)
417+
obs = rolling_evaluate(
418+
(ngrid, nt, nf),
419+
rho,
420+
forecast_length,
421+
rolling,
422+
hindcast_output_window,
423+
obs,
424+
)
425+
pred_xr = self.testdataset.denormalize(pred, rolling=evaluation_cfgs["rolling"])
426+
obs_xr = self.testdataset.denormalize(obs, rolling=evaluation_cfgs["rolling"])
433427
return pred_xr, obs_xr
434428

435429
def _get_optimizer(self, training_cfgs):

torchhydro/trainers/resulter.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,6 @@
3131
from torchhydro.trainers.deep_hydro import DeepHydro
3232

3333

34-
def set_unit_to_var(ds):
35-
units_dict = ds.attrs["units"]
36-
for var_name, units in units_dict.items():
37-
if var_name in ds:
38-
ds[var_name].attrs["units"] = units
39-
if "units" in ds.attrs:
40-
del ds.attrs["units"]
41-
return ds
42-
43-
4434
class Resulter:
4535
def __init__(self, cfgs) -> None:
4636
self.cfgs = cfgs
@@ -112,8 +102,6 @@ def save_result(self, pred, obs):
112102
save_dir = self.result_dir
113103
flow_pred_file = os.path.join(save_dir, self.pred_name)
114104
flow_obs_file = os.path.join(save_dir, self.obs_name)
115-
pred = set_unit_to_var(pred)
116-
obs = set_unit_to_var(obs)
117105
pred.to_netcdf(flow_pred_file + ".nc")
118106
obs.to_netcdf(flow_obs_file + ".nc")
119107

0 commit comments

Comments
 (0)