Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions qlib/workflow/record_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class SignalRecord(RecordTemp):
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
"""

def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
def __init__(self, model=None, dataset=None, recorder=None):
super().__init__(recorder=recorder)
self.model = model
self.dataset = dataset
Expand Down Expand Up @@ -164,13 +164,15 @@ class SigAnaRecord(SignalRecord):
artifact_path = "sig_analysis"

def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
super().__init__(recorder=recorder, **kwargs)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
super().__init__(recorder=recorder, **kwargs)
# The name must be unique. Otherwise it will be overridden

def generate(self):
self.check(parent=True)
def generate(self, **kwargs):
try:
self.check(parent=True)
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please specify the specific error type to avoid unexpected outcomes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, added :)

super().generate()

pred = self.load("pred.pkl")
label = self.load("label.pkl")
Expand Down Expand Up @@ -228,18 +230,21 @@ def __init__(self, recorder, config, **kwargs):
config["backtest"] : dict
define the backtest kwargs.
"""
super().__init__(recorder=recorder)
super().__init__(recorder=recorder, **kwargs)

self.strategy_config = config["strategy"]
self.backtest_config = config["backtest"]
self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)

def generate(self, **kwargs):
# check previously stored prediction results
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
try:
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, added :)

super().generate()

# custom strategy and get backtest
pred_score = super().load()
pred_score = super().load("pred.pkl")
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
report_normal = report_dict.get("report_df")
positions_normal = report_dict.get("positions")
Expand Down
41 changes: 38 additions & 3 deletions tests/test_all_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,35 @@ def train():
return pred_score, {"ic": ic, "ric": ric}, rid


def train_with_sigana():
"""train model followed by SigAnaRecord

Returns
-------
pred_score: pandas.DataFrame
predict scores
performance: dict
model performance
"""
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

# start exp
with R.start(experiment_name="workflow_with_sigana"):
R.log_params(**flatten_dict(task))
model.fit(dataset)

# predict and calculate ic and ric
recorder = R.get_recorder()
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
pred_score = sar.load("pred.pkl")
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path


def fake_experiment():
"""A fake experiment workflow to test uri

Expand Down Expand Up @@ -195,20 +224,26 @@ class TestAllFlow(TestAutoData):
def tearDownClass(cls) -> None:
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))

def test_0_train(self):
def test_0_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))

def test_1_train(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")

def test_1_backtest(self):
def test_2_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
0.10,
"backtest failed",
)

def test_2_expmanager(self):
def test_3_expmanager(self):
pass_default, pass_current, uri_path = fake_experiment()
self.assertTrue(pass_default, msg="default uri is incorrect")
self.assertTrue(pass_current, msg="current uri is incorrect")
Expand Down