1313import os
1414import re
1515import shutil
16- import dask
1716from functools import reduce
1817from pathlib import Path
1918import pandas as pd
@@ -280,6 +279,9 @@ def get_evaluation(
280279 )
281280 else :
282281 # TODO: need more test
282+ raise NotImplementedError (
283+ "we only support the case that the stride is equal to 1 now, others need to be implemented"
284+ )
283285 obss_xr = _rolling_once_evaluate (
284286 (basin_num , horizon , nf ),
285287 target_scaler .rho ,
@@ -297,17 +299,14 @@ def get_evaluation(
297299 output .reshape (batch_size , horizon , nf ),
298300 )
299301 elif evaluator ["eval_way" ] == "1pace" :
300- # TODO: to be implemented
301- # in this case, you should set calc_metrics = False in the evaluation config or use BasinBatchSampler in your data config
302- # baceuse we need to calculate the metrics for each time step
303- # but we have multi-outputs for each time step in this case
304- o = output [:, length + prec , :].reshape (basin_num , batch_size , len (target_col ))
305- l = labels [:, length + prec , :].reshape (basin_num , batch_size , len (target_col ))
306- valdataset = valorte_data_loader .dataset
307- preds_xr = valdataset .denormalize (output )
308- obss_xr = valdataset .denormalize (labels )
309- obs = obss_xr [col ].to_numpy ()
310- pred = preds_xr [col ].to_numpy ()
302+ pace_idx = evaluator ["pace_idx" ]
303+ # for 1pace with pace_idx meaning which value of output was chosen to show
304+ # 1st, we need to transpose data to 4-dim to show the whole data
305+ pred = _recover_samples_to_basin (output , valorte_data_loader , pace_idx )
306+ obs = _recover_samples_to_basin (labels , valorte_data_loader , pace_idx )
307+ valte_dataset = valorte_data_loader .dataset
308+ preds_xr = valte_dataset .denormalize (pred )
309+ obss_xr = valte_dataset .denormalize (obs )
311310 elif evaluator ["eval_way" ] == "rolling" :
312311 # TODO: to be implemented
313312 raise NotImplementedError (
@@ -318,6 +317,46 @@ def get_evaluation(
318317 return obss_xr , preds_xr
319318
320319
320+ def _recover_samples_to_basin (arr_3d , valorte_data_loader , pace_idx ):
321+ """Reorganize the 3D prediction results by basin
322+
323+ Parameters
324+ ----------
325+ arr_3d : np.ndarray
326+ A 3D prediction array with the shape (total number of samples, number of time steps, number of features).
327+ valorte_data_loader: DataLoader
328+ The corresponding data loader used to obtain the basin-time index mapping.
329+ pace_idx: int
330+ Which time step was chosen to show.
331+
332+ Returns
333+ -------
334+ np.ndarray
335+ The reorganized 3D array with the shape (number of basins, length of time, number of features).
336+ """
337+ dataset = valorte_data_loader .dataset
338+ basin_num = len (dataset .t_s_dict ["sites_id" ])
339+ nt = dataset .nt
340+ rho = dataset .rho
341+ warmup_len = dataset .warmup_length
342+ horizon = dataset .horizon
343+ nf = dataset .noutputvar
344+
345+ basin_array = np .full ((basin_num , nt , nf ), np .nan )
346+
347+ for sample_idx in range (arr_3d .shape [0 ]):
348+ # Get the basin and start time index corresponding to this sample
349+ basin , start_time = dataset .lookup_table [sample_idx ]
350+ # Take the value at the last time step of this sample (at the position of rho + horizon)
351+ value = arr_3d [sample_idx , pace_idx , :]
352+ # Calculate the time position in the result array
353+ result_time_idx = start_time + warmup_len + rho + horizon + pace_idx
354+ # Fill in the corresponding position
355+ basin_array [basin , result_time_idx , :] = value
356+
357+ return basin_array
358+
359+
321360def evaluate_validation (
322361 validation_data_loader ,
323362 output ,
0 commit comments