@@ -277,6 +277,8 @@ def __getitem__(self, item: int):
277277
278278 def _pre_load_data (self ):
279279 self .train_mode = self .is_tra_val_te == "train"
280+ self .valid_mode = self .is_tra_val_te == "valid"
281+ self .test_mode = self .is_tra_val_te == "test"
280282 self .t_s_dict = wrap_t_s_dict (self .data_cfgs , self .is_tra_val_te )
281283 self .rho = self .data_cfgs ["hindcast_length" ]
282284 self .warmup_length = self .data_cfgs ["warmup_length" ]
@@ -680,6 +682,24 @@ def __init__(self, data_cfgs: dict, is_tra_val_te: str):
680682 self .lead_time_start , self .lead_time_start + horizon
681683 )
682684 self .horizon_offset = offset
685+ feature_mapping = self .data_cfgs ["feature_mapping" ]
686+ #
687+ xf_var_indices = {}
688+ for obs_var , fore_var in feature_mapping .items ():
689+ # 找到x中需要被替换的变量索引
690+ x_var_indice = [
691+ i
692+ for i , var in enumerate (self .data_cfgs ["relevant_cols" ])
693+ if var == obs_var
694+ ][0 ]
695+ # 找到f中对应的变量索引
696+ f_var_indice = [
697+ i
698+ for i , var in enumerate (self .data_cfgs ["forecast_cols" ])
699+ if var == fore_var
700+ ][0 ]
701+ xf_var_indices [x_var_indice ] = f_var_indice
702+ self .xf_var_indices = xf_var_indices
683703
684704 def _read_xyc_specified_time (self , start_date , end_date , ** kwargs ):
685705 """read f data from data source with specified time range and add it to the whole dict"""
@@ -707,19 +727,18 @@ def __getitem__(self, item: int):
707727 Parameters
708728 ----------
709729 item : int
710- 样本索引
730+ index of sample
711731
712732 Returns
713733 -------
714734 tuple
715- (x, y) 数据对,其中 x 包含输入特征和预见期标志,y 包含目标值
735+ A pair of (x, y) data, where x contains input features and lead time flags,
736+ and y contains target values
716737 """
717- # 获取基础数据
718- if not self .train_mode :
738+ if not (self .train_mode and self .valid_mode ):
719739 x = self .x [item , :, :]
720740 y = self .y [item , :, :]
721741 f = self .f [item , :, :]
722- # 添加预见期标志到输入特征
723742 if self .c is None or self .c .shape [- 1 ] == 0 :
724743 xc = np .concatenate ((x , f ), axis = 1 )
725744 else :
@@ -729,25 +748,47 @@ def __getitem__(self, item: int):
729748
730749 return torch .from_numpy (xc ).float (), torch .from_numpy (y ).float ()
731750
732- # 训练模式
751+ # train mode
733752 basin , idx = self .lookup_table [item ]
734753 warmup_length = self .warmup_length
754+ # for x, we only chose data before horizon, but we may need forecast data for not all variables
755+ # hence, to avoid nan values for some variables without forecast in horizon
756+ # we still get data from the first time to the end of horizon
735757 x = self .x [basin , idx - warmup_length : idx + self .rho + self .horizon , :]
758+ # for y, we chose data after warmup_length
736759 y = self .y [basin , idx : idx + self .rho + self .horizon , :]
737-
738760 # use offset to get forecast data
739- f = self .f [basin , idx : idx + self .rho + self .horizon , :]
740-
741- # 添加预见期标志到输入特征
761+ offset = self .horizon_offset
762+ if self .lead_time_type == "fixed" :
763+ # Fixed lead_time mode - All forecast steps use the same lead_step
764+ f = self .f [
765+ basin , idx + self .rho : idx + self .rho + self .horizon , offset [0 ], :
766+ ]
767+ else :
768+ # Increasing lead_time mode - Each forecast step uses a different lead_step
769+ f = self .f [basin , idx + self .rho , offset , :]
770+ xf = self ._concat_xf (x , f )
742771 if self .c is None or self .c .shape [- 1 ] == 0 :
743- xfc = np . concatenate (( x , f ), axis = 1 )
772+ xfc = xf
744773 else :
745774 c = self .c [basin , :]
746- c = np .repeat (c , x .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
747- xfc = np .concatenate ((x , c , f ), axis = 1 )
775+ c = np .repeat (c , xf .shape [0 ], axis = 0 ).reshape (c .shape [0 ], - 1 ).T
776+ xfc = np .concatenate ((xf , c ), axis = 1 )
748777
749778 return torch .from_numpy (xfc ).float (), torch .from_numpy (y ).float ()
750779
780+ def _concat_xf (self , x , f ):
781+ # Create a copy of x to avoid modifying the original data
782+ x_combined = x .copy ()
783+
784+ # Iterate through the variable mapping relationship
785+ for x_idx , f_idx in self .xf_var_indices .items ():
786+ # Replace the variables in the forecast period of x with the forecast variables in f
787+ # The forecast period of x starts from the rho position
788+ x_combined [self .rho :, x_idx ] = f [:, f_idx ]
789+
790+ return x_combined
791+
751792
752793class BasinSingleFlowDataset (BaseDataset ):
753794 """one time length output for each grid in a batch"""
0 commit comments