@@ -976,65 +976,100 @@ def _read_xyc(self):
976976 time_unit = self .data_cfgs ["min_time_unit" ]
977977
978978 # Determine the date format
979- date_format = detect_date_format (end_date )
979+ date_format = detect_date_format (start_date )
980980
981981 # Adjust the end date based on the time unit
982- end_date_dt = datetime .strptime (end_date , date_format )
982+ start_date_dt = datetime .strptime (start_date , date_format )
983983 if time_unit == "h" :
984- adjusted_end_date = (end_date_dt + timedelta (hours = interval )).strftime (
984+ adjusted_start_date = (start_date_dt - timedelta (hours = interval )).strftime (
985985 date_format
986986 )
987987 elif time_unit == "D" :
988- adjusted_end_date = (end_date_dt + timedelta (days = interval )).strftime (
988+ adjusted_start_date = (start_date_dt - timedelta (days = interval )).strftime (
989989 date_format
990990 )
991991 else :
992992 raise ValueError (f"Unsupported time unit: { time_unit } " )
993- return self ._read_xyc_specified_time (start_date , adjusted_end_date )
993+ return self ._read_xyc_specified_time (adjusted_start_date , end_date )
994994
995- def _normalize (self ):
996- x , y , c = super ()._normalize ()
997- # TODO: this work for minio? maybe better to move to basedataset
998- return x .compute (), y .compute (), c .compute ()
995+ def denormalize (self , norm_data , is_real_time = True ):
996+ """Denormalize the norm_data
997+
998+ Parameters
999+ ----------
1000+ norm_data : np.ndarray
1001+ batch-first data
1002+ is_real_time : bool, optional
1003+ whether the data is real time data, by default True
1004+ sometimes we may have multiple results for one time period and we flatten them
1005+ so we need a temp time to replace real one
1006+
1007+ Returns
1008+ -------
1009+ xr.Dataset
1010+ denormlized data
1011+ """
1012+ target_scaler = self .target_scaler
1013+ target_data = target_scaler .data_target
1014+ # the units are dimensionless for pure DL models
1015+ units = {k : "dimensionless" for k in target_data .attrs ["units" ].keys ()}
1016+ if target_scaler .pbm_norm :
1017+ units = {** units , ** target_data .attrs ["units" ]}
1018+ warmup_length = self .warmup_length
1019+ selected_time_points = target_data .coords ["time" ][warmup_length :- 1 ]
1020+ selected_data = target_data .sel (time = selected_time_points )
1021+ denorm_xr_ds = target_scaler .inverse_transform (
1022+ xr .DataArray (
1023+ norm_data ,
1024+ dims = selected_data .dims ,
1025+ coords = selected_data .coords ,
1026+ attrs = {"units" : units },
1027+ )
1028+ )
1029+ return set_unit_to_var (denorm_xr_ds )
9991030
10001031 def __getitem__ (self , item : int ):
10011032 basin , time = self .lookup_table [item ]
10021033 rho = self .rho
10031034 horizon = self .horizon
10041035 hindcast_output_window = self .data_cfgs .get ("hindcast_output_window" , 0 )
10051036 # p cover all encoder-decoder periods; +1 means the period while +0 means start of the current period
1006- p = self .x [basin , time + 1 : time + rho + horizon + 1 , 0 ]. reshape ( - 1 , 1 )
1037+ p = self .x [basin , time + 1 : time + rho + horizon + 1 , : 1 ]
10071038 # s only cover encoder periods
10081039 s = self .x [basin , time : time + rho , 1 :]
1009- x = np .concatenate ((p [:rho ], s ), axis = 1 )
1040+ # xe = np.concatenate((p[:rho], s), axis=1)
10101041
10111042 if self .c is None or self .c .shape [- 1 ] == 0 :
1012- xc = x
1043+ pc = p
10131044 else :
10141045 c = self .c [basin , :]
10151046 c = np .tile (c , (rho + horizon , 1 ))
1016- xc = np .concatenate ((x , c [:rho ]), axis = 1 )
1047+ pc = np .concatenate ((p [:rho ], c [:rho ]), axis = 1 )
1048+ xe = np .concatenate ((pc [:rho ], s ), axis = 1 )
10171049 # xh cover decoder periods
10181050 try :
1019- xh = np .concatenate ((p [rho :], c [rho :]), axis = 1 )
1051+ xd = np .concatenate ((p [rho :], c [rho :]), axis = 1 )
10201052 except ValueError as e :
10211053 print (f"Error in np.concatenate: { e } " )
10221054 print (f"p[rho:].shape: { p [rho :].shape } , c[rho:].shape: { c [rho :].shape } " )
10231055 raise
10241056 # y cover specified encoder size (hindcast_output_window) and all decoder periods
10251057 y = self .y [
10261058 basin , time + rho - hindcast_output_window + 1 : time + rho + horizon + 1 , :
1027- ]
1059+ ] # qs
1060+ # y_q = y[:, :1]
1061+ # y_s = y[:, 1:]
1062+ # y = np.concatenate((y_s, y_q), axis=1)
10281063
10291064 if self .is_tra_val_te == "train" :
10301065 return [
1031- torch .from_numpy (xc ).float (),
1032- torch .from_numpy (xh ).float (),
1066+ torch .from_numpy (xe ).float (),
1067+ torch .from_numpy (xd ).float (),
10331068 torch .from_numpy (y ).float (),
10341069 ], torch .from_numpy (y ).float ()
10351070 return [
1036- torch .from_numpy (xc ).float (),
1037- torch .from_numpy (xh ).float (),
1071+ torch .from_numpy (xe ).float (),
1072+ torch .from_numpy (xd ).float (),
10381073 ], torch .from_numpy (y ).float ()
10391074
10401075
@@ -1086,15 +1121,15 @@ def __getitem__(self, item: int):
10861121 ], torch .from_numpy (y ).float ()
10871122
10881123
1089- class BaseDatasetValidSame (BaseDataset ):
1124+ class ForecastDataset (BaseDataset ):
10901125 def __init__ (self , data_cfgs : dict , is_tra_val_te : str ):
1091- super (BaseDatasetValidSame , self ).__init__ (data_cfgs , is_tra_val_te )
1126+ super (ForecastDataset , self ).__init__ (data_cfgs , is_tra_val_te )
10921127
10931128 def __getitem__ (self , item ):
10941129 basin , idx = self .lookup_table [item ]
10951130 warmup_length = self .warmup_length
10961131 x = self .x [basin , idx - warmup_length : idx + self .rho + self .horizon , :]
1097- y = self .y [basin , idx : idx + self .rho + self .horizon , :]
1132+ y = self .y [basin , idx + self . rho : idx + self .rho + self .horizon , :]
10981133 if self .c is None or self .c .shape [- 1 ] == 0 :
10991134 return torch .from_numpy (x ).float (), torch .from_numpy (y ).float ()
11001135 c = self .c [basin , :]
0 commit comments