Skip to content

Commit 901ef39

Browse files
committed
Merge branch 'dev' of github.com:OuyangWenyu/torchhydro into dev
2 parents e687a5e + e11fd72 commit 901ef39

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

torchhydro/models/simple_lstm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2023-09-19 09:36:25
4-
LastEditTime: 2025-01-25 09:19:43
4+
LastEditTime: 2025-01-25 09:32:12
55
LastEditors: Wenyu Ouyang
66
Description: Some self-made LSTMs
77
FilePath: /torchhydro/torchhydro/models/simple_lstm.py
@@ -68,11 +68,9 @@ def __init__(self, linear_size, **kwargs):
6868
super(LinearSimpleLSTMModel, self).__init__(**kwargs)
6969
self.former_linear = nn.Linear(linear_size, kwargs["input_size"])
7070

71-
def forward(self, x, do_drop_mc=False, dropout_false=False):
71+
def forward(self, x):
7272
x0 = F.relu(self.former_linear(x))
73-
return super(LinearSimpleLSTMModel, self).forward(
74-
x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
75-
)
73+
return super(LinearSimpleLSTMModel, self).forward(x0)
7674

7775

7876
class SimpleLSTMForecast(SimpleLSTM):

torchhydro/trainers/resulter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ def save_result(self, pred, obs):
102102
save_dir = self.result_dir
103103
flow_pred_file = os.path.join(save_dir, self.pred_name)
104104
flow_obs_file = os.path.join(save_dir, self.obs_name)
105-
pred.to_netcdf(flow_pred_file + ".nc")
106-
obs.to_netcdf(flow_obs_file + ".nc")
105+
max_len = max(len(basin) for basin in pred.basin.values)
106+
encoding = {"basin": {"dtype": f"U{max_len}"}}
107+
pred.to_netcdf(flow_pred_file + ".nc", encoding=encoding)
108+
obs.to_netcdf(flow_obs_file + ".nc", encoding=encoding)
107109

108110
def eval_result(self, preds_xr, obss_xr):
109111
# types of observations

torchhydro/trainers/trainer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2021-12-05 11:21:58
4-
LastEditTime: 2024-05-23 15:21:17
4+
LastEditTime: 2025-01-14 19:29:34
55
LastEditors: Wenyu Ouyang
66
Description: Main function for training and testing
77
FilePath: \torchhydro\torchhydro\trainers\trainer.py
@@ -61,13 +61,19 @@ def train_and_evaluate(cfgs: Dict):
6161
set_random_seed(random_seed)
6262
resulter = Resulter(cfgs)
6363
deephydro = _get_deep_hydro(cfgs)
64-
if cfgs["training_cfgs"]["train_mode"] and (
64+
# if train_mode is False, we only evaluate the model
65+
train_mode = deephydro.cfgs["training_cfgs"]["train_mode"]
66+
# but if train_mode is True, we still need some conditions to train the model
67+
continue_train = deephydro.cfgs["model_cfgs"]["continue_train"]
68+
is_transfer_learning = deephydro.cfgs["model_cfgs"]["model_type"] == "TransLearn"
69+
is_train = train_mode and (
6570
(
6671
deephydro.weight_path is not None
67-
and deephydro.cfgs["model_cfgs"]["continue_train"]
72+
and (continue_train or is_transfer_learning)
6873
)
6974
or (deephydro.weight_path is None)
70-
):
75+
)
76+
if is_train:
7177
deephydro.model_train()
7278
preds, obss = deephydro.model_evaluate()
7379
resulter.save_cfg(deephydro.cfgs)

0 commit comments

Comments
 (0)