From 94eb94592f58a5ba7e354931aa2083beee3d9233 Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Wed, 21 Jul 2021 13:19:07 +0800 Subject: [PATCH 1/7] refactor TRA --- .../TRA/workflow_config_tra_Alpha158.yaml | 125 +++ .../workflow_config_tra_Alpha158_full.yaml | 118 +++ .../TRA/workflow_config_tra_Alpha360.yaml | 119 +++ qlib/contrib/data/dataset.py | 349 ++++++++ qlib/contrib/model/pytorch_tra.py | 820 ++++++++++++++++++ qlib/data/dataset/loader.py | 4 + 6 files changed, 1535 insertions(+) create mode 100644 examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml create mode 100644 examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml create mode 100644 examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml create mode 100644 qlib/contrib/data/dataset.py create mode 100644 qlib/contrib/model/pytorch_tra.py diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml new file mode 100644 index 00000000000..59b1c8e734c --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -0,0 +1,125 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: LR_TPE + +model_config: &model_config + input_size: 20 + hidden_size: 64 + num_layers: 2 + rnn_arch: LSTM + use_attn: True + dropout: 0.0 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + seed: 0 + logdir: output/Alpha158/router + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: False + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml new file mode 100644 index 00000000000..bb49798d4cd --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -0,0 +1,118 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: TPE + +model_config: &model_config + input_size: 158 + hidden_size: 256 + num_layers: 2 + use_attn: True + dropout: 0.2 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + seed: 0 + logdir: output/Alpha158_full/router + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: False + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml new file mode 100644 index 00000000000..99c7aa42eeb --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -0,0 +1,119 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: LR_TPE + +model_config: &model_config + input_size: 6 + hidden_size: 64 + num_layers: 2 + rnn_arch: LSTM + use_attn: True + dropout: 0.0 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + logdir: output/Alpha360/router + seed: 0 + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: False + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha360 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: 6 + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py new file mode 100644 index 00000000000..8989a61566d --- /dev/null +++ b/qlib/contrib/data/dataset.py @@ -0,0 +1,349 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import copy +import torch +import warnings +import numpy as np +import pandas as pd + +from qlib.utils import init_instance_by_config +from qlib.data.dataset import DatasetH, DataHandler + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def _to_tensor(x): + if not isinstance(x, torch.Tensor): + return torch.tensor(x, dtype=torch.float, device=device) + return x + + +def _create_ts_slices(index, seq_len): + """ + create time series slices from pandas index + + Args: + index (pd.MultiIndex): pandas multiindex with order + seq_len (int): sequence length + """ + assert isinstance(index, pd.MultiIndex), "unsupported index type" + assert seq_len > 0, "sequence length should be larger than 0" + assert index.is_monotonic_increasing, "index should be sorted" + + # number of dates for each instrument + sample_count_by_insts = index.to_series().groupby(level=0).size().values + + # start index for each instrument + start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1) + start_index_of_insts[0] = 0 + + # all the [start, stop) indices of features + # features between [start, stop) will be used to predict label at `stop - 1` + slices = [] + for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts): + for stop in range(1, cur_cnt + 1): + end = cur_loc + stop + start = max(end - seq_len, 0) + slices.append(slice(start, end)) + slices = np.array(slices, dtype="object") + + assert len(slices) == len(index) # the i-th slice = index[i] + + return slices + + +def _get_date_parse_fn(target): + """get date parse function + + This method is used to parse date arguments as target type. + + Example: + get_date_parse_fn('20120101')('2017-01-01') => '20170101' + get_date_parse_fn(20120101)('2017-01-01') => 20170101 + """ + if isinstance(target, pd.Timestamp): + _fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01') + elif isinstance(target, int): + _fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201 + elif isinstance(target, str) and len(target) == 8: + _fn = lambda x: str(x).replace("-", "")[:8] # '20200201' + else: + _fn = lambda x: x # '2021-01-01' + return _fn + + +def _maybe_padding(x, seq_len, zeros=None): + """padding 2d