Skip to content

Commit cab2e2a

Browse files
committed
test for once and 1pace eval_ways
1 parent e5c9bcf commit cab2e2a

File tree

3 files changed

+54
-12
lines changed

3 files changed

+54
-12
lines changed

torchhydro/configs/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,8 @@ def update_cfg(cfg_file, new_args):
12691269
cfg_file["training_cfgs"]["lr_scheduler"] = new_args.lr_scheduler
12701270
if new_args.valid_batch_mode is not None:
12711271
cfg_file["training_cfgs"]["valid_batch_mode"] = new_args.valid_batch_mode
1272+
if new_args.evaluator is not None:
1273+
cfg_file["evaluation_cfgs"]["evaluator"] = new_args.evaluator
12721274
# print("the updated config:\n", json.dumps(cfg_file, indent=4, ensure_ascii=False))
12731275

12741276

torchhydro/datasets/data_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,7 @@ def _create_lookup_table(self):
654654
for basin in tqdm(range(basin_coordinates), file=sys.stdout, disable=False):
655655
if not self.train_mode:
656656
# we don't need to ignore those with full nan in target vars for prediction without loss calculation
657+
# all samples should be included so that we can recover results to specified basins easily
657658
lookup.extend(
658659
(basin, f)
659660
for f in range(warmup_length, max_time_length - rho - horizon + 1)

torchhydro/trainers/train_utils.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import os
1414
import re
1515
import shutil
16-
import dask
1716
from functools import reduce
1817
from pathlib import Path
1918
import pandas as pd
@@ -280,6 +279,9 @@ def get_evaluation(
280279
)
281280
else:
282281
# TODO: need more test
282+
raise NotImplementedError(
283+
"we only support the case that the stride is equal to 1 now, others need to be implemented"
284+
)
283285
obss_xr = _rolling_once_evaluate(
284286
(basin_num, horizon, nf),
285287
target_scaler.rho,
@@ -297,17 +299,14 @@ def get_evaluation(
297299
output.reshape(batch_size, horizon, nf),
298300
)
299301
elif evaluator["eval_way"] == "1pace":
300-
# TODO: to be implemented
301-
# in this case, you should set calc_metrics = False in the evaluation config or use BasinBatchSampler in your data config
302-
# baceuse we need to calculate the metrics for each time step
303-
# but we have multi-outputs for each time step in this case
304-
o = output[:, length + prec, :].reshape(basin_num, batch_size, len(target_col))
305-
l = labels[:, length + prec, :].reshape(basin_num, batch_size, len(target_col))
306-
valdataset = valorte_data_loader.dataset
307-
preds_xr = valdataset.denormalize(output)
308-
obss_xr = valdataset.denormalize(labels)
309-
obs = obss_xr[col].to_numpy()
310-
pred = preds_xr[col].to_numpy()
302+
pace_idx = evaluator["pace_idx"]
303+
# for 1pace with pace_idx meaning which value of output was chosen to show
304+
# 1st, we need to transpose data to 4-dim to show the whole data
305+
pred = _recover_samples_to_basin(output, valorte_data_loader, pace_idx)
306+
obs = _recover_samples_to_basin(labels, valorte_data_loader, pace_idx)
307+
valte_dataset = valorte_data_loader.dataset
308+
preds_xr = valte_dataset.denormalize(pred)
309+
obss_xr = valte_dataset.denormalize(obs)
311310
elif evaluator["eval_way"] == "rolling":
312311
# TODO: to be implemented
313312
raise NotImplementedError(
@@ -318,6 +317,46 @@ def get_evaluation(
318317
return obss_xr, preds_xr
319318

320319

320+
def _recover_samples_to_basin(arr_3d, valorte_data_loader, pace_idx):
321+
"""Reorganize the 3D prediction results by basin
322+
323+
Parameters
324+
----------
325+
arr_3d : np.ndarray
326+
A 3D prediction array with the shape (total number of samples, number of time steps, number of features).
327+
valorte_data_loader: DataLoader
328+
The corresponding data loader used to obtain the basin-time index mapping.
329+
pace_idx: int
330+
Which time step was chosen to show.
331+
332+
Returns
333+
-------
334+
np.ndarray
335+
The reorganized 3D array with the shape (number of basins, length of time, number of features).
336+
"""
337+
dataset = valorte_data_loader.dataset
338+
basin_num = len(dataset.t_s_dict["sites_id"])
339+
nt = dataset.nt
340+
rho = dataset.rho
341+
warmup_len = dataset.warmup_length
342+
horizon = dataset.horizon
343+
nf = dataset.noutputvar
344+
345+
basin_array = np.full((basin_num, nt, nf), np.nan)
346+
347+
for sample_idx in range(arr_3d.shape[0]):
348+
# Get the basin and start time index corresponding to this sample
349+
basin, start_time = dataset.lookup_table[sample_idx]
350+
# Take the value at the last time step of this sample (at the position of rho + horizon)
351+
value = arr_3d[sample_idx, pace_idx, :]
352+
# Calculate the time position in the result array
353+
result_time_idx = start_time + warmup_len + rho + horizon + pace_idx
354+
# Fill in the corresponding position
355+
basin_array[basin, result_time_idx, :] = value
356+
357+
return basin_array
358+
359+
321360
def evaluate_validation(
322361
validation_data_loader,
323362
output,

0 commit comments

Comments
 (0)