Skip to content

Commit 4e8dc5a

Browse files
committed
Merge remote-tracking branch 'downstream/dev' into dev
2 parents da14708 + 81ffaa6 commit 4e8dc5a

File tree

6 files changed

+109
-60
lines changed

6 files changed

+109
-60
lines changed

torchhydro/datasets/data_dict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from torchhydro.datasets.data_sets import (
1212
BaseDataset,
13-
BaseDatasetValidSame,
13+
ForecastDataset,
1414
HFDataset,
1515
BasinSingleFlowDataset,
1616
DplDataset,
@@ -24,7 +24,7 @@
2424

2525
datasets_dict = {
2626
"StreamflowDataset": BaseDataset,
27-
"BaseDatasetValidSame": BaseDatasetValidSame,
27+
"ForecastDataset": ForecastDataset,
2828
"HFDataset": HFDataset,
2929
"SingleflowDataset": BasinSingleFlowDataset,
3030
"DplDataset": DplDataset,

torchhydro/datasets/data_sets.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -976,65 +976,100 @@ def _read_xyc(self):
976976
time_unit = self.data_cfgs["min_time_unit"]
977977

978978
# Determine the date format
979-
date_format = detect_date_format(end_date)
979+
date_format = detect_date_format(start_date)
980980

981981
# Adjust the end date based on the time unit
982-
end_date_dt = datetime.strptime(end_date, date_format)
982+
start_date_dt = datetime.strptime(start_date, date_format)
983983
if time_unit == "h":
984-
adjusted_end_date = (end_date_dt + timedelta(hours=interval)).strftime(
984+
adjusted_start_date = (start_date_dt - timedelta(hours=interval)).strftime(
985985
date_format
986986
)
987987
elif time_unit == "D":
988-
adjusted_end_date = (end_date_dt + timedelta(days=interval)).strftime(
988+
adjusted_start_date = (start_date_dt - timedelta(days=interval)).strftime(
989989
date_format
990990
)
991991
else:
992992
raise ValueError(f"Unsupported time unit: {time_unit}")
993-
return self._read_xyc_specified_time(start_date, adjusted_end_date)
993+
return self._read_xyc_specified_time(adjusted_start_date, end_date)
994994

995-
def _normalize(self):
996-
x, y, c = super()._normalize()
997-
# TODO: this work for minio? maybe better to move to basedataset
998-
return x.compute(), y.compute(), c.compute()
995+
def denormalize(self, norm_data, is_real_time=True):
996+
"""Denormalize the norm_data
997+
998+
Parameters
999+
----------
1000+
norm_data : np.ndarray
1001+
batch-first data
1002+
is_real_time : bool, optional
1003+
whether the data is real time data, by default True
1004+
sometimes we may have multiple results for one time period and we flatten them
1005+
so we need a temp time to replace real one
1006+
1007+
Returns
1008+
-------
1009+
xr.Dataset
1010+
denormlized data
1011+
"""
1012+
target_scaler = self.target_scaler
1013+
target_data = target_scaler.data_target
1014+
# the units are dimensionless for pure DL models
1015+
units = {k: "dimensionless" for k in target_data.attrs["units"].keys()}
1016+
if target_scaler.pbm_norm:
1017+
units = {**units, **target_data.attrs["units"]}
1018+
warmup_length = self.warmup_length
1019+
selected_time_points = target_data.coords["time"][warmup_length:-1]
1020+
selected_data = target_data.sel(time=selected_time_points)
1021+
denorm_xr_ds = target_scaler.inverse_transform(
1022+
xr.DataArray(
1023+
norm_data,
1024+
dims=selected_data.dims,
1025+
coords=selected_data.coords,
1026+
attrs={"units": units},
1027+
)
1028+
)
1029+
return set_unit_to_var(denorm_xr_ds)
9991030

10001031
def __getitem__(self, item: int):
10011032
basin, time = self.lookup_table[item]
10021033
rho = self.rho
10031034
horizon = self.horizon
10041035
hindcast_output_window = self.data_cfgs.get("hindcast_output_window", 0)
10051036
# p cover all encoder-decoder periods; +1 means the period while +0 means start of the current period
1006-
p = self.x[basin, time + 1 : time + rho + horizon + 1, 0].reshape(-1, 1)
1037+
p = self.x[basin, time + 1 : time + rho + horizon + 1, :1]
10071038
# s only cover encoder periods
10081039
s = self.x[basin, time : time + rho, 1:]
1009-
x = np.concatenate((p[:rho], s), axis=1)
1040+
# xe = np.concatenate((p[:rho], s), axis=1)
10101041

10111042
if self.c is None or self.c.shape[-1] == 0:
1012-
xc = x
1043+
pc = p
10131044
else:
10141045
c = self.c[basin, :]
10151046
c = np.tile(c, (rho + horizon, 1))
1016-
xc = np.concatenate((x, c[:rho]), axis=1)
1047+
pc = np.concatenate((p[:rho], c[:rho]), axis=1)
1048+
xe = np.concatenate((pc[:rho], s), axis=1)
10171049
# xh cover decoder periods
10181050
try:
1019-
xh = np.concatenate((p[rho:], c[rho:]), axis=1)
1051+
xd = np.concatenate((p[rho:], c[rho:]), axis=1)
10201052
except ValueError as e:
10211053
print(f"Error in np.concatenate: {e}")
10221054
print(f"p[rho:].shape: {p[rho:].shape}, c[rho:].shape: {c[rho:].shape}")
10231055
raise
10241056
# y cover specified encoder size (hindcast_output_window) and all decoder periods
10251057
y = self.y[
10261058
basin, time + rho - hindcast_output_window + 1 : time + rho + horizon + 1, :
1027-
]
1059+
] # qs
1060+
# y_q = y[:, :1]
1061+
# y_s = y[:, 1:]
1062+
# y = np.concatenate((y_s, y_q), axis=1)
10281063

10291064
if self.is_tra_val_te == "train":
10301065
return [
1031-
torch.from_numpy(xc).float(),
1032-
torch.from_numpy(xh).float(),
1066+
torch.from_numpy(xe).float(),
1067+
torch.from_numpy(xd).float(),
10331068
torch.from_numpy(y).float(),
10341069
], torch.from_numpy(y).float()
10351070
return [
1036-
torch.from_numpy(xc).float(),
1037-
torch.from_numpy(xh).float(),
1071+
torch.from_numpy(xe).float(),
1072+
torch.from_numpy(xd).float(),
10381073
], torch.from_numpy(y).float()
10391074

10401075

@@ -1086,15 +1121,15 @@ def __getitem__(self, item: int):
10861121
], torch.from_numpy(y).float()
10871122

10881123

1089-
class BaseDatasetValidSame(BaseDataset):
1124+
class ForecastDataset(BaseDataset):
10901125
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
1091-
super(BaseDatasetValidSame, self).__init__(data_cfgs, is_tra_val_te)
1126+
super(ForecastDataset, self).__init__(data_cfgs, is_tra_val_te)
10921127

10931128
def __getitem__(self, item):
10941129
basin, idx = self.lookup_table[item]
10951130
warmup_length = self.warmup_length
10961131
x = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, :]
1097-
y = self.y[basin, idx : idx + self.rho + self.horizon, :]
1132+
y = self.y[basin, idx + self.rho : idx + self.rho + self.horizon, :]
10981133
if self.c is None or self.c.shape[-1] == 0:
10991134
return torch.from_numpy(x).float(), torch.from_numpy(y).float()
11001135
c = self.c[basin, :]

torchhydro/models/seq2seq.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.3)
6565
self.hidden_dim = hidden_dim
6666
self.pre_fc = nn.Linear(input_dim, hidden_dim)
6767
self.pre_relu = nn.ReLU()
68-
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
68+
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers)
6969
self.dropout = nn.Dropout(dropout)
7070
self.fc = nn.Linear(hidden_dim, output_dim)
7171

