Skip to content

Commit 9e14062

Browse files
committed
concat x and f for training
1 parent a8dec71 commit 9e14062

File tree

2 files changed

+69
-14
lines changed

2 files changed

+69
-14
lines changed

torchhydro/configs/config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def default_config_file():
170170
# hence we list them as a seperate type
171171
"forecast_cols": None,
172172
"forecast_rm_nan": True,
173+
# same variable but different names for obs and forecast
174+
# key is obs and value is forecast
175+
"feature_mapping": {
176+
"total_precipitation_hourly": "total_precipitation_surface",
177+
},
173178
# global variables such as ENSO indictors are used in some long term models
174179
"global_cols": None,
175180
# specify the data source of each variable
@@ -373,11 +378,12 @@ def cmd(
373378
weight_path_add=None,
374379
var_t_type=None,
375380
var_f=None,
381+
f_rm_nan=None,
382+
feature_mapping=None,
376383
var_g=None,
377384
var_out=None,
378385
var_to_source_map=None,
379386
out_rm_nan=None,
380-
f_rm_nan=None,
381387
target_as_input=None,
382388
constant_only=None,
383389
gage_id_screen=None,
@@ -724,6 +730,12 @@ def cmd(
724730
default=var_f,
725731
type=json.loads,
726732
)
733+
parser.add_argument(
734+
"--feature_mapping",
735+
type=json.loads,
736+
help="same variables from obs and forecast",
737+
default=feature_mapping,
738+
)
727739
parser.add_argument(
728740
"--var_g",
729741
dest="var_g",
@@ -1049,6 +1061,8 @@ def update_cfg(cfg_file, new_args):
10491061
cfg_file["data_cfgs"]["target_rm_nan"] = bool(new_args.out_rm_nan > 0)
10501062
if new_args.f_rm_nan is not None:
10511063
cfg_file["data_cfgs"]["forecast_rm_nan"] = bool(new_args.f_rm_nan > 0)
1064+
if new_args.feature_mapping is not None:
1065+
cfg_file["data_cfgs"]["feature_mapping"] = new_args.feature_mapping
10521066
if new_args.target_as_input is not None:
10531067
cfg_file["data_cfgs"]["target_as_input"] = bool(new_args.target_as_input > 0)
10541068
if new_args.constant_only is not None:

torchhydro/datasets/data_sets.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ def __getitem__(self, item: int):
277277

278278
def _pre_load_data(self):
279279
self.train_mode = self.is_tra_val_te == "train"
280+
self.valid_mode = self.is_tra_val_te == "valid"
281+
self.test_mode = self.is_tra_val_te == "test"
280282
self.t_s_dict = wrap_t_s_dict(self.data_cfgs, self.is_tra_val_te)
281283
self.rho = self.data_cfgs["hindcast_length"]
282284
self.warmup_length = self.data_cfgs["warmup_length"]
@@ -680,6 +682,24 @@ def __init__(self, data_cfgs: dict, is_tra_val_te: str):
680682
self.lead_time_start, self.lead_time_start + horizon
681683
)
682684
self.horizon_offset = offset
685+
feature_mapping = self.data_cfgs["feature_mapping"]
686+
#
687+
xf_var_indices = {}
688+
for obs_var, fore_var in feature_mapping.items():
689+
# 找到x中需要被替换的变量索引
690+
x_var_indice = [
691+
i
692+
for i, var in enumerate(self.data_cfgs["relevant_cols"])
693+
if var == obs_var
694+
][0]
695+
# 找到f中对应的变量索引
696+
f_var_indice = [
697+
i
698+
for i, var in enumerate(self.data_cfgs["forecast_cols"])
699+
if var == fore_var
700+
][0]
701+
xf_var_indices[x_var_indice] = f_var_indice
702+
self.xf_var_indices = xf_var_indices
683703

684704
def _read_xyc_specified_time(self, start_date, end_date, **kwargs):
685705
"""read f data from data source with specified time range and add it to the whole dict"""
@@ -707,19 +727,18 @@ def __getitem__(self, item: int):
707727
Parameters
708728
----------
709729
item : int
710-
样本索引
730+
index of sample
711731
712732
Returns
713733
-------
714734
tuple
715-
(x, y) 数据对,其中 x 包含输入特征和预见期标志,y 包含目标值
735+
A pair of (x, y) data, where x contains input features and lead time flags,
736+
and y contains target values
716737
"""
717-
# 获取基础数据
718-
if not self.train_mode:
738+
if not (self.train_mode and self.valid_mode):
719739
x = self.x[item, :, :]
720740
y = self.y[item, :, :]
721741
f = self.f[item, :, :]
722-
# 添加预见期标志到输入特征
723742
if self.c is None or self.c.shape[-1] == 0:
724743
xc = np.concatenate((x, f), axis=1)
725744
else:
@@ -729,25 +748,47 @@ def __getitem__(self, item: int):
729748

730749
return torch.from_numpy(xc).float(), torch.from_numpy(y).float()
731750

732-
# 训练模式
751+
# train mode
733752
basin, idx = self.lookup_table[item]
734753
warmup_length = self.warmup_length
754+
# for x, we only chose data before horizon, but we may need forecast data for not all variables
755+
# hence, to avoid nan values for some variables without forecast in horizon
756+
# we still get data from the first time to the end of horizon
735757
x = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, :]
758+
# for y, we chose data after warmup_length
736759
y = self.y[basin, idx : idx + self.rho + self.horizon, :]
737-
738760
# use offset to get forecast data
739-
f = self.f[basin, idx : idx + self.rho + self.horizon, :]
740-
741-
# 添加预见期标志到输入特征
761+
offset = self.horizon_offset
762+
if self.lead_time_type == "fixed":
763+
# Fixed lead_time mode - All forecast steps use the same lead_step
764+
f = self.f[
765+
basin, idx + self.rho : idx + self.rho + self.horizon, offset[0], :
766+
]
767+
else:
768+
# Increasing lead_time mode - Each forecast step uses a different lead_step
769+
f = self.f[basin, idx + self.rho, offset, :]
770+
xf = self._concat_xf(x, f)
742771
if self.c is None or self.c.shape[-1] == 0:
743-
xfc = np.concatenate((x, f), axis=1)
772+
xfc = xf
744773
else:
745774
c = self.c[basin, :]
746-
c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
747-
xfc = np.concatenate((x, c, f), axis=1)
775+
c = np.repeat(c, xf.shape[0], axis=0).reshape(c.shape[0], -1).T
776+
xfc = np.concatenate((xf, c), axis=1)
748777

749778
return torch.from_numpy(xfc).float(), torch.from_numpy(y).float()
750779

780+
def _concat_xf(self, x, f):
781+
# Create a copy of x to avoid modifying the original data
782+
x_combined = x.copy()
783+
784+
# Iterate through the variable mapping relationship
785+
for x_idx, f_idx in self.xf_var_indices.items():
786+
# Replace the variables in the forecast period of x with the forecast variables in f
787+
# The forecast period of x starts from the rho position
788+
x_combined[self.rho :, x_idx] = f[:, f_idx]
789+
790+
return x_combined
791+
751792

752793
class BasinSingleFlowDataset(BaseDataset):
753794
"""one time length output for each grid in a batch"""

0 commit comments

Comments
 (0)