|
1 | 1 | """ |
2 | 2 | Author: Wenyu Ouyang |
3 | 3 | Date: 2024-04-08 18:16:26 |
4 | | -LastEditTime: 2025-04-18 08:46:07 |
| 4 | +LastEditTime: 2025-04-27 18:36:57 |
5 | 5 | LastEditors: Wenyu Ouyang |
6 | 6 | Description: Some basic functions for training |
7 | 7 | FilePath: /HydroForecastEval/mnt/disk1/owen/code/torchhydro/torchhydro/trainers/train_utils.py |
@@ -308,15 +308,60 @@ def get_evaluation( |
308 | 308 | preds_xr = valte_dataset.denormalize(pred) |
309 | 309 | obss_xr = valte_dataset.denormalize(obs) |
310 | 310 | elif evaluator["eval_way"] == "rolling": |
311 | | - # TODO: to be implemented |
312 | | - raise NotImplementedError( |
313 | | - "we will implement this function in the future, please choose 1pace or once now" |
314 | | - ) |
| 311 | + # TODO: to be test |
| 312 | + pred = _recover_samples_to_4d(output, valorte_data_loader, evaluator["stride"]) |
| 313 | + obs = _recover_samples_to_4d(labels, valorte_data_loader, evaluator["stride"]) |
| 314 | + valte_dataset = valorte_data_loader.dataset |
| 315 | + preds_xr = valte_dataset.denormalize(pred) |
| 316 | + obss_xr = valte_dataset.denormalize(obs) |
315 | 317 | else: |
316 | 318 | raise ValueError("eval_way should be rolling or 1pace") |
317 | 319 | return obss_xr, preds_xr |
318 | 320 |
|
319 | 321 |
|
| 322 | +def _recover_samples_to_4d(arr_3d, valorte_data_loader, stride): |
| 323 | + """Reorganize the 3D prediction results to 4D |
| 324 | + TODO: to be finished |
| 325 | +
|
| 326 | + Parameters |
| 327 | + ---------- |
| 328 | + arr_3d : np.ndarray |
| 329 | + A 3D prediction array with the shape (total number of samples, number of time steps, number of features). |
| 330 | + valorte_data_loader: DataLoader |
| 331 | + The corresponding data loader used to obtain the basin-time index mapping. |
| 332 | + stride: int |
| 333 | + The stride of the rolling. |
| 334 | +
|
| 335 | + Returns |
| 336 | + ------- |
| 337 | + np.ndarray |
| 338 | + The reorganized 4D array with the shape (number of basins, length of time, forecast steps, number of features). |
| 339 | + """ |
| 340 | + dataset = valorte_data_loader.dataset |
| 341 | + batch_size = valorte_data_loader.batch_size |
| 342 | + basin_num = len(dataset.t_s_dict["sites_id"]) |
| 343 | + nt = dataset.nt |
| 344 | + rho = dataset.rho |
| 345 | + warmup_len = dataset.warmup_length |
| 346 | + horizon = dataset.horizon |
| 347 | + nf = dataset.noutputvar |
| 348 | + |
| 349 | + # Initialize the 4D array with NaN values |
| 350 | + basin_array = np.full((basin_num, nt - warmup_len - rho, horizon, nf), np.nan) |
| 351 | + |
| 352 | + for sample_idx in range(arr_3d.shape[0]): |
| 353 | + # Get the basin and start time index corresponding to this sample |
| 354 | + basin, start_time = dataset.lookup_table[sample_idx] |
| 355 | + # Take the value at the last time step of this sample (at the position of rho + horizon) |
| 356 | + value = arr_3d[sample_idx, warmup_len + rho :, :] |
| 357 | + # Calculate the time position in the result array |
| 358 | + result_time_idx = start_time + warmup_len + stride * (sample_idx % batch_size) |
| 359 | + # Fill in the corresponding position |
| 360 | + basin_array[basin, result_time_idx, :, :] = value |
| 361 | + |
| 362 | + return basin_array |
| 363 | + |
| 364 | + |
320 | 365 | def _recover_samples_to_basin(arr_3d, valorte_data_loader, pace_idx): |
321 | 366 | """Reorganize the 3D prediction results by basin |
322 | 367 |
|
|
0 commit comments