Skip to content

Commit 91037f2

Browse files
authored
Add A New Baseline: ADD (microsoft#704)
1 parent 680d8cb commit 91037f2

File tree

7 files changed

+703
-1
lines changed

7 files changed

+703
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ Here is a list of models built on `Qlib`.
298298
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](examples/benchmarks/TRA/)
299299
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
300300
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
301+
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
301302
302303
Your PR of new Quant models is highly welcomed.
303304

examples/benchmarks/ADD/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# AdaRNN
2+
* Paper: [ADD: Augmented Disentanglement Distillation Framework for Improving Stock Trend Forecasting](https://arxiv.org/abs/2012.06289).
3+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy==1.17.4
2+
pandas==1.1.2
3+
scikit_learn==0.23.2
4+
torch==1.7.0
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
qlib_init:
2+
provider_uri: "~/.qlib/qlib_data/cn_data"
3+
region: cn
4+
market: &market csi300
5+
benchmark: &benchmark SH000300
6+
data_handler_config: &data_handler_config
7+
start_time: 2008-01-01
8+
end_time: 2020-08-01
9+
fit_start_time: 2008-01-01
10+
fit_end_time: 2014-12-31
11+
instruments: *market
12+
infer_processors:
13+
- class: RobustZScoreNorm
14+
kwargs:
15+
fields_group: feature
16+
clip_outlier: true
17+
- class: Fillna
18+
kwargs:
19+
fields_group: feature
20+
learn_processors:
21+
- class: DropnaLabel
22+
- class: CSRankNorm
23+
kwargs:
24+
fields_group: label
25+
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
26+
port_analysis_config: &port_analysis_config
27+
strategy:
28+
class: TopkDropoutStrategy
29+
module_path: qlib.contrib.strategy
30+
kwargs:
31+
signal:
32+
- <MODEL>
33+
- <DATASET>
34+
topk: 50
35+
n_drop: 5
36+
backtest:
37+
start_time: 2017-01-01
38+
end_time: 2020-08-01
39+
account: 100000000
40+
benchmark: *benchmark
41+
exchange_kwargs:
42+
limit_threshold: 0.095
43+
deal_price: close
44+
open_cost: 0.0005
45+
close_cost: 0.0015
46+
min_cost: 5
47+
task:
48+
model:
49+
class: ADD
50+
module_path: qlib.contrib.model.pytorch_add
51+
kwargs:
52+
d_feat: 6
53+
hidden_size: 64
54+
num_layers: 2
55+
dropout: 0.1
56+
dec_dropout: 0.0
57+
n_epochs: 200
58+
lr: 1e-3
59+
early_stop: 20
60+
batch_size: 5000
61+
metric: ic
62+
base_model: GRU
63+
gamma: 0.1
64+
gamma_clip: 0.2
65+
optimizer: adam
66+
mu: 0.2
67+
GPU: 0
68+
dataset:
69+
class: DatasetH
70+
module_path: qlib.data.dataset
71+
kwargs:
72+
handler:
73+
class: Alpha360
74+
module_path: qlib.contrib.data.handler
75+
kwargs: *data_handler_config
76+
segments:
77+
train: [2008-01-01, 2014-12-31]
78+
valid: [2015-01-01, 2016-12-31]
79+
test: [2017-01-01, 2020-08-01]
80+
record:
81+
- class: SignalRecord
82+
module_path: qlib.workflow.record_temp
83+
kwargs:
84+
model: <MODEL>
85+
dataset: <DATASET>
86+
- class: SigAnaRecord
87+
module_path: qlib.workflow.record_temp
88+
kwargs:
89+
ana_long_short: False
90+
ann_scaler: 252
91+
- class: PortAnaRecord
92+
module_path: qlib.workflow.record_temp
93+
kwargs:
94+
config: *port_analysis_config

examples/benchmarks/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
5656
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
5757
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
5858
| LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 |
59+
| ADD | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02 | 0.8992±0.34 | -0.0855±0.02 |
5960
| GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 |
6061
| AdaRNN(Yuntao Du, et al.) | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03 | 1.0200±0.40 | -0.0936±0.03 |
6162
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |

qlib/contrib/model/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@
3131
from .pytorch_tabnet import TabnetModel
3232
from .pytorch_sfm import SFM_Model
3333
from .pytorch_tcn import TCN
34+
from .pytorch_add import ADD
3435

35-
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN)
36+
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
3637
except ModuleNotFoundError:
3738
pytorch_classes = ()
3839
print("Please install necessary libs for PyTorch models.")

0 commit comments

Comments
 (0)