Skip to content

Commit e5c9bcf

Browse files
committed
refactor evaluation part
1 parent a9f3042 commit e5c9bcf

File tree

4 files changed

+177
-143
lines changed

4 files changed

+177
-143
lines changed

torchhydro/configs/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,26 @@ def default_config_file():
332332
# 0 means all testing periods belong to forecast periods without hindcast part
333333
"current_idx": 0,
334334
"calc_metrics": True,
335+
# we provide some different evaluators:
336+
# 1st -- once: for each time each var and each basin, only one result is evaluated
337+
# stride means if rolling is true, after evaluating, we need a stride to skip some periods
338+
# 2nd -- 1pace: we only chose one pace from results to evaluate
339+
# -1 means we chose the final result of each sample which will be used in hindcast-only/forecast-only model inference
340+
# 1 means we chose the first result of each sample which will be used in hindcast-forecast model inference
341+
# 3rd -- rolling: we perform evaluation for each sample of each basin,
342+
# stride means we will perform evaluation for each sample after stride periods
343+
"evaluator": {
344+
"eval_way": "once",
345+
"stride": 0,
346+
},
347+
# "evaluator": {
348+
# "eval_way": "1pace",
349+
# "pace_idx": -1,
350+
# },
351+
# "evaluator": {
352+
# "eval_way": "rolling",
353+
# "stride": 1,
354+
# },
335355
},
336356
}
337357

@@ -418,6 +438,7 @@ def cmd(
418438
min_time_unit=None,
419439
min_time_interval=None,
420440
valid_batch_mode=None,
441+
evaluator=None,
421442
):
422443
"""input args from cmd"""
423444
parser = argparse.ArgumentParser(
@@ -958,6 +979,13 @@ def cmd(
958979
help="The batch organization mode of valid data, train means same as train; test means same as test",
959980
default=valid_batch_mode,
960981
)
982+
parser.add_argument(
983+
"--evaluator",
984+
dest="evaluator",
985+
help="evaluation way",
986+
default=evaluator,
987+
type=json.loads,
988+
)
961989
# To make pytest work in PyCharm, here we use the following code instead of "args = parser.parse_args()":
962990
# https://blog.csdn.net/u014742995/article/details/100119905
963991
args, unknown = parser.parse_known_args()

torchhydro/datasets/data_sets.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,16 @@ def ngrid(self):
219219
"""
220220
return len(self.basins)
221221

222+
@property
223+
def noutputvar(self):
224+
"""How many output variables in the dataset
225+
Returns
226+
-------
227+
int
228+
number of variables
229+
"""
230+
return len(self.data_cfgs["target_cols"])
231+
222232
@property
223233
def nt(self):
224234
"""length of longest time series in all basins
@@ -377,15 +387,17 @@ def _normalize(
377387
self.target_scaler = scaler_hub.target_scaler
378388
return scaler_hub.norm_data
379389

380-
def denormalize(self, norm_data, rolling=0):
390+
def denormalize(self, norm_data, is_real_time=True):
381391
"""Denormalize the norm_data
382392
383393
Parameters
384394
----------
385395
norm_data : np.ndarray
386396
batch-first data
387-
rolling: int
388-
default 0, if rolling is used, perform forecasting using rolling window size
397+
is_real_time : bool, optional
398+
whether the data is real time data, by default True
399+
sometimes we may have multiple results for one time period and we flatten them
400+
so we need a temp time to replace real one
389401
390402
Returns
391403
-------
@@ -398,17 +410,8 @@ def denormalize(self, norm_data, rolling=0):
398410
units = {k: "dimensionless" for k in target_data.attrs["units"].keys()}
399411
if target_scaler.pbm_norm:
400412
units = {**units, **target_data.attrs["units"]}
401-
if rolling > 0:
402-
hindcast_output_window = target_scaler.data_cfgs["hindcast_output_window"]
403-
rho = target_scaler.training_cfgs["hindcast_length"]
404-
# TODO: -1 because seq2seqdataset has one more time, hence we need to cut it, as rolling will be refactored, we will modify it later
405-
selected_time_points = target_data.coords["time"][
406-
rho - hindcast_output_window : -1
407-
]
408-
else:
409-
warmup_length = self.warmup_length
410-
selected_time_points = target_data.coords["time"][warmup_length:]
411-
413+
warmup_length = self.warmup_length
414+
selected_time_points = target_data.coords["time"][warmup_length:]
412415
selected_data = target_data.sel(time=selected_time_points)
413416
denorm_xr_ds = target_scaler.inverse_transform(
414417
xr.DataArray(

torchhydro/trainers/deep_hydro.py

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
from torchhydro.trainers.train_logger import TrainLogger
4444
from torchhydro.trainers.train_utils import (
4545
EarlyStopper,
46-
rolling_evaluate,
4746
average_weights,
4847
evaluate_validation,
4948
compute_validation,
5049
model_infer,
5150
read_pth_from_model_loader,
5251
torch_single_train,
52+
get_evaluation,
5353
)
5454

5555

@@ -377,7 +377,6 @@ def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
377377
"""infer using trained model and unnormalized results"""
378378
data_cfgs = self.cfgs["data_cfgs"]
379379
training_cfgs = self.cfgs["training_cfgs"]
380-
evaluation_cfgs = self.cfgs["evaluation_cfgs"]
381380
device = get_the_device(self.cfgs["training_cfgs"]["device"])
382381
test_dataloader = self._get_dataloader(training_cfgs, data_cfgs, mode="infer")
383382
seq_first = training_cfgs["which_first_tensor"] == "sequence"
@@ -404,33 +403,13 @@ def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
404403
# params of reshape should be (basin size, time length)
405404
pred = pred.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
406405
obs = obs.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
407-
408-
if evaluation_cfgs["rolling"] > 0:
409-
ngrid = self.testdataset.ngrid
410-
nt = self.testdataset.nt
411-
nf = len(data_cfgs["target_cols"])
412-
rolling = evaluation_cfgs["rolling"]
413-
forecast_length = training_cfgs["forecast_length"]
414-
hindcast_output_window = data_cfgs["hindcast_output_window"]
415-
rho = training_cfgs["hindcast_length"]
416-
pred = rolling_evaluate(
417-
(ngrid, nt, nf),
418-
rho,
419-
forecast_length,
420-
rolling,
421-
hindcast_output_window,
422-
pred,
423-
)
424-
obs = rolling_evaluate(
425-
(ngrid, nt, nf),
426-
rho,
427-
forecast_length,
428-
rolling,
429-
hindcast_output_window,
430-
obs,
431-
)
432-
pred_xr = self.testdataset.denormalize(pred, rolling=evaluation_cfgs["rolling"])
433-
obs_xr = self.testdataset.denormalize(obs, rolling=evaluation_cfgs["rolling"])
406+
evaluation_cfgs = self.cfgs["evaluation_cfgs"]
407+
obs_xr, pred_xr = get_evaluation(
408+
test_dataloader,
409+
evaluation_cfgs,
410+
pred,
411+
obs,
412+
)
434413
return pred_xr, obs_xr
435414

436415
def _get_optimizer(self, training_cfgs):
@@ -457,26 +436,15 @@ def _get_loss_func(self, training_cfgs):
457436

458437
def _get_dataloader(self, training_cfgs, data_cfgs, mode="train"):
459438
if mode == "infer":
460-
ngrid = self.testdataset.ngrid
461-
if data_cfgs["sampler"] != "BasinBatchSampler":
462-
# TODO: this case should be tested more
463-
return DataLoader(
464-
self.testdataset,
465-
batch_size=training_cfgs["batch_size"],
466-
shuffle=False,
467-
sampler=None,
468-
batch_sampler=None,
469-
drop_last=False,
470-
timeout=0,
471-
worker_init_fn=None,
472-
)
473-
test_num_samples = self.testdataset.num_samples
474439
return DataLoader(
475440
self.testdataset,
476-
batch_size=test_num_samples // ngrid,
441+
batch_size=training_cfgs["batch_size"],
477442
shuffle=False,
443+
sampler=None,
444+
batch_sampler=None,
478445
drop_last=False,
479446
timeout=0,
447+
worker_init_fn=None,
480448
)
481449
worker_num = 0
482450
pin_memory = False

0 commit comments

Comments
 (0)