diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 844f1819804..091a87862fe 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,7 +17,7 @@ from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM +from qlib.model.trainer import TrainerRM, task_train from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index b05d9a026d0..bd8f085ec1d 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -564,7 +564,7 @@ def __init__(self, inp_dim, out_dim, shared, n_ind, vbs): self.shared = None self.independ = nn.ModuleList() if first: - self.independ.append(GLU(inp, out_dim, vbs=vbs)) + self.independ.append(GLU(inp_dim, out_dim, vbs=vbs)) for x in range(first, n_ind): self.independ.append(GLU(out_dim, out_dim, vbs=vbs)) self.scale = float(np.sqrt(0.5))