Skip to content

Commit 6f71f8a

Browse files
PalanQuJiabao Qu
andauthored
chore: remove hard code input dimension of model pytorch_tcts (#843)
Co-authored-by: Jiabao Qu <qujiabao@logiocean.com>
1 parent edd8bad commit 6f71f8a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

qlib/contrib/model/pytorch_tcts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
loss="mse",
5757
fore_optimizer="adam",
5858
weight_optimizer="adam",
59+
input_dim=360,
5960
output_dim=5,
6061
fore_lr=5e-7,
6162
weight_lr=5e-7,
@@ -83,6 +84,7 @@ def __init__(
8384
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
8485
self.use_gpu = torch.cuda.is_available()
8586
self.seed = seed
87+
self.input_dim = input_dim
8688
self.output_dim = output_dim
8789
self.fore_lr = fore_lr
8890
self.weight_lr = weight_lr
@@ -139,7 +141,6 @@ def loss_fn(self, pred, label, weight):
139141
raise NotImplementedError("mode {} is not supported!".format(self.mode))
140142

141143
def train_epoch(self, x_train, y_train, x_valid, y_valid):
142-
143144
x_train_values = x_train.values
144145
y_train_values = np.squeeze(y_train.values)
145146

@@ -297,7 +298,7 @@ def training(
297298
dropout=self.dropout,
298299
)
299300
self.weight_model = MLPModel(
300-
d_feat=360 + 3 * self.output_dim + 1,
301+
d_feat=self.input_dim + 3 * self.output_dim + 1,
301302
hidden_size=self.hidden_size,
302303
num_layers=self.num_layers,
303304
dropout=self.dropout,

0 commit comments

Comments
 (0)