Skip to content

Commit cf3868f

Browse files
committed
Merge remote-tracking branch 'downstream/dev' into dev
2 parents 901ef39 + 38ba197 commit cf3868f

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

torchhydro/models/model_dict_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2021-12-31 11:08:29
4-
LastEditTime: 2025-01-14 19:30:54
4+
LastEditTime: 2025-01-25 09:33:46
55
LastEditors: Wenyu Ouyang
66
Description: Dicts including models (which are seq-first), losses, and optims
77
FilePath: /torchhydro/torchhydro/models/model_dict_function.py
@@ -19,6 +19,7 @@
1919
)
2020

2121
from torchhydro.models.simple_lstm import (
22+
LinearMultiLayerLSTMModel,
2223
LinearSimpleLSTMModel,
2324
MultiLayerLSTM,
2425
SimpleLSTM,
@@ -79,6 +80,7 @@
7980
"SimpleLSTM": SimpleLSTM,
8081
"LinearSimpleLSTMModel": LinearSimpleLSTMModel,
8182
"MultiLayerLSTM": MultiLayerLSTM,
83+
"LinearMultiLayerLSTMModel": LinearMultiLayerLSTMModel,
8284
}
8385

8486
pytorch_criterion_dict = {

torchhydro/models/simple_lstm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,25 @@ def forward(self, x):
7272
x0 = F.relu(self.former_linear(x))
7373
return super(LinearSimpleLSTMModel, self).forward(x0)
7474

75+
class LinearMultiLayerLSTMModel(MultiLayerLSTM):
76+
"""
77+
This model is nonlinear layer + MultiLayerLSTM.
78+
"""
79+
80+
def __init__(self, linear_size, **kwargs):
81+
"""
82+
83+
Parameters
84+
----------
85+
linear_size
86+
the number of input features for the first input linear layer
87+
"""
88+
super(LinearMultiLayerLSTMModel, self).__init__(**kwargs)
89+
self.former_linear = nn.Linear(linear_size, kwargs["input_size"])
90+
91+
def forward(self, x):
92+
x0 = F.relu(self.former_linear(x))
93+
return super(LinearMultiLayerLSTMModel, self).forward(x0)
7594

7695
class SimpleLSTMForecast(SimpleLSTM):
7796
def __init__(self, input_size, output_size, hidden_size, forecast_length, dr=0.0):

torchhydro/trainers/deep_hydro.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,12 @@ def load_model(self, mode="train"):
737737
raise NotImplementedError(
738738
"For transfer learning, we need a pre-trained model"
739739
)
740-
model = super().load_model(mode)
740+
if mode == "train":
741+
model = super().load_model(mode)
742+
elif mode == "infer":
743+
self.weight_path = self._get_trained_model()
744+
model = self._load_model_from_pth()
745+
model.to(self.device)
741746
if (
742747
"weight_path_add" in model_cfgs
743748
and "freeze_params" in model_cfgs["weight_path_add"]

0 commit comments

Comments
 (0)