@@ -88,7 +88,7 @@ def __init__(self, input_dim, output_dim, hidden_dim, num_layers=1, dropout=0.3)
8888
self.hidden_dim = hidden_dim
8989
self.pre_fc = nn.Linear(input_dim, hidden_dim)
9090
self.pre_relu = nn.ReLU()
91-
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
91+
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers)
9292
self.dropout = nn.Dropout(dropout)
9393
self.fc_out = nn.Linear(hidden_dim, output_dim)
9494

@@ -157,6 +157,12 @@ def __init__(
157157
)
158158
self.transfer = StateTransferNetwork(hidden_dim=hidden_size)
159159

160+
def _teacher_forcing_preparation(self, trgs):
161+
# teacher forcing preparation
162+
valid_mask = ~torch.isnan(trgs)
163+
random_vals = torch.rand_like(valid_mask, dtype=torch.float)
164+
return (random_vals < self.teacher_forcing_ratio) * valid_mask
165+
160166
def forward(self, *src):
161167
if len(src) == 3:
162168
encoder_input, decoder_input, trgs = src
@@ -165,40 +171,43 @@ def forward(self, *src):
165171
device = decoder_input.device
166172
trgs = torch.full(
167173
(
168-
decoder_input.shape[0], # batch_size
169174
self.hindcast_output_window + self.trg_len, # seq
175+
decoder_input.shape[1], # batch_size
170176
self.output_size, # features
171177
),
172178
float("nan"),
173179
).to(device)
174-
encoder_outputs, hidden_, cell_ = self.encoder(encoder_input)
180+
trgs_q = trgs[:, :, :1]
181+
trgs_s = trgs[:, :, 1:]
182+
trgs = torch.cat((trgs_s, trgs_q), dim=2) # sq
183+
encoder_outputs, hidden_, cell_ = self.encoder(encoder_input) # sq
175184
hidden, cell = self.transfer(hidden_, cell_)
176185
outputs = []
177-
current_input = encoder_outputs[:, -1, :].unsqueeze(1)
186+
prev_output = encoder_outputs[-1, :, :].unsqueeze(0) # sq
187+
_, batch_size, _ = decoder_input.size()
178188

189+
outputs = torch.zeros(self.trg_len, batch_size, self.output_size).to(
190+
decoder_input.device
191+
)
192+
use_teacher_forcing = self._teacher_forcing_preparation(trgs)
179193
for t in range(self.trg_len):
180-
p = decoder_input[:, t, :].unsqueeze(1)
181-
current_input = torch.cat((current_input, p), dim=2)
182-
output, hidden, cell = self.decoder(current_input, hidden, cell)
183-
outputs.append(output.squeeze(1))
184-
trg = trgs[:, (self.hindcast_output_window + t), :].unsqueeze(1)
185-
valid_mask = ~torch.isnan(trg)
186-
random_vals = torch.rand_like(valid_mask, dtype=torch.float)
187-
use_teacher_forcing = (
188-
random_vals < self.teacher_forcing_ratio
189-
) * valid_mask
190-
current_input = torch.where(
191-
torch.isnan(trg), # if trg is nan
192-
output, # then use output
193-
trg * use_teacher_forcing
194-
+ output
195-
* (~use_teacher_forcing), # else calculate with teacher forcing
194+
pc = decoder_input[t : t + 1, :, :] # sq
195+
obs = trgs[self.hindcast_output_window + t, :, :].unsqueeze(0) # sq
196+
safe_obs = torch.where(torch.isnan(obs), torch.zeros_like(obs), obs)
197+
prev_output = torch.where( # sq
198+
use_teacher_forcing[t : t + 1, :, :],
199+
safe_obs,
200+
prev_output,
196201
)
197-
198-
outputs = torch.stack(outputs, dim=1)
202+
current_input = torch.cat((pc, prev_output), dim=2) # pcsq
203+
output, hidden, cell = self.decoder(current_input, hidden, cell)
204+
outputs[t, :, :] = output.squeeze(0) # sq
199205
if self.hindcast_output_window > 0:
200-
prec_outputs = encoder_outputs[:, -self.hindcast_output_window :, :]
201-
outputs = torch.cat((prec_outputs, outputs), dim=1)
206+
prec_outputs = encoder_outputs[-self.hindcast_output_window :, :, :]
207+
outputs = torch.cat((prec_outputs, outputs), dim=0)
208+
outputs_s = outputs[:, :, :1]
209+
outputs_q = outputs[:, :, 1:]
210+
outputs = torch.cat((outputs_q, outputs_s), dim=2) # qs
202211
return outputs
203212

204213

torchhydro/models/simple_lstm.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,26 @@ def forward(self, *x):
9292
xfc_rho, xfc_hor, xq_rho, xq_hor = x
9393

9494
x_rho = torch.cat((xfc_rho, xq_rho), dim=-1)
95-
seq_len, batch_size, _ = xfc_hor.size()
96-
97-
use_teacher_forcing = self._teacher_forcing_preparation(xq_hor)
95+
hor_len, batch_size, _ = xfc_hor.size()
9896

9997
# hindcast-forecast, we do not have forecast-hindcast situation
10098
# do rho forward first, prev_output is the last output of rho (seq_length = 1, batch_size, feature = output_size)
10199
if self.hindcast_with_output:
102100
_, h_n, c_n, prev_output = self._rho_forward(x_rho)
101+
seq_len = hor_len
103102
else:
104103
# TODO: need more test
105-
seq_len = xfc_rho.shape[0] + seq_len
104+
seq_len = xfc_rho.shape[0] + hor_len
105+
xfc_hor = torch.cat((xfc_rho, xfc_hor), dim=0)
106+
xq_hor = torch.cat((xq_rho, xq_hor), dim=0)
106107
h_n = torch.randn(1, batch_size, self.hidden_size).to(xfc_rho.device) * 0.1
107108
c_n = torch.randn(1, batch_size, self.hidden_size).to(xfc_rho.device) * 0.1
108109
prev_output = (
109-
torch.randn(1, batch_size, self.output_size).to(x.device) * 0.1
110+
torch.randn(1, batch_size, self.output_size).to(xfc_rho.device) * 0.1
110111
)
111112

113+
use_teacher_forcing = self._teacher_forcing_preparation(xq_hor)
114+
112115
# do hor forward
113116
outputs = torch.zeros(seq_len, batch_size, self.output_size).to(xfc_rho.device)
114117
# TODO: too slow here when seq_len is large, need to optimize
@@ -131,7 +134,7 @@ def forward(self, *x):
131134
prev_output = self.linearOut(out_lstm)
132135
outputs[t, :, :] = prev_output.squeeze(0)
133136
# Return the outputs
134-
return outputs
137+
return outputs[-hor_len:, :, :]
135138

136139

137140
class MultiLayerLSTM(nn.Module):

torchhydro/trainers/deep_hydro.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,9 @@ def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
385385
test_preds = []
386386
obss = []
387387
with torch.no_grad():
388-
for xs, ys in test_dataloader:
388+
for xs, ys in tqdm(
389+
test_dataloader, desc="Processing", total=len(test_dataloader)
390+
):
389391
# here the a batch doesn't mean a basin; it is only an index in lookup table
390392
# for NtoN mode, only basin is index in lookup table, so the batch is same as basin
391393
# for Nto1 mode, batch is only an index

torchhydro/trainers/train_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,13 @@ def _recover_samples_to_basin(arr_3d, valorte_data_loader, pace_idx):
392392
for sample_idx in range(arr_3d.shape[0]):
393393
# Get the basin and start time index corresponding to this sample
394394
basin, start_time = dataset.lookup_table[sample_idx]
395-
# Take the value at the last time step of this sample (at the position of rho + horizon)
396-
value = arr_3d[sample_idx, pace_idx, :]
397395
# Calculate the time position in the result array
398396
if pace_idx < 0:
397+
value = arr_3d[sample_idx, pace_idx, :]
399398
result_time_idx = start_time + warmup_len + rho + horizon + pace_idx
400399
else:
401-
result_time_idx = start_time + warmup_len + rho + pace_idx
400+
value = arr_3d[sample_idx, pace_idx - 1, :]
401+
result_time_idx = start_time + warmup_len + rho + pace_idx - 1
402402
# Fill in the corresponding position
403403
basin_array[basin, result_time_idx, :] = value
404404

@@ -609,7 +609,7 @@ def compute_validation(
609609
pred_final = None
610610
with torch.no_grad():
611611
iter_num = 0
612-
for src, trg in data_loader:
612+
for src, trg in tqdm(data_loader, desc="Processing", total=len(data_loader)):
613613
trg, output = model_infer(seq_first, device, model, src, trg)
614614
obs.append(trg)
615615
preds.append(output)

0 commit comments

Comments
 (0)