@@ -164,7 +164,7 @@ def load_norm_data(self, vars_data):
164164 },
165165 dims = ["basin" , "time" , "variable" ],
166166 )
167- else :
167+ elif v . ndim == 2 :
168168 num_instances , num_features = v .shape
169169 v_np = v .to_numpy ().reshape (- 1 , num_features )
170170 scaler , data_norm = self ._sklearn_scale (
@@ -179,6 +179,30 @@ def load_norm_data(self, vars_data):
179179 },
180180 dims = ["basin" , "variable" ],
181181 )
182+ elif v .ndim == 4 :
183+ # for forecast data
184+ num_instances , num_time_steps , num_lead_steps , num_features = v .shape
185+ v_np = v .to_numpy ().reshape (- 1 , num_features )
186+ scaler , data_norm = self ._sklearn_scale (
187+ self .data_cfgs , self .is_tra_val_te , scaler , k , v_np
188+ )
189+ data_norm = data_norm .reshape (
190+ num_instances , num_time_steps , num_lead_steps , num_features
191+ )
192+ norm_xrarray = xr .DataArray (
193+ data_norm ,
194+ coords = {
195+ "basin" : v .coords ["basin" ],
196+ "time" : v .coords ["time" ],
197+ "lead_step" : v .coords ["lead_step" ],
198+ "variable" : v .coords ["variable" ],
199+ },
200+ dims = ["basin" , "time" , "lead_step" , "variable" ],
201+ )
202+ else :
203+ raise NotImplementedError (
204+ "Please check your data, the dim of data must be 2, 3 or 4"
205+ )
182206
183207 norm_dict [k ] = norm_xrarray
184208 if k == "target_cols" :
0 commit comments