diff --git a/qlib/contrib/workflow/__init__.py b/qlib/contrib/workflow/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py new file mode 100644 index 00000000000..3fdf0c28157 --- /dev/null +++ b/qlib/contrib/workflow/record_temp.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import re +import pandas as pd +from sklearn.metrics import mean_squared_error +from pprint import pprint +import numpy as np + +from ...workflow.record_temp import SignalRecord +from ...log import get_module_logger + +logger = get_module_logger("workflow", "INFO") + + +class SignalMseRecord(SignalRecord): + """ + This is the Signal MSE Record class that computes the mean squared error (MSE). + This class inherits the ``SignalMseRecord`` class. + """ + + artifact_path = "sig_analysis" + + def __init__(self, recorder, **kwargs): + super().__init__(recorder=recorder, **kwargs) + + def generate(self, **kwargs): + try: + self.check(parent=True) + except FileExistsError: + super().generate() + + pred = self.load("pred.pkl") + label = self.load("label.pkl") + masks = ~np.isnan(label.values) + mse = mean_squared_error(pred.values[masks], label[masks]) + metrics = {"MSE": mse, "RMSE": np.sqrt(mse)} + objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)} + self.recorder.log_metrics(**metrics) + self.recorder.save_objects(**objects, artifact_path=self.get_path()) + pprint(metrics) + + def list(self): + paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")] + return paths diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index be458a24d29..2c1b6fecce0 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -110,7 +110,7 @@ class SignalRecord(RecordTemp): This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class. """ - def __init__(self, model=None, dataset=None, recorder=None, **kwargs): + def __init__(self, model=None, dataset=None, recorder=None): super().__init__(recorder=recorder) self.model = model self.dataset = dataset @@ -164,13 +164,15 @@ class SigAnaRecord(SignalRecord): artifact_path = "sig_analysis" def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs): + super().__init__(recorder=recorder, **kwargs) self.ana_long_short = ana_long_short self.ann_scaler = ann_scaler - super().__init__(recorder=recorder, **kwargs) - # The name must be unique. Otherwise it will be overridden - def generate(self): - self.check(parent=True) + def generate(self, **kwargs): + try: + self.check(parent=True) + except FileExistsError: + super().generate() pred = self.load("pred.pkl") label = self.load("label.pkl") @@ -228,7 +230,7 @@ def __init__(self, recorder, config, **kwargs): config["backtest"] : dict define the backtest kwargs. """ - super().__init__(recorder=recorder) + super().__init__(recorder=recorder, **kwargs) self.strategy_config = config["strategy"] self.backtest_config = config["backtest"] @@ -236,10 +238,13 @@ def __init__(self, recorder, config, **kwargs): def generate(self, **kwargs): # check previously stored prediction results - self.check(parent=True) # "Make sure the parent process is completed and store the data properly." + try: + self.check(parent=True) # "Make sure the parent process is completed and store the data properly." + except FileExistsError: + super().generate() # custom strategy and get backtest - pred_score = super().load() + pred_score = super().load("pred.pkl") report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config) report_normal = report_dict.get("report_df") positions_normal = report_dict.get("positions") diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index fbf15d29ad7..29d39179d1f 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -19,6 +19,7 @@ backtest as normal_backtest, risk_analysis, ) +from qlib.contrib.workflow.record_temp import SignalMseRecord from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord @@ -139,6 +140,38 @@ def train(): return pred_score, {"ic": ic, "ric": ric}, rid +def train_with_sigana(): + """train model followed by SigAnaRecord + + Returns + ------- + pred_score: pandas.DataFrame + predict scores + performance: dict + model performance + """ + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + + # start exp + with R.start(experiment_name="workflow_with_sigana"): + R.log_params(**flatten_dict(task)) + model.fit(dataset) + + # predict and calculate ic and ric + recorder = R.get_recorder() + sar = SigAnaRecord(recorder, model=model, dataset=dataset) + sar.generate() + ic = sar.load(sar.get_path("ic.pkl")) + ric = sar.load(sar.get_path("ric.pkl")) + pred_score = sar.load("pred.pkl") + + smr = SignalMseRecord(recorder) + smr.generate() + uri_path = R.get_uri() + return pred_score, {"ic": ic, "ric": ric}, uri_path + + def fake_experiment(): """A fake experiment workflow to test uri @@ -195,12 +228,18 @@ class TestAllFlow(TestAutoData): def tearDownClass(cls) -> None: shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve())) - def test_0_train(self): + def test_0_train_with_sigana(self): + TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana() + self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed") + self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed") + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + + def test_1_train(self): TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train() self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed") self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed") - def test_1_backtest(self): + def test_2_backtest(self): analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID) self.assertGreaterEqual( analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], @@ -208,7 +247,7 @@ def test_1_backtest(self): "backtest failed", ) - def test_2_expmanager(self): + def test_3_expmanager(self): pass_default, pass_current, uri_path = fake_experiment() self.assertTrue(pass_default, msg="default uri is incorrect") self.assertTrue(pass_current, msg="current uri is incorrect")