Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test/config.py
  • Loading branch information
zhupr committed May 28, 2021
commit 98eacf8f88de66aa88ee877004a76c2a60f7c5f5
19 changes: 3 additions & 16 deletions examples/highfreq/workflow.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import sys
import fire
from pathlib import Path

import qlib
import pickle
import numpy as np
import pandas as pd
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)

from qlib.utils import init_instance_by_config, exists_qlib_data

from qlib.utils import init_instance_by_config
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
Expand Down Expand Up @@ -96,9 +85,7 @@ def _init_qlib(self):
# use yahoo_cn_1min data
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
qlib.init(**QLIB_INIT_CONFIG)

def _prepare_calender_cache(self):
Expand Down
58 changes: 14 additions & 44 deletions examples/hyperparameter/LightGBM/hyperparameter_158.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,9 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna

provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")

market = "csi300"
benchmark = "SH000300"

data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.config import CSI300_DATASET_CONFIG
from qlib.tests.data import GetData


def objective(trial):
Expand All @@ -65,12 +28,19 @@ def objective(trial):
},
},
}

evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])


study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
if __name__ == "__main__":

provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region="cn")

dataset = init_instance_by_config(CSI300_DATASET_CONFIG)

study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
57 changes: 15 additions & 42 deletions examples/hyperparameter/LightGBM/hyperparameter_360.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,11 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS

provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")

market = "csi300"
benchmark = "SH000300"

data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)


def objective(trial):
Expand Down Expand Up @@ -72,5 +37,13 @@ def objective(trial):
return min(evals_result["valid"])


study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
if __name__ == "__main__":

provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

dataset = init_instance_by_config(DATASET_CONFIG)

study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
81 changes: 0 additions & 81 deletions examples/model_interpreter.py

This file was deleted.

32 changes: 32 additions & 0 deletions examples/model_interpreter/feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


import qlib
from qlib.config import REG_CN

from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_GBDT_TASK


if __name__ == "__main__":

# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

qlib.init(provider_uri=provider_uri, region=REG_CN)

###################################
# train model
###################################
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
model.fit(dataset)

# get model feature importance
feature_importance = model.get_feature_importance()
print("feature importance:")
print(feature_importance)
62 changes: 4 additions & 58 deletions examples/model_rolling/task_manager_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,63 +17,7 @@
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM


data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}

dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}

record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]

# use lgb
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}

# use xgboost
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG


class RollingTaskExample:
Expand All @@ -85,11 +29,13 @@ def __init__(
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=[task_xgboost_config, task_lgb_config],
task_config=None,
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
if task_config is None:
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
Expand Down
Loading