Skip to content

Commit 08d5c90

Browse files
committed
add rolling eval func but not finished yet
1 parent 3215c36 commit 08d5c90

File tree

3 files changed

+56
-11
lines changed

3 files changed

+56
-11
lines changed

torchhydro/configs/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ def default_config_file():
237237
# for each batch, we fix length of hindcast and forecast length.
238238
# data from different lead time with a number representing the lead time,
239239
# for example, now is 2020-09-30, our min_time_interval is 1 day, hindcast length is 30 and forecast length is 1,
240-
# lead_time = 3 means 2020-09-01 to 2020-09-30, and the forecast data is 2020-10-01 from 2020-09-28
240+
# lead_time = 3 means 2020-09-01 to 2020-09-30, and the forecast data is 2020-10-01 forecast-performed at 2020-09-28
241241
# for forecast data, we have two different configurations:
242-
# 1st, we can set a same lead time for all forecast time
242+
# 1st "fixed", we can set a same lead time for all forecast time
243243
# 2020-09-30now, 30hindcast, 2forecast, 3leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 forecast data from 2020-09-28 and 2020-10-02 forecast data from 2020-09-29
244-
# 2nd, we can set a increasing lead time for each forecast time
244+
# 2nd "increasing", we can set a increasing lead time for each forecast time
245245
# 2020-09-30now, 30hindcast, 2forecast, [1, 2]leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 to 2010-10-02 forecast data from 2020-09-30
246246
"lead_time_type": "fixed", # must be fixed or increasing
247247
"lead_time_start": 1,

torchhydro/datasets/data_sets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2024-04-08 18:16:53
4-
LastEditTime: 2025-04-19 17:35:29
4+
LastEditTime: 2025-04-27 14:28:24
55
LastEditors: Wenyu Ouyang
66
Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology
7-
FilePath: /torchhydro/torchhydro/datasets/data_sets.py
7+
FilePath: /HydroForecastEval/mnt/disk1/owen/code/torchhydro/torchhydro/datasets/data_sets.py
88
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
99
"""
1010

@@ -806,7 +806,7 @@ def _concat_xf(self, x, f):
806806
for x_idx, f_idx in self.xf_var_indices.items():
807807
# Replace the variables in the forecast period of x with the forecast variables in f
808808
# The forecast period of x starts from the rho position
809-
x_combined[self.rho :, x_idx] = f[:, f_idx]
809+
x_combined[self.warmup_length + self.rho :, x_idx] = f[:, f_idx]
810810

811811
return x_combined
812812

torchhydro/trainers/train_utils.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2024-04-08 18:16:26
4-
LastEditTime: 2025-04-18 08:46:07
4+
LastEditTime: 2025-04-27 18:36:57
55
LastEditors: Wenyu Ouyang
66
Description: Some basic functions for training
77
FilePath: /HydroForecastEval/mnt/disk1/owen/code/torchhydro/torchhydro/trainers/train_utils.py
@@ -308,15 +308,60 @@ def get_evaluation(
308308
preds_xr = valte_dataset.denormalize(pred)
309309
obss_xr = valte_dataset.denormalize(obs)
310310
elif evaluator["eval_way"] == "rolling":
311-
# TODO: to be implemented
312-
raise NotImplementedError(
313-
"we will implement this function in the future, please choose 1pace or once now"
314-
)
311+
# TODO: to be test
312+
pred = _recover_samples_to_4d(output, valorte_data_loader, evaluator["stride"])
313+
obs = _recover_samples_to_4d(labels, valorte_data_loader, evaluator["stride"])
314+
valte_dataset = valorte_data_loader.dataset
315+
preds_xr = valte_dataset.denormalize(pred)
316+
obss_xr = valte_dataset.denormalize(obs)
315317
else:
316318
raise ValueError("eval_way should be rolling or 1pace")
317319
return obss_xr, preds_xr
318320

319321

322+
def _recover_samples_to_4d(arr_3d, valorte_data_loader, stride):
323+
"""Reorganize the 3D prediction results to 4D
324+
TODO: to be finished
325+
326+
Parameters
327+
----------
328+
arr_3d : np.ndarray
329+
A 3D prediction array with the shape (total number of samples, number of time steps, number of features).
330+
valorte_data_loader: DataLoader
331+
The corresponding data loader used to obtain the basin-time index mapping.
332+
stride: int
333+
The stride of the rolling.
334+
335+
Returns
336+
-------
337+
np.ndarray
338+
The reorganized 4D array with the shape (number of basins, length of time, forecast steps, number of features).
339+
"""
340+
dataset = valorte_data_loader.dataset
341+
batch_size = valorte_data_loader.batch_size
342+
basin_num = len(dataset.t_s_dict["sites_id"])
343+
nt = dataset.nt
344+
rho = dataset.rho
345+
warmup_len = dataset.warmup_length
346+
horizon = dataset.horizon
347+
nf = dataset.noutputvar
348+
349+
# Initialize the 4D array with NaN values
350+
basin_array = np.full((basin_num, nt - warmup_len - rho, horizon, nf), np.nan)
351+
352+
for sample_idx in range(arr_3d.shape[0]):
353+
# Get the basin and start time index corresponding to this sample
354+
basin, start_time = dataset.lookup_table[sample_idx]
355+
# Take the value at the last time step of this sample (at the position of rho + horizon)
356+
value = arr_3d[sample_idx, warmup_len + rho :, :]
357+
# Calculate the time position in the result array
358+
result_time_idx = start_time + warmup_len + stride * (sample_idx % batch_size)
359+
# Fill in the corresponding position
360+
basin_array[basin, result_time_idx, :, :] = value
361+
362+
return basin_array
363+
364+
320365
def _recover_samples_to_basin(arr_3d, valorte_data_loader, pace_idx):
321366
"""Reorganize the 3D prediction results by basin
322367

0 commit comments

Comments
 (0)