From 1e39cdfd63a52ea54485e2b7fbddc596fca0f3d7 Mon Sep 17 00:00:00 2001 From: Christian Clauss Date: Mon, 13 Sep 2021 19:47:47 +0200 Subject: [PATCH 1/3] Update pytorch_tabnet.py $ `flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics` ``` ./qlib/qlib/contrib/model/pytorch_tabnet.py:567:38: F821 undefined name 'inp' self.independ.append(GLU(inp, out_dim, vbs=vbs)) ^ ./qlib/examples/model_rolling/task_manager_rolling.py:75:18: F821 undefined name 'task_train' run_task(task_train, self.task_pool, experiment_name=self.experiment_name) ^ 2 F821 undefined name 'task_train' 2 ``` --- qlib/contrib/model/pytorch_tabnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index b05d9a026d0..bd8f085ec1d 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -564,7 +564,7 @@ def __init__(self, inp_dim, out_dim, shared, n_ind, vbs): self.shared = None self.independ = nn.ModuleList() if first: - self.independ.append(GLU(inp, out_dim, vbs=vbs)) + self.independ.append(GLU(inp_dim, out_dim, vbs=vbs)) for x in range(first, n_ind): self.independ.append(GLU(out_dim, out_dim, vbs=vbs)) self.scale = float(np.sqrt(0.5)) From c86228df8973ebeb6ca1571b552aa4d0ac7ee0e0 Mon Sep 17 00:00:00 2001 From: Christian Clauss Date: Mon, 13 Sep 2021 19:53:02 +0200 Subject: [PATCH 2/3] Fix undefined names in Python code --- examples/model_rolling/task_manager_rolling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 844f1819804..50489c9bca0 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -72,7 +72,7 @@ def task_training(self, tasks): def worker(self): # train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker. print("========== worker ==========") - run_task(task_train, self.task_pool, experiment_name=self.experiment_name) + run_task(self.task_training, self.task_pool, experiment_name=self.experiment_name) def task_collecting(self): print("========== task_collecting ==========") From 2d3bebc61f2d52131cde9492f0520bdf62c8b9a7 Mon Sep 17 00:00:00 2001 From: Christian Clauss Date: Mon, 13 Sep 2021 19:57:59 +0200 Subject: [PATCH 3/3] from qlib.model.trainer import task_train --- examples/model_rolling/task_manager_rolling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 50489c9bca0..091a87862fe 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,7 +17,7 @@ from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM +from qlib.model.trainer import TrainerRM, task_train from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG @@ -72,7 +72,7 @@ def task_training(self, tasks): def worker(self): # train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker. print("========== worker ==========") - run_task(self.task_training, self.task_pool, experiment_name=self.experiment_name) + run_task(task_train, self.task_pool, experiment_name=self.experiment_name) def task_collecting(self): print("========== task_collecting ==========")