Skip to content

Commit f3ffe1c

Browse files
committed
sklearn scaler for 4-dim data
1 parent cab2e2a commit f3ffe1c

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

torchhydro/datasets/data_scalers.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def load_norm_data(self, vars_data):
164164
},
165165
dims=["basin", "time", "variable"],
166166
)
167-
else:
167+
elif v.ndim == 2:
168168
num_instances, num_features = v.shape
169169
v_np = v.to_numpy().reshape(-1, num_features)
170170
scaler, data_norm = self._sklearn_scale(
@@ -179,6 +179,30 @@ def load_norm_data(self, vars_data):
179179
},
180180
dims=["basin", "variable"],
181181
)
182+
elif v.ndim == 4:
183+
# for forecast data
184+
num_instances, num_time_steps, num_lead_steps, num_features = v.shape
185+
v_np = v.to_numpy().reshape(-1, num_features)
186+
scaler, data_norm = self._sklearn_scale(
187+
self.data_cfgs, self.is_tra_val_te, scaler, k, v_np
188+
)
189+
data_norm = data_norm.reshape(
190+
num_instances, num_time_steps, num_lead_steps, num_features
191+
)
192+
norm_xrarray = xr.DataArray(
193+
data_norm,
194+
coords={
195+
"basin": v.coords["basin"],
196+
"time": v.coords["time"],
197+
"lead_step": v.coords["lead_step"],
198+
"variable": v.coords["variable"],
199+
},
200+
dims=["basin", "time", "lead_step", "variable"],
201+
)
202+
else:
203+
raise NotImplementedError(
204+
"Please check your data, the dim of data must be 2, 3 or 4"
205+
)
182206

183207
norm_dict[k] = norm_xrarray
184208
if k == "target_cols":

0 commit comments

Comments
 (0)