Skip to content

Commit 69cb8a0

Browse files
committed
del gpu cache during validation
1 parent f3ffe1c commit 69cb8a0

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

torchhydro/datasets/data_sets.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -495,24 +495,24 @@ def _read_xyc(self):
495495
end_date = self.t_s_dict["t_final_range"][1]
496496
return self._read_xyc_specified_time(start_date, end_date)
497497

498-
def _rm_timeunit_key(self, data_output_ds_):
498+
def _rm_timeunit_key(self, ds_):
499499
"""this means the data source return a dict with key as time_unit
500500
in this BaseDataset, we only support unified time range for all basins, so we chose the first key
501501
TODO: maybe this could be refactored better
502502
503503
Parameters
504504
----------
505-
data_output_ds_ : dict
506-
the output data with time_unit as key
505+
ds_ : dict
506+
the xarray data with time_unit as key
507507
508508
Returns
509509
----------
510-
data_output_ds_ : xr.Dataset
510+
ds_ : xr.Dataset
511511
the output data without time_unit
512512
"""
513-
if isinstance(data_output_ds_, dict):
514-
data_output_ds_ = data_output_ds_[list(data_output_ds_.keys())[0]]
515-
return data_output_ds_
513+
if isinstance(ds_, dict):
514+
ds_ = ds_[list(ds_.keys())[0]]
515+
return ds_
516516

517517
def _read_xyc_specified_time(self, start_date, end_date):
518518
"""Read x, y, c data from data source with specified time range

torchhydro/trainers/train_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def compute_validation(
532532
data_loader: DataLoader,
533533
device: torch.device = None,
534534
**kwargs,
535-
) -> float:
535+
):
536536
"""
537537
Function to compute the validation loss metrics
538538
@@ -557,21 +557,34 @@ def compute_validation(
557557
obs = []
558558
preds = []
559559
valid_loss = 0.0
560+
obs_final = None
561+
pred_final = None
560562
with torch.no_grad():
563+
iter_num = 0
561564
for src, trg in data_loader:
562565
trg, output = model_infer(seq_first, device, model, src, trg)
563566
obs.append(trg)
564567
preds.append(output)
565568
valid_loss_ = compute_loss(trg, output, criterion)
569+
if torch.isnan(valid_loss_):
570+
# for not-train mode, we may get all nan data for trg
571+
# so we skip this batch
572+
continue
566573
valid_loss = valid_loss + valid_loss_.item()
574+
iter_num = iter_num + 1
567575
# clear memory to save GPU memory
576+
if obs_final is None:
577+
obs_final = trg.detach().cpu()
578+
pred_final = output.detach().cpu()
579+
else:
580+
obs_final = torch.cat([obs_final, trg.detach().cpu()], dim=0)
581+
pred_final = torch.cat([pred_final, output.detach().cpu()], dim=0)
582+
del trg, output
568583
torch.cuda.empty_cache()
569584
# first dim is batch
570-
obs_final = torch.cat(obs, dim=0)
571-
pred_final = torch.cat(preds, dim=0)
572-
valid_loss = valid_loss / len(data_loader)
573-
y_obs = obs_final.detach().cpu().numpy()
574-
y_pred = pred_final.detach().cpu().numpy()
585+
valid_loss = valid_loss / iter_num
586+
y_obs = obs_final.numpy()
587+
y_pred = pred_final.numpy()
575588
return y_obs, y_pred, valid_loss
576589

577590

0 commit comments

Comments
 (0)