Skip to content

Commit 60d45ad

Browse files
authored
Enhance pytorch nn (#917)
* enhance pytorch_nn * fix dim bug * Black format * Fix pylint error
1 parent 0e8b94a commit 60d45ad

File tree

13 files changed

+285
-143
lines changed

13 files changed

+285
-143
lines changed

.pylintrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[TYPECHECK]
2+
# https://stackoverflow.com/a/53572939
3+
# List of members which are set dynamically and missed by Pylint inference
4+
# system, and so shouldn't trigger E1101 when accessed.
5+
generated-members=numpy.*, torch.*

docs/developer/code_standard.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,19 @@ Continuous Integration (CI) tools help you stick to the quality standards by run
1414

1515
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
1616

17-
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
17+
1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
18+
You can fix the bug by inputing the following code in the command line.
1819

1920
.. code-block:: python
2021
2122
pip install black
2223
python -m black . -l 120
24+
25+
26+
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
27+
Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this
28+
29+
.. code-block:: python
30+
31+
return -ICLoss()(pred, target, index) # pylint: disable=E1130
32+

examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ task:
6363
module_path: qlib.contrib.model.pytorch_nn
6464
kwargs:
6565
loss: mse
66-
input_dim: 157
67-
output_dim: 1
6866
lr: 0.002
6967
lr_decay: 0.96
7068
lr_decay_steps: 100
@@ -73,6 +71,8 @@ task:
7371
batch_size: 8192
7472
GPU: 0
7573
weight_decay: 0.0002
74+
pt_model_kwargs:
75+
input_dim: 157
7676
dataset:
7777
class: DatasetH
7878
module_path: qlib.data.dataset

examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ task:
5151
module_path: qlib.contrib.model.pytorch_nn
5252
kwargs:
5353
loss: mse
54-
input_dim: 360
55-
output_dim: 1
5654
lr: 0.002
5755
lr_decay: 0.96
5856
lr_decay_steps: 100
5957
optimizer: adam
6058
max_steps: 8000
6159
batch_size: 4096
6260
GPU: 0
61+
pt_model_kwargs:
62+
input_dim: 360
6363
dataset:
6464
class: DatasetH
6565
module_path: qlib.data.dataset

qlib/contrib/meta/data_selection/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
class ICLoss(nn.Module):
1010
def forward(self, pred, y, idx, skip_size=50):
1111
"""forward.
12+
FIXME:
13+
- Some times it will be a slightly different from the result from `pandas.corr()`
14+
- It may be caused by the precision problem of model;
1215
1316
:param pred:
1417
:param y:

qlib/contrib/model/gbdt.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ...data.dataset.handler import DataHandlerLP
1111
from ...model.interpret.base import LightGBMFInt
1212
from ...data.dataset.weight import Reweighter
13+
from qlib.workflow import R
1314

1415

1516
class LGBModel(ModelFT, LightGBMFInt):
@@ -59,10 +60,12 @@ def fit(
5960
num_boost_round=None,
6061
early_stopping_rounds=None,
6162
verbose_eval=20,
62-
evals_result=dict(),
63+
evals_result=None,
6364
reweighter=None,
64-
**kwargs
65+
**kwargs,
6566
):
67+
if evals_result is None:
68+
evals_result = {} # in case of unsafety of Python default values
6669
ds_l = self._prepare_data(dataset, reweighter)
6770
ds, names = list(zip(*ds_l))
6871
self.model = lgb.train(
@@ -76,10 +79,13 @@ def fit(
7679
),
7780
verbose_eval=verbose_eval,
7881
evals_result=evals_result,
79-
**kwargs
82+
**kwargs,
8083
)
8184
for k in names:
82-
evals_result[k] = list(evals_result[k].values())[0]
85+
for key, val in evals_result[k].items():
86+
name = f"{key}.{k}"
87+
for epoch, m in enumerate(val):
88+
R.log_metrics(**{name.replace("@", "_"): m}, step=epoch)
8389

8490
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
8591
if self.model is None:

qlib/contrib/model/pytorch_gats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ def fit(
263263

264264
model_dict = self.GAT_model.state_dict()
265265
pretrained_dict = {
266-
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
267-
} # pylint: disable=E1135
266+
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
267+
}
268268
model_dict.update(pretrained_dict)
269269
self.GAT_model.load_state_dict(model_dict)
270270
self.logger.info("Loading pretrained model Done...")

qlib/contrib/model/pytorch_gats_ts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ def fit(
278278

279279
model_dict = self.GAT_model.state_dict()
280280
pretrained_dict = {
281-
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
282-
} # pylint: disable=E1135
281+
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
282+
}
283283
model_dict.update(pretrained_dict)
284284
self.GAT_model.load_state_dict(model_dict)
285285
self.logger.info("Loading pretrained model Done...")

0 commit comments

Comments
 (0)