Skip to content

Commit a9f3042

Browse files
committed
make valid loss calculation way same as training
1 parent 2a79d55 commit a9f3042

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

torchhydro/datasets/data_sets.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def times(self):
300300
return times_
301301

302302
def __len__(self):
303-
return self.num_samples if self.train_mode else self.ngrid
303+
return self.num_samples
304304

305305
def __getitem__(self, item: int):
306306
basin, idx = self.lookup_table[item]
@@ -818,9 +818,6 @@ def __getitem__(self, index):
818818
y = ys[-1, :]
819819
return xc, y
820820

821-
def __len__(self):
822-
return self.num_samples
823-
824821

825822
class DplDataset(BaseDataset):
826823
"""pytorch dataset for Differential parameter learning"""
@@ -886,9 +883,6 @@ def __getitem__(self, item):
886883
z_train,
887884
), torch.from_numpy(y_train).float()
888885

889-
def __len__(self):
890-
return self.num_samples if self.train_mode else len(self.t_s_dict["sites_id"])
891-
892886

893887
class FlexibleDataset(BaseDataset):
894888
"""A dataset whose datasources are from multiple sources according to the configuration"""
@@ -999,9 +993,6 @@ def _normalize(self):
999993
# TODO: this work for minio? maybe better to move to basedataset
1000994
return x.compute(), y.compute(), c.compute()
1001995

1002-
def __len__(self):
1003-
return self.num_samples
1004-
1005996
def __getitem__(self, item: int):
1006997
basin, time = self.lookup_table[item]
1007998
rho = self.rho

torchhydro/trainers/train_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,18 +482,20 @@ def compute_validation(
482482
seq_first = kwargs["which_first_tensor"] != "batch"
483483
obs = []
484484
preds = []
485+
valid_loss = 0.0
485486
with torch.no_grad():
486487
for src, trg in data_loader:
487488
trg, output = model_infer(seq_first, device, model, src, trg)
488489
obs.append(trg)
489490
preds.append(output)
491+
valid_loss_ = compute_loss(trg, output, criterion)
492+
valid_loss = valid_loss + valid_loss_.item()
490493
# clear memory to save GPU memory
491494
torch.cuda.empty_cache()
492495
# first dim is batch
493496
obs_final = torch.cat(obs, dim=0)
494497
pred_final = torch.cat(preds, dim=0)
495-
496-
valid_loss = compute_loss(obs_final, pred_final, criterion)
498+
valid_loss = valid_loss / len(data_loader)
497499
y_obs = obs_final.detach().cpu().numpy()
498500
y_pred = pred_final.detach().cpu().numpy()
499501
return y_obs, y_pred, valid_loss

0 commit comments

Comments
 (0)