Skip to content

Commit aa552fd

Browse files
authored
Merge pull request #345 from D-X-Y/main
Fix errors when SignalRecord is not called before SigAna/PortAna
2 parents d47e35d + 872ddc6 commit aa552fd

File tree

4 files changed

+100
-11
lines changed

4 files changed

+100
-11
lines changed

qlib/contrib/workflow/__init__.py

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import re
5+
import pandas as pd
6+
from sklearn.metrics import mean_squared_error
7+
from pprint import pprint
8+
import numpy as np
9+
10+
from ...workflow.record_temp import SignalRecord
11+
from ...log import get_module_logger
12+
13+
logger = get_module_logger("workflow", "INFO")
14+
15+
16+
class SignalMseRecord(SignalRecord):
17+
"""
18+
This is the Signal MSE Record class that computes the mean squared error (MSE).
19+
This class inherits the ``SignalMseRecord`` class.
20+
"""
21+
22+
artifact_path = "sig_analysis"
23+
24+
def __init__(self, recorder, **kwargs):
25+
super().__init__(recorder=recorder, **kwargs)
26+
27+
def generate(self, **kwargs):
28+
try:
29+
self.check(parent=True)
30+
except FileExistsError:
31+
super().generate()
32+
33+
pred = self.load("pred.pkl")
34+
label = self.load("label.pkl")
35+
masks = ~np.isnan(label.values)
36+
mse = mean_squared_error(pred.values[masks], label[masks])
37+
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
38+
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
39+
self.recorder.log_metrics(**metrics)
40+
self.recorder.save_objects(**objects, artifact_path=self.get_path())
41+
pprint(metrics)
42+
43+
def list(self):
44+
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
45+
return paths

qlib/workflow/record_temp.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class SignalRecord(RecordTemp):
110110
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
111111
"""
112112

113-
def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
113+
def __init__(self, model=None, dataset=None, recorder=None):
114114
super().__init__(recorder=recorder)
115115
self.model = model
116116
self.dataset = dataset
@@ -164,13 +164,15 @@ class SigAnaRecord(SignalRecord):
164164
artifact_path = "sig_analysis"
165165

166166
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
167+
super().__init__(recorder=recorder, **kwargs)
167168
self.ana_long_short = ana_long_short
168169
self.ann_scaler = ann_scaler
169-
super().__init__(recorder=recorder, **kwargs)
170-
# The name must be unique. Otherwise it will be overridden
171170

172-
def generate(self):
173-
self.check(parent=True)
171+
def generate(self, **kwargs):
172+
try:
173+
self.check(parent=True)
174+
except FileExistsError:
175+
super().generate()
174176

175177
pred = self.load("pred.pkl")
176178
label = self.load("label.pkl")
@@ -228,18 +230,21 @@ def __init__(self, recorder, config, **kwargs):
228230
config["backtest"] : dict
229231
define the backtest kwargs.
230232
"""
231-
super().__init__(recorder=recorder)
233+
super().__init__(recorder=recorder, **kwargs)
232234

233235
self.strategy_config = config["strategy"]
234236
self.backtest_config = config["backtest"]
235237
self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)
236238

237239
def generate(self, **kwargs):
238240
# check previously stored prediction results
239-
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
241+
try:
242+
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
243+
except FileExistsError:
244+
super().generate()
240245

241246
# custom strategy and get backtest
242-
pred_score = super().load()
247+
pred_score = super().load("pred.pkl")
243248
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
244249
report_normal = report_dict.get("report_df")
245250
positions_normal = report_dict.get("positions")

tests/test_all_pipeline.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
backtest as normal_backtest,
2020
risk_analysis,
2121
)
22+
from qlib.contrib.workflow.record_temp import SignalMseRecord
2223
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
2324
from qlib.workflow import R
2425
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
@@ -139,6 +140,38 @@ def train():
139140
return pred_score, {"ic": ic, "ric": ric}, rid
140141

141142

143+
def train_with_sigana():
144+
"""train model followed by SigAnaRecord
145+
146+
Returns
147+
-------
148+
pred_score: pandas.DataFrame
149+
predict scores
150+
performance: dict
151+
model performance
152+
"""
153+
model = init_instance_by_config(task["model"])
154+
dataset = init_instance_by_config(task["dataset"])
155+
156+
# start exp
157+
with R.start(experiment_name="workflow_with_sigana"):
158+
R.log_params(**flatten_dict(task))
159+
model.fit(dataset)
160+
161+
# predict and calculate ic and ric
162+
recorder = R.get_recorder()
163+
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
164+
sar.generate()
165+
ic = sar.load(sar.get_path("ic.pkl"))
166+
ric = sar.load(sar.get_path("ric.pkl"))
167+
pred_score = sar.load("pred.pkl")
168+
169+
smr = SignalMseRecord(recorder)
170+
smr.generate()
171+
uri_path = R.get_uri()
172+
return pred_score, {"ic": ic, "ric": ric}, uri_path
173+
174+
142175
def fake_experiment():
143176
"""A fake experiment workflow to test uri
144177
@@ -195,20 +228,26 @@ class TestAllFlow(TestAutoData):
195228
def tearDownClass(cls) -> None:
196229
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
197230

198-
def test_0_train(self):
231+
def test_0_train_with_sigana(self):
232+
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
233+
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
234+
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
235+
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
236+
237+
def test_1_train(self):
199238
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
200239
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
201240
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
202241

203-
def test_1_backtest(self):
242+
def test_2_backtest(self):
204243
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
205244
self.assertGreaterEqual(
206245
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
207246
0.10,
208247
"backtest failed",
209248
)
210249

211-
def test_2_expmanager(self):
250+
def test_3_expmanager(self):
212251
pass_default, pass_current, uri_path = fake_experiment()
213252
self.assertTrue(pass_default, msg="default uri is incorrect")
214253
self.assertTrue(pass_current, msg="current uri is incorrect")

0 commit comments

Comments
 (0)