Skip to content

Commit 3215c36

Browse files
committed
refactor hf and add another situation in pace
1 parent de47cc0 commit 3215c36

File tree

5 files changed

+99
-393
lines changed

5 files changed

+99
-393
lines changed

torchhydro/datasets/data_dict.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from torchhydro.datasets.data_sets import (
1212
BaseDataset,
1313
BaseDatasetValidSame,
14-
HoDataset,
15-
FoDataset,
1614
HFDataset,
17-
HoEvalWithValidStreamflowDataset,
1815
BasinSingleFlowDataset,
1916
DplDataset,
2017
FlexibleDataset,
@@ -28,9 +25,6 @@
2825
datasets_dict = {
2926
"StreamflowDataset": BaseDataset,
3027
"BaseDatasetValidSame": BaseDatasetValidSame,
31-
"HoDataset": HoDataset,
32-
"HoEvalWithValidStreamflowDataset": HoEvalWithValidStreamflowDataset,
33-
"FoDataset": FoDataset,
3428
"HFDataset": HFDataset,
3529
"SingleflowDataset": BasinSingleFlowDataset,
3630
"DplDataset": DplDataset,

torchhydro/datasets/data_sets.py

Lines changed: 26 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,50 +1103,14 @@ def __getitem__(self, item):
11031103
return torch.from_numpy(xc).float(), torch.from_numpy(y).float()
11041104

11051105

1106-
class HoDataset(BaseDataset):
1106+
class HFDataset(BaseDataset):
11071107
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
1108-
super(HoDataset, self).__init__(data_cfgs, is_tra_val_te)
1108+
super(HFDataset, self).__init__(data_cfgs, is_tra_val_te)
11091109

11101110
@property
11111111
def streamflow_input_name(self):
11121112
return self.data_cfgs["relevant_cols"][-1]
11131113

1114-
def __getitem__(self, item: int):
1115-
if not self.train_mode:
1116-
xf = self.x[item, 1:, :-1]
1117-
# xq = self.x[item, :-1, -1]
1118-
# xq = xq.reshape(xq.size, 1)
1119-
# x = np.concatenate((xf, xq), axis=1)
1120-
x = xf
1121-
y = self.y[item, :, :]
1122-
if self.c is None or self.c.shape[-1] == 0:
1123-
return torch.from_numpy(x).float(), torch.from_numpy(y).float()
1124-
c = self.c[item, :]
1125-
c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
1126-
xc = np.concatenate((x, c), axis=1)
1127-
return torch.from_numpy(xc).float(), torch.from_numpy(y).float()
1128-
basin, idx = self.lookup_table[item]
1129-
warmup_length = self.warmup_length
1130-
xf = self.x[
1131-
basin,
1132-
idx - warmup_length + 1 : idx + self.rho + self.horizon + 1,
1133-
:-1,
1134-
]
1135-
xq = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, -1]
1136-
xq = xq.reshape(xq.size, 1)
1137-
# x = np.concatenate((xf, xq), axis=1)
1138-
x = xf
1139-
y = self.y[basin, idx : idx + self.rho + self.horizon, :]
1140-
if self.c is None or self.c.shape[-1] == 0:
1141-
return torch.from_numpy(x).float(), torch.from_numpy(y).float()
1142-
c = self.c[basin, :]
1143-
c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
1144-
xc = np.concatenate((x, c), axis=1)
1145-
return [
1146-
torch.from_numpy(xc).float(),
1147-
torch.from_numpy(xq).float(),
1148-
], torch.from_numpy(y).float()
1149-
11501114
def _read_xyc_specified_time(self, start_date, end_date):
11511115
"""Read x, y, c data from data source with specified time range
11521116
We set this function as sometimes we need adjust the time range for some specific dataset,
@@ -1161,16 +1125,23 @@ def _read_xyc_specified_time(self, start_date, end_date):
11611125
"""
11621126
date_format = detect_date_format(start_date)
11631127
time_unit = self.data_cfgs["min_time_unit"]
1164-
horizon = self.horizon
11651128
start_date_dt = datetime.strptime(start_date, date_format)
11661129
if time_unit == "h":
11671130
adjusted_start_date = (start_date_dt - timedelta(hours=1)).strftime(
11681131
date_format
11691132
)
1133+
adjusted_start_date_y = (
1134+
start_date_dt
1135+
+ timedelta(hours=self.horizon * self.data_cfgs["min_time_interval"])
1136+
).strftime(date_format)
11701137
elif time_unit == "D":
1171-
adjusted_start_date = (start_date_dt - timedelta(days=1)).strftime(
1172-
date_format
1173-
)
1138+
adjusted_start_date = (
1139+
start_date_dt - timedelta(days=self.data_cfgs["min_time_interval"])
1140+
).strftime(date_format)
1141+
adjusted_start_date_y = (
1142+
start_date_dt
1143+
+ timedelta(days=self.horizon * self.data_cfgs["min_time_interval"])
1144+
).strftime(date_format)
11741145
else:
11751146
raise ValueError(f"Unsupported time unit: {time_unit}")
11761147
data_forcing_ds_ = self.data_source.read_ts_xrdataset(
@@ -1252,25 +1223,7 @@ def standardize_unit(unit):
12521223

12531224
return data_forcing_ds, data_output_ds
12541225

1255-
1256-
class HoEvalWithValidStreamflowDataset(HoDataset):
1257-
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
1258-
super(HoEvalWithValidStreamflowDataset, self).__init__(data_cfgs, is_tra_val_te)
1259-
12601226
def __getitem__(self, item: int):
1261-
if not self.train_mode:
1262-
xf = self.x[item, 1:, :-1]
1263-
xq = self.x[item, :-1, -1]
1264-
xq = xq.reshape(xq.size, 1)
1265-
x = np.concatenate((xf, xq), axis=1)
1266-
# x = xf
1267-
y = self.y[item, :, :]
1268-
if self.c is None or self.c.shape[-1] == 0:
1269-
return torch.from_numpy(x).float(), torch.from_numpy(y).float()
1270-
c = self.c[item, :]
1271-
c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
1272-
xc = np.concatenate((x, c), axis=1)
1273-
return torch.from_numpy(xc).float(), torch.from_numpy(y).float()
12741227
basin, idx = self.lookup_table[item]
12751228
warmup_length = self.warmup_length
12761229
xf = self.x[
@@ -1280,123 +1233,19 @@ def __getitem__(self, item: int):
12801233
]
12811234
xq = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, -1]
12821235
xq = xq.reshape(xq.size, 1)
1283-
x = np.concatenate((xf, xq), axis=1)
1284-
# x = xf
1285-
y = self.y[basin, idx : idx + self.rho + self.horizon, :]
1286-
if self.c is None or self.c.shape[-1] == 0:
1287-
return torch.from_numpy(x).float(), torch.from_numpy(y).float()
1288-
c = self.c[basin, :]
1289-
c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
1290-
xc = np.concatenate((x, c), axis=1)
1291-
return torch.from_numpy(xc).float(), torch.from_numpy(y).float()
1292-
1293-
1294-
class FoDataset(HoDataset):
1295-
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
1296-
super(FoDataset, self).__init__(data_cfgs, is_tra_val_te)
1297-
1298-
def __getitem__(self, item: int):
1299-
if not self.train_mode:
1300-
x = self.x[item, 1:, :-1]
1301-
y = self.y[item, :, :]
1302-
if self.c is None or self.c.shape[-1] == 0:
1303-
return torch.from_numpy(x).float(), torch.from_numpy(y).float()
1304-
c = self.c[item, :]
1305-
c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
1306-
xc = np.concatenate((x, c), axis=1)
1307-
return torch.from_numpy(xc).float(), torch.from_numpy(y).float()
1308-
basin, idx = self.lookup_table[item]
1309-
warmup_length = self.warmup_length
1310-
x = self.x[
1311-
basin,
1312-
idx - warmup_length + 1 : idx + self.rho + self.horizon + 1,
1313-
:-1,
1314-
]
1315-
xy = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, -1]
1316-
xy = xy.reshape(xy.size, 1)
1317-
y = self.y[basin, idx : idx + self.rho + self.horizon, :]
1318-
if self.c is None or self.c.shape[-1] == 0:
1319-
return torch.from_numpy(x).float(), torch.from_numpy(y).float()
1320-
c = self.c[basin, :]
1321-
c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
1322-
xc = np.concatenate((x, c), axis=1)
1323-
return [
1324-
torch.from_numpy(xc).float(),
1325-
torch.from_numpy(xy).float(),
1326-
], torch.from_numpy(y).float()
1327-
1328-
1329-
class HFDataset(HoDataset):
1330-
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
1331-
super(HFDataset, self).__init__(data_cfgs, is_tra_val_te)
1332-
1333-
def __getitem__(self, item: int):
1334-
if not self.train_mode:
1335-
# 先把forcing都取出来
1336-
xf = self.x[item, 1:, :-1]
1337-
xf = xf.reshape(1, xf.shape[0], xf.shape[1])
1338-
# 再把hindcast输入的streamflow取出来
1339-
xq = self.x[item, : self.rho, -1]
1340-
xq = xq.reshape(xq.size, 1)
1341-
# 取hindcast部分的forcing
1342-
xf_hind = xf[:, : self.rho, :]
1343-
# 取forecast部分的forcing
1344-
xf_fore = xf[:, self.rho :, :]
1345-
# 取y
1346-
y = self.y[item, :, :]
1347-
# 取c
1348-
c = self.c[item, :]
1349-
# 转到二维
1350-
xf_hind = xf_hind.squeeze(0)
1351-
xf_fore = xf_fore.squeeze(0)
1352-
# hindcast部分和c拼接
1353-
hind_c = np.repeat(c, xf_hind.shape[0], axis=0).reshape(c.shape[0], -1).T
1354-
xf_hind_c = np.concatenate((xf_hind, hind_c), axis=1)
1355-
x_hind_c = np.concatenate((xf_hind_c, xq), axis=-1)
1356-
# forecast部分和c拼接
1357-
fore_c = np.repeat(c, xf_fore.shape[0], axis=0).reshape(c.shape[0], -1).T
1358-
xf_fore_c = np.concatenate((xf_fore, fore_c), axis=1)
1359-
return [
1360-
torch.from_numpy(x_hind_c).float(),
1361-
torch.from_numpy(xf_fore_c).float(),
1362-
], torch.from_numpy(y).float()
1363-
basin, idx = self.lookup_table[item]
1364-
warmup_length = self.warmup_length
1365-
# 先把hindcast和forecast的forcing取出来
1366-
xf = self.x[
1367-
basin,
1368-
idx - warmup_length + 1 : idx + self.rho + self.horizon + 1,
1369-
:-1,
1370-
]
1371-
xf = xf.reshape(1, xf.shape[0], xf.shape[1])
1372-
xf_hind = xf[:, : self.rho, :]
1373-
xf_fore = xf[:, self.rho :, :]
1374-
# 再把hindcast和forecast输入的streamflow取出来
1375-
xq = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, -1]
1376-
xq = xq.reshape(1, xq.size, 1)
1377-
# 取hindcast部分的流量
1378-
xq_hind = xq[:, : self.rho, :]
1379-
# 取forecast部分的流量
1380-
xq_fore = xq[:, : self.rho, :]
1381-
# 取c
1236+
xf_rho = xf[: self.rho, :]
1237+
xf_hor = xf[self.rho :, :]
1238+
xq_rho = xq[: self.rho, :]
1239+
xq_hor = xq[self.rho :, :]
1240+
y = self.y[basin, idx + self.rho : idx + self.rho + self.horizon, :]
13821241
c = self.c[basin, :]
1383-
# 取y
1384-
y = self.y[basin, idx : idx + self.rho + self.horizon, :]
1385-
# 转到二维
1386-
xf_hind = xf_hind.squeeze(0)
1387-
xf_fore = xf_fore.squeeze(0)
1388-
xq_hind = xq_hind.squeeze(0)
1389-
xq_fore = xq_fore.squeeze(0)
1390-
# hindcast部分和c拼接
1391-
hind_c = np.repeat(c, xf_hind.shape[0], axis=0).reshape(c.shape[0], -1).T
1392-
xf_hind_c = np.concatenate((xf_hind, hind_c), axis=1)
1393-
x_hind_c = np.concatenate((xf_hind_c, xq_hind), axis=-1)
1394-
# forecast部分和c拼接
1395-
fore_c = np.repeat(c, xf_fore.shape[0], axis=0).reshape(c.shape[0], -1).T
1396-
xf_fore_c = np.concatenate((xf_fore, fore_c), axis=1)
1397-
1242+
c_rho = np.repeat(c, xf_rho.shape[0], axis=0).reshape(c.shape[0], -1).T
1243+
c_hor = np.repeat(c, xf_hor.shape[0], axis=0).reshape(c.shape[0], -1).T
1244+
xfc_rho = np.concatenate((xf_rho, c_rho), axis=1)
1245+
xfc_hor = np.concatenate((xf_hor, c_hor), axis=1)
13981246
return [
1399-
torch.from_numpy(x_hind_c).float(),
1400-
torch.from_numpy(xf_fore_c).float(),
1401-
torch.from_numpy(xq_fore).float(),
1247+
torch.from_numpy(xfc_rho).float(),
1248+
torch.from_numpy(xfc_hor).float(),
1249+
torch.from_numpy(xq_rho).float(),
1250+
torch.from_numpy(xq_hor).float(),
14021251
], torch.from_numpy(y).float()

torchhydro/models/model_dict_function.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
MultiLayerLSTM,
2525
SimpleLSTM,
2626
SimpleLSTMForecast,
27-
FoLSTM,
28-
HoLSTM,
2927
HFLSTM,
3028
)
3129
from torchhydro.models.seqforecast import SequentialForecastLSTM
@@ -81,8 +79,6 @@
8179
"LinearMultiLayerLSTMModel": LinearMultiLayerLSTMModel,
8280
"SPPLSTM": SPP_LSTM_Model,
8381
"SimpleLSTMForecast": SimpleLSTMForecast,
84-
"HoLSTM": HoLSTM,
85-
"FoLSTM": FoLSTM,
8682
"HFLSTM": HFLSTM,
8783
"SPPLSTM2": SPP_LSTM_Model_2,
8884
"SeqForecastLSTM": SequentialForecastLSTM,

0 commit comments

Comments
 (0)