Skip to content

Commit 29f4e11

Browse files
authored
Merge pull request microsoft#374 from bxdd/qlib_loaderhandler
Add DataLoader Based on DataHandler & Add Rolling Process Example & Restructure the Config & Setup_data
2 parents 272aa00 + 8facd69 commit 29f4e11

File tree

8 files changed

+372
-71
lines changed

8 files changed

+372
-71
lines changed

examples/highfreq/workflow.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@
2727
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
2828

2929

30-
class HighfreqWorkflow(object):
30+
class HighfreqWorkflow:
3131

3232
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
3333

3434
MARKET = "all"
35-
BENCHMARK = "SH000300"
3635

3736
start_time = "2020-09-15 00:00:00"
3837
end_time = "2021-01-18 16:00:00"
@@ -146,35 +145,40 @@ def dump_and_load_dataset(self):
146145

147146
self._prepare_calender_cache()
148147
##=============reinit dataset=============
149-
dataset.init(
148+
dataset.config(
150149
handler_kwargs={
151-
"init_type": DataHandlerLP.IT_LS,
152150
"start_time": "2021-01-19 00:00:00",
153151
"end_time": "2021-01-25 16:00:00",
154152
},
155-
segment_kwargs={
153+
segments={
156154
"test": (
157155
"2021-01-19 00:00:00",
158156
"2021-01-25 16:00:00",
159157
),
160158
},
161159
)
162-
dataset_backtest.init(
160+
dataset.setup_data(
161+
handler_kwargs={
162+
"init_type": DataHandlerLP.IT_LS,
163+
},
164+
)
165+
dataset_backtest.config(
163166
handler_kwargs={
164167
"start_time": "2021-01-19 00:00:00",
165168
"end_time": "2021-01-25 16:00:00",
166169
},
167-
segment_kwargs={
170+
segments={
168171
"test": (
169172
"2021-01-19 00:00:00",
170173
"2021-01-25 16:00:00",
171174
),
172175
},
173176
)
177+
dataset_backtest.setup_data(handler_kwargs={})
174178

175179
##=============get data=============
176-
xtest = dataset.prepare(["test"])
177-
backtest_test = dataset_backtest.prepare(["test"])
180+
xtest = dataset.prepare("test")
181+
backtest_test = dataset_backtest.prepare("test")
178182

179183
print(xtest, backtest_test)
180184
return
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Rolling Process Data
2+
3+
This workflow is an example for `Rolling Process Data`.
4+
5+
## Background
6+
7+
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
8+
9+
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
10+
11+
12+
## Run the Code
13+
14+
Run the example by running the following command:
15+
```bash
16+
python workflow.py rolling_process
17+
```
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from qlib.data.dataset.handler import DataHandlerLP
2+
from qlib.data.dataset.loader import DataLoaderDH
3+
from qlib.contrib.data.handler import check_transform_proc
4+
5+
6+
class RollingDataHandler(DataHandlerLP):
7+
def __init__(
8+
self,
9+
start_time=None,
10+
end_time=None,
11+
infer_processors=[],
12+
learn_processors=[],
13+
fit_start_time=None,
14+
fit_end_time=None,
15+
data_loader_kwargs={},
16+
):
17+
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
18+
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
19+
20+
data_loader = {
21+
"class": "DataLoaderDH",
22+
"kwargs": {**data_loader_kwargs},
23+
}
24+
25+
super().__init__(
26+
instruments=None,
27+
start_time=start_time,
28+
end_time=end_time,
29+
data_loader=data_loader,
30+
infer_processors=infer_processors,
31+
learn_processors=learn_processors,
32+
)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import qlib
5+
import fire
6+
import pickle
7+
import pandas as pd
8+
9+
from datetime import datetime
10+
from qlib.config import REG_CN
11+
from qlib.data.dataset.handler import DataHandlerLP
12+
from qlib.contrib.data.handler import Alpha158
13+
from qlib.utils import exists_qlib_data, init_instance_by_config
14+
from qlib.tests.data import GetData
15+
16+
17+
class RollingDataWorkflow:
18+
19+
MARKET = "csi300"
20+
start_time = "2010-01-01"
21+
end_time = "2019-12-31"
22+
rolling_cnt = 5
23+
24+
def _init_qlib(self):
25+
"""initialize qlib"""
26+
# use yahoo_cn_1min data
27+
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
28+
if not exists_qlib_data(provider_uri):
29+
print(f"Qlib data is not found in {provider_uri}")
30+
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
31+
qlib.init(provider_uri=provider_uri, region=REG_CN)
32+
33+
def _dump_pre_handler(self, path):
34+
handler_config = {
35+
"class": "Alpha158",
36+
"module_path": "qlib.contrib.data.handler",
37+
"kwargs": {
38+
"start_time": self.start_time,
39+
"end_time": self.end_time,
40+
"instruments": self.MARKET,
41+
"infer_processors": [],
42+
"learn_processors": [],
43+
},
44+
}
45+
pre_handler = init_instance_by_config(handler_config)
46+
pre_handler.config(dump_all=True)
47+
pre_handler.to_pickle(path)
48+
49+
def _load_pre_handler(self, path):
50+
with open(path, "rb") as file_dataset:
51+
pre_handler = pickle.load(file_dataset)
52+
return pre_handler
53+
54+
def rolling_process(self):
55+
self._init_qlib()
56+
self._dump_pre_handler("pre_handler.pkl")
57+
pre_handler = self._load_pre_handler("pre_handler.pkl")
58+
59+
train_start_time = (2010, 1, 1)
60+
train_end_time = (2012, 12, 31)
61+
valid_start_time = (2013, 1, 1)
62+
valid_end_time = (2013, 12, 31)
63+
test_start_time = (2014, 1, 1)
64+
test_end_time = (2014, 12, 31)
65+
66+
dataset_config = {
67+
"class": "DatasetH",
68+
"module_path": "qlib.data.dataset",
69+
"kwargs": {
70+
"handler": {
71+
"class": "RollingDataHandler",
72+
"module_path": "rolling_handler",
73+
"kwargs": {
74+
"start_time": datetime(*train_start_time),
75+
"end_time": datetime(*test_end_time),
76+
"fit_start_time": datetime(*train_start_time),
77+
"fit_end_time": datetime(*train_end_time),
78+
"infer_processors": [
79+
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
80+
],
81+
"learn_processors": [
82+
{"class": "DropnaLabel"},
83+
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
84+
],
85+
"data_loader_kwargs": {
86+
"handler_config": pre_handler,
87+
},
88+
},
89+
},
90+
"segments": {
91+
"train": (datetime(*train_start_time), datetime(*train_end_time)),
92+
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
93+
"test": (datetime(*test_start_time), datetime(*test_end_time)),
94+
},
95+
},
96+
}
97+
98+
dataset = init_instance_by_config(dataset_config)
99+
100+
for rolling_offset in range(self.rolling_cnt):
101+
102+
print(f"===========rolling{rolling_offset} start===========")
103+
if rolling_offset:
104+
dataset.config(
105+
handler_kwargs={
106+
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
107+
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
108+
"processor_kwargs": {
109+
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
110+
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
111+
},
112+
},
113+
segments={
114+
"train": (
115+
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
116+
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
117+
),
118+
"valid": (
119+
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
120+
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
121+
),
122+
"test": (
123+
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
124+
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
125+
),
126+
},
127+
)
128+
dataset.setup_data(
129+
handler_kwargs={
130+
"init_type": DataHandlerLP.IT_FIT_SEQ,
131+
}
132+
)
133+
134+
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
135+
print(dtrain, dvalid, dtest)
136+
## print or dump data
137+
print(f"===========rolling{rolling_offset} end===========")
138+
139+
140+
if __name__ == "__main__":
141+
fire.Fire(RollingDataWorkflow)

0 commit comments

Comments
 (0)