|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import qlib |
| 5 | +import fire |
| 6 | +import pickle |
| 7 | +import pandas as pd |
| 8 | + |
| 9 | +from datetime import datetime |
| 10 | +from qlib.config import REG_CN |
| 11 | +from qlib.data.dataset.handler import DataHandlerLP |
| 12 | +from qlib.contrib.data.handler import Alpha158 |
| 13 | +from qlib.utils import exists_qlib_data, init_instance_by_config |
| 14 | +from qlib.tests.data import GetData |
| 15 | + |
| 16 | + |
| 17 | +class RollingDataWorkflow: |
| 18 | + |
| 19 | + MARKET = "csi300" |
| 20 | + start_time = "2010-01-01" |
| 21 | + end_time = "2019-12-31" |
| 22 | + rolling_cnt = 5 |
| 23 | + |
| 24 | + def _init_qlib(self): |
| 25 | + """initialize qlib""" |
| 26 | + # use yahoo_cn_1min data |
| 27 | + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir |
| 28 | + if not exists_qlib_data(provider_uri): |
| 29 | + print(f"Qlib data is not found in {provider_uri}") |
| 30 | + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) |
| 31 | + qlib.init(provider_uri=provider_uri, region=REG_CN) |
| 32 | + |
| 33 | + def _dump_pre_handler(self, path): |
| 34 | + handler_config = { |
| 35 | + "class": "Alpha158", |
| 36 | + "module_path": "qlib.contrib.data.handler", |
| 37 | + "kwargs": { |
| 38 | + "start_time": self.start_time, |
| 39 | + "end_time": self.end_time, |
| 40 | + "instruments": self.MARKET, |
| 41 | + "infer_processors": [], |
| 42 | + "learn_processors": [], |
| 43 | + }, |
| 44 | + } |
| 45 | + pre_handler = init_instance_by_config(handler_config) |
| 46 | + pre_handler.config(dump_all=True) |
| 47 | + pre_handler.to_pickle(path) |
| 48 | + |
| 49 | + def _load_pre_handler(self, path): |
| 50 | + with open(path, "rb") as file_dataset: |
| 51 | + pre_handler = pickle.load(file_dataset) |
| 52 | + return pre_handler |
| 53 | + |
| 54 | + def rolling_process(self): |
| 55 | + self._init_qlib() |
| 56 | + self._dump_pre_handler("pre_handler.pkl") |
| 57 | + pre_handler = self._load_pre_handler("pre_handler.pkl") |
| 58 | + |
| 59 | + train_start_time = (2010, 1, 1) |
| 60 | + train_end_time = (2012, 12, 31) |
| 61 | + valid_start_time = (2013, 1, 1) |
| 62 | + valid_end_time = (2013, 12, 31) |
| 63 | + test_start_time = (2014, 1, 1) |
| 64 | + test_end_time = (2014, 12, 31) |
| 65 | + |
| 66 | + dataset_config = { |
| 67 | + "class": "DatasetH", |
| 68 | + "module_path": "qlib.data.dataset", |
| 69 | + "kwargs": { |
| 70 | + "handler": { |
| 71 | + "class": "RollingDataHandler", |
| 72 | + "module_path": "rolling_handler", |
| 73 | + "kwargs": { |
| 74 | + "start_time": datetime(*train_start_time), |
| 75 | + "end_time": datetime(*test_end_time), |
| 76 | + "fit_start_time": datetime(*train_start_time), |
| 77 | + "fit_end_time": datetime(*train_end_time), |
| 78 | + "infer_processors": [ |
| 79 | + {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}}, |
| 80 | + ], |
| 81 | + "learn_processors": [ |
| 82 | + {"class": "DropnaLabel"}, |
| 83 | + {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, |
| 84 | + ], |
| 85 | + "data_loader_kwargs": { |
| 86 | + "handler_config": pre_handler, |
| 87 | + }, |
| 88 | + }, |
| 89 | + }, |
| 90 | + "segments": { |
| 91 | + "train": (datetime(*train_start_time), datetime(*train_end_time)), |
| 92 | + "valid": (datetime(*valid_start_time), datetime(*valid_end_time)), |
| 93 | + "test": (datetime(*test_start_time), datetime(*test_end_time)), |
| 94 | + }, |
| 95 | + }, |
| 96 | + } |
| 97 | + |
| 98 | + dataset = init_instance_by_config(dataset_config) |
| 99 | + |
| 100 | + for rolling_offset in range(self.rolling_cnt): |
| 101 | + |
| 102 | + print(f"===========rolling{rolling_offset} start===========") |
| 103 | + if rolling_offset: |
| 104 | + dataset.config( |
| 105 | + handler_kwargs={ |
| 106 | + "start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), |
| 107 | + "end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), |
| 108 | + "processor_kwargs": { |
| 109 | + "fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), |
| 110 | + "fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), |
| 111 | + }, |
| 112 | + }, |
| 113 | + segments={ |
| 114 | + "train": ( |
| 115 | + datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), |
| 116 | + datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), |
| 117 | + ), |
| 118 | + "valid": ( |
| 119 | + datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]), |
| 120 | + datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]), |
| 121 | + ), |
| 122 | + "test": ( |
| 123 | + datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]), |
| 124 | + datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), |
| 125 | + ), |
| 126 | + }, |
| 127 | + ) |
| 128 | + dataset.setup_data( |
| 129 | + handler_kwargs={ |
| 130 | + "init_type": DataHandlerLP.IT_FIT_SEQ, |
| 131 | + } |
| 132 | + ) |
| 133 | + |
| 134 | + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) |
| 135 | + print(dtrain, dvalid, dtest) |
| 136 | + ## print or dump data |
| 137 | + print(f"===========rolling{rolling_offset} end===========") |
| 138 | + |
| 139 | + |
| 140 | +if __name__ == "__main__": |
| 141 | + fire.Fire(RollingDataWorkflow) |
0 commit comments