Skip to content

Commit a267d65

Browse files
authored
Merge pull request microsoft#464 from lwwang1995/main
Add TCTS baseline.
2 parents 78896f2 + 0aa3cc7 commit a267d65

File tree

5 files changed

+538
-0
lines changed

5 files changed

+538
-0
lines changed

examples/benchmarks/TCTS/TCTS.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Temporally Correlated Task Scheduling for Sequence Learning
2+
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
3+
4+
### Background
5+
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
6+
7+
<p align="center">
8+
<img src="task_description.png" width="600" height="200"/>
9+
</p>
10+
11+
12+
### Method
13+
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
14+
15+
<p align="center">
16+
<img src="workflow.png"/>
17+
</p>
18+
19+
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
20+
21+
### DataSet
22+
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
23+
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
24+
25+
### Experiments
26+
#### Task Description
27+
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
28+
<div align=center>
29+
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
30+
</div>
31+
32+
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
33+
#### Baselines
34+
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
35+
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
36+
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
37+
#### Result
38+
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
39+
| :----: | :----: | :----: | :----: |
40+
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
41+
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
42+
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
43+
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
44+
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
45+
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
46+
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
47+
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
48+
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
49+
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
50+
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
51+
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
52+
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
24.6 KB
Loading
29.2 KB
Loading
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
qlib_init:
2+
provider_uri: "~/.qlib/qlib_data/cn_data"
3+
region: cn
4+
market: &market csi300
5+
benchmark: &benchmark SH000300
6+
data_handler_config: &data_handler_config
7+
start_time: 2008-01-01
8+
end_time: 2020-08-01
9+
fit_start_time: 2008-01-01
10+
fit_end_time: 2014-12-31
11+
instruments: *market
12+
infer_processors:
13+
- class: RobustZScoreNorm
14+
kwargs:
15+
fields_group: feature
16+
clip_outlier: true
17+
- class: Fillna
18+
kwargs:
19+
fields_group: feature
20+
learn_processors:
21+
- class: DropnaLabel
22+
- class: CSRankNorm
23+
kwargs:
24+
fields_group: label
25+
label: ["Ref($close, -2) / Ref($close, -1) - 1",
26+
"Ref($close, -3) / Ref($close, -1) - 1",
27+
"Ref($close, -4) / Ref($close, -1) - 1",
28+
"Ref($close, -5) / Ref($close, -1) - 1",
29+
"Ref($close, -6) / Ref($close, -1) - 1"]
30+
port_analysis_config: &port_analysis_config
31+
strategy:
32+
class: TopkDropoutStrategy
33+
module_path: qlib.contrib.strategy.strategy
34+
kwargs:
35+
topk: 50
36+
n_drop: 5
37+
backtest:
38+
verbose: False
39+
limit_threshold: 0.095
40+
account: 100000000
41+
benchmark: *benchmark
42+
deal_price: close
43+
open_cost: 0.0005
44+
close_cost: 0.0015
45+
min_cost: 5
46+
task:
47+
model:
48+
class: TCTS
49+
module_path: qlib.contrib.model.pytorch_tcts
50+
kwargs:
51+
d_feat: 6
52+
hidden_size: 64
53+
num_layers: 2
54+
dropout: 0.0
55+
n_epochs: 200
56+
lr: 1e-3
57+
early_stop: 20
58+
batch_size: 800
59+
metric: loss
60+
loss: mse
61+
GPU: 0
62+
fore_optimizer: adam
63+
weight_optimizer: adam
64+
output_dim: 5
65+
fore_lr: 5e-7
66+
weight_lr: 5e-7
67+
steps: 3
68+
target_label: 0
69+
dataset:
70+
class: DatasetH
71+
module_path: qlib.data.dataset
72+
kwargs:
73+
handler:
74+
class: Alpha360
75+
module_path: qlib.contrib.data.handler
76+
kwargs: *data_handler_config
77+
segments:
78+
train: [2008-01-01, 2014-12-31]
79+
valid: [2015-01-01, 2016-12-31]
80+
test: [2017-01-01, 2020-08-01]
81+
record:
82+
- class: SignalRecord
83+
module_path: qlib.workflow.record_temp
84+
kwargs: {}
85+
- class: SigAnaRecord
86+
module_path: qlib.workflow.record_temp
87+
kwargs:
88+
ana_long_short: False
89+
ann_scaler: 252
90+
- class: PortAnaRecord
91+
module_path: qlib.workflow.record_temp
92+
kwargs:
93+
config: *port_analysis_config

0 commit comments

Comments
 (0)