@@ -1103,50 +1103,14 @@ def __getitem__(self, item):
11031103 return torch .from_numpy (xc ).float (), torch .from_numpy (y ).float ()
11041104
11051105
1106- class HoDataset (BaseDataset ):
1106+ class HFDataset (BaseDataset ):
11071107 def __init__ (self , data_cfgs : dict , is_tra_val_te : str ):
1108- super (HoDataset , self ).__init__ (data_cfgs , is_tra_val_te )
1108+ super (HFDataset , self ).__init__ (data_cfgs , is_tra_val_te )
11091109
11101110 @property
11111111 def streamflow_input_name (self ):
11121112 return self .data_cfgs ["relevant_cols" ][- 1 ]
11131113
1114- def __getitem__ (self , item : int ):
1115- if not self .train_mode :
1116- xf = self .x [item , 1 :, :- 1 ]
1117- # xq = self.x[item, :-1, -1]
1118- # xq = xq.reshape(xq.size, 1)
1119- # x = np.concatenate((xf, xq), axis=1)
1120- x = xf
1121- y = self .y [item , :, :]
1122- if self .c is None or self .c .shape [- 1 ] == 0 :
1123- return torch .from_numpy (x ).float (), torch .from_numpy (y ).float ()
1124- c = self .c [item , :]
1125- c = np .repeat (c , x .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1126- xc = np .concatenate ((x , c ), axis = 1 )
1127- return torch .from_numpy (xc ).float (), torch .from_numpy (y ).float ()
1128- basin , idx = self .lookup_table [item ]
1129- warmup_length = self .warmup_length
1130- xf = self .x [
1131- basin ,
1132- idx - warmup_length + 1 : idx + self .rho + self .horizon + 1 ,
1133- :- 1 ,
1134- ]
1135- xq = self .x [basin , idx - warmup_length : idx + self .rho + self .horizon , - 1 ]
1136- xq = xq .reshape (xq .size , 1 )
1137- # x = np.concatenate((xf, xq), axis=1)
1138- x = xf
1139- y = self .y [basin , idx : idx + self .rho + self .horizon , :]
1140- if self .c is None or self .c .shape [- 1 ] == 0 :
1141- return torch .from_numpy (x ).float (), torch .from_numpy (y ).float ()
1142- c = self .c [basin , :]
1143- c = np .repeat (c , x .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1144- xc = np .concatenate ((x , c ), axis = 1 )
1145- return [
1146- torch .from_numpy (xc ).float (),
1147- torch .from_numpy (xq ).float (),
1148- ], torch .from_numpy (y ).float ()
1149-
11501114 def _read_xyc_specified_time (self , start_date , end_date ):
11511115 """Read x, y, c data from data source with specified time range
11521116 We set this function as sometimes we need adjust the time range for some specific dataset,
@@ -1161,16 +1125,23 @@ def _read_xyc_specified_time(self, start_date, end_date):
11611125 """
11621126 date_format = detect_date_format (start_date )
11631127 time_unit = self .data_cfgs ["min_time_unit" ]
1164- horizon = self .horizon
11651128 start_date_dt = datetime .strptime (start_date , date_format )
11661129 if time_unit == "h" :
11671130 adjusted_start_date = (start_date_dt - timedelta (hours = 1 )).strftime (
11681131 date_format
11691132 )
1133+ adjusted_start_date_y = (
1134+ start_date_dt
1135+ + timedelta (hours = self .horizon * self .data_cfgs ["min_time_interval" ])
1136+ ).strftime (date_format )
11701137 elif time_unit == "D" :
1171- adjusted_start_date = (start_date_dt - timedelta (days = 1 )).strftime (
1172- date_format
1173- )
1138+ adjusted_start_date = (
1139+ start_date_dt - timedelta (days = self .data_cfgs ["min_time_interval" ])
1140+ ).strftime (date_format )
1141+ adjusted_start_date_y = (
1142+ start_date_dt
1143+ + timedelta (days = self .horizon * self .data_cfgs ["min_time_interval" ])
1144+ ).strftime (date_format )
11741145 else :
11751146 raise ValueError (f"Unsupported time unit: { time_unit } " )
11761147 data_forcing_ds_ = self .data_source .read_ts_xrdataset (
@@ -1252,25 +1223,7 @@ def standardize_unit(unit):
12521223
12531224 return data_forcing_ds , data_output_ds
12541225
1255-
1256- class HoEvalWithValidStreamflowDataset (HoDataset ):
1257- def __init__ (self , data_cfgs : dict , is_tra_val_te : str ):
1258- super (HoEvalWithValidStreamflowDataset , self ).__init__ (data_cfgs , is_tra_val_te )
1259-
12601226 def __getitem__ (self , item : int ):
1261- if not self .train_mode :
1262- xf = self .x [item , 1 :, :- 1 ]
1263- xq = self .x [item , :- 1 , - 1 ]
1264- xq = xq .reshape (xq .size , 1 )
1265- x = np .concatenate ((xf , xq ), axis = 1 )
1266- # x = xf
1267- y = self .y [item , :, :]
1268- if self .c is None or self .c .shape [- 1 ] == 0 :
1269- return torch .from_numpy (x ).float (), torch .from_numpy (y ).float ()
1270- c = self .c [item , :]
1271- c = np .repeat (c , x .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1272- xc = np .concatenate ((x , c ), axis = 1 )
1273- return torch .from_numpy (xc ).float (), torch .from_numpy (y ).float ()
12741227 basin , idx = self .lookup_table [item ]
12751228 warmup_length = self .warmup_length
12761229 xf = self .x [
@@ -1280,123 +1233,19 @@ def __getitem__(self, item: int):
12801233 ]
12811234 xq = self .x [basin , idx - warmup_length : idx + self .rho + self .horizon , - 1 ]
12821235 xq = xq .reshape (xq .size , 1 )
1283- x = np .concatenate ((xf , xq ), axis = 1 )
1284- # x = xf
1285- y = self .y [basin , idx : idx + self .rho + self .horizon , :]
1286- if self .c is None or self .c .shape [- 1 ] == 0 :
1287- return torch .from_numpy (x ).float (), torch .from_numpy (y ).float ()
1288- c = self .c [basin , :]
1289- c = np .repeat (c , x .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1290- xc = np .concatenate ((x , c ), axis = 1 )
1291- return torch .from_numpy (xc ).float (), torch .from_numpy (y ).float ()
1292-
1293-
1294- class FoDataset (HoDataset ):
1295- def __init__ (self , data_cfgs : dict , is_tra_val_te : str ):
1296- super (FoDataset , self ).__init__ (data_cfgs , is_tra_val_te )
1297-
1298- def __getitem__ (self , item : int ):
1299- if not self .train_mode :
1300- x = self .x [item , 1 :, :- 1 ]
1301- y = self .y [item , :, :]
1302- if self .c is None or self .c .shape [- 1 ] == 0 :
1303- return torch .from_numpy (x ).float (), torch .from_numpy (y ).float ()
1304- c = self .c [item , :]
1305- c = np .repeat (c , x .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1306- xc = np .concatenate ((x , c ), axis = 1 )
1307- return torch .from_numpy (xc ).float (), torch .from_numpy (y ).float ()
1308- basin , idx = self .lookup_table [item ]
1309- warmup_length = self .warmup_length
1310- x = self .x [
1311- basin ,
1312- idx - warmup_length + 1 : idx + self .rho + self .horizon + 1 ,
1313- :- 1 ,
1314- ]
1315- xy = self .x [basin , idx - warmup_length : idx + self .rho + self .horizon , - 1 ]
1316- xy = xy .reshape (xy .size , 1 )
1317- y = self .y [basin , idx : idx + self .rho + self .horizon , :]
1318- if self .c is None or self .c .shape [- 1 ] == 0 :
1319- return torch .from_numpy (x ).float (), torch .from_numpy (y ).float ()
1320- c = self .c [basin , :]
1321- c = np .repeat (c , x .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1322- xc = np .concatenate ((x , c ), axis = 1 )
1323- return [
1324- torch .from_numpy (xc ).float (),
1325- torch .from_numpy (xy ).float (),
1326- ], torch .from_numpy (y ).float ()
1327-
1328-
1329- class HFDataset (HoDataset ):
1330- def __init__ (self , data_cfgs : dict , is_tra_val_te : str ):
1331- super (HFDataset , self ).__init__ (data_cfgs , is_tra_val_te )
1332-
1333- def __getitem__ (self , item : int ):
1334- if not self .train_mode :
1335- # 先把forcing都取出来
1336- xf = self .x [item , 1 :, :- 1 ]
1337- xf = xf .reshape (1 , xf .shape [0 ], xf .shape [1 ])
1338- # 再把hindcast输入的streamflow取出来
1339- xq = self .x [item , : self .rho , - 1 ]
1340- xq = xq .reshape (xq .size , 1 )
1341- # 取hindcast部分的forcing
1342- xf_hind = xf [:, : self .rho , :]
1343- # 取forecast部分的forcing
1344- xf_fore = xf [:, self .rho :, :]
1345- # 取y
1346- y = self .y [item , :, :]
1347- # 取c
1348- c = self .c [item , :]
1349- # 转到二维
1350- xf_hind = xf_hind .squeeze (0 )
1351- xf_fore = xf_fore .squeeze (0 )
1352- # hindcast部分和c拼接
1353- hind_c = np .repeat (c , xf_hind .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1354- xf_hind_c = np .concatenate ((xf_hind , hind_c ), axis = 1 )
1355- x_hind_c = np .concatenate ((xf_hind_c , xq ), axis = - 1 )
1356- # forecast部分和c拼接
1357- fore_c = np .repeat (c , xf_fore .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1358- xf_fore_c = np .concatenate ((xf_fore , fore_c ), axis = 1 )
1359- return [
1360- torch .from_numpy (x_hind_c ).float (),
1361- torch .from_numpy (xf_fore_c ).float (),
1362- ], torch .from_numpy (y ).float ()
1363- basin , idx = self .lookup_table [item ]
1364- warmup_length = self .warmup_length
1365- # 先把hindcast和forecast的forcing取出来
1366- xf = self .x [
1367- basin ,
1368- idx - warmup_length + 1 : idx + self .rho + self .horizon + 1 ,
1369- :- 1 ,
1370- ]
1371- xf = xf .reshape (1 , xf .shape [0 ], xf .shape [1 ])
1372- xf_hind = xf [:, : self .rho , :]
1373- xf_fore = xf [:, self .rho :, :]
1374- # 再把hindcast和forecast输入的streamflow取出来
1375- xq = self .x [basin , idx - warmup_length : idx + self .rho + self .horizon , - 1 ]
1376- xq = xq .reshape (1 , xq .size , 1 )
1377- # 取hindcast部分的流量
1378- xq_hind = xq [:, : self .rho , :]
1379- # 取forecast部分的流量
1380- xq_fore = xq [:, : self .rho , :]
1381- # 取c
1236+ xf_rho = xf [: self .rho , :]
1237+ xf_hor = xf [self .rho :, :]
1238+ xq_rho = xq [: self .rho , :]
1239+ xq_hor = xq [self .rho :, :]
1240+ y = self .y [basin , idx + self .rho : idx + self .rho + self .horizon , :]
13821241 c = self .c [basin , :]
1383- # 取y
1384- y = self .y [basin , idx : idx + self .rho + self .horizon , :]
1385- # 转到二维
1386- xf_hind = xf_hind .squeeze (0 )
1387- xf_fore = xf_fore .squeeze (0 )
1388- xq_hind = xq_hind .squeeze (0 )
1389- xq_fore = xq_fore .squeeze (0 )
1390- # hindcast部分和c拼接
1391- hind_c = np .repeat (c , xf_hind .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1392- xf_hind_c = np .concatenate ((xf_hind , hind_c ), axis = 1 )
1393- x_hind_c = np .concatenate ((xf_hind_c , xq_hind ), axis = - 1 )
1394- # forecast部分和c拼接
1395- fore_c = np .repeat (c , xf_fore .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1396- xf_fore_c = np .concatenate ((xf_fore , fore_c ), axis = 1 )
1397-
1242+ c_rho = np .repeat (c , xf_rho .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1243+ c_hor = np .repeat (c , xf_hor .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
1244+ xfc_rho = np .concatenate ((xf_rho , c_rho ), axis = 1 )
1245+ xfc_hor = np .concatenate ((xf_hor , c_hor ), axis = 1 )
13981246 return [
1399- torch .from_numpy (x_hind_c ).float (),
1400- torch .from_numpy (xf_fore_c ).float (),
1401- torch .from_numpy (xq_fore ).float (),
1247+ torch .from_numpy (xfc_rho ).float (),
1248+ torch .from_numpy (xfc_hor ).float (),
1249+ torch .from_numpy (xq_rho ).float (),
1250+ torch .from_numpy (xq_hor ).float (),
14021251 ], torch .from_numpy (y ).float ()
0 commit comments