From 5208da21d5025695d3272198e4c01b2d619066f1 Mon Sep 17 00:00:00 2001 From: Di Chen Date: Sat, 10 Jun 2023 12:07:26 +0800 Subject: [PATCH 1/9] Add multi pass port ana record --- qlib/workflow/record_temp.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index fdb3f6c92a1..235042a5b00 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -4,6 +4,7 @@ import logging import warnings import pandas as pd +import numpy as np from pprint import pprint from typing import Union, List, Optional @@ -554,3 +555,98 @@ def list(self): else: warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found") return list_path + +class MultiPassPortAnaRecord(PortAnaRecord): + """ + This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class. + + If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random. + The shuffle_init_score will only works when the signal is used as placeholder. The placeholder will be replaced by pred.pkl saved in recorder. + + The following files will be stored in recorder. + + - report_normal.pkl & positions_normal.pkl: + + - The return report and detailed positions of the backtest, returned by `qlib/contrib/evaluate.py:backtest` + - port_analysis.pkl : The last risk analysis of your portfolio, returned by `qlib/contrib/evaluate.py:risk_analysis` + - multi_pass_port_analysis.pkl: The aggregated risk analysis data from port_analysis.pkl + """ + depend_cls = None + + def __init__( + self, + recorder, + pass_num=10, + shuffle_init_score=True, + **kwargs + ): + self.pass_num = pass_num + self.shuffle_init_score = shuffle_init_score + + super().__init__(recorder, **kwargs) + + # Save original strategy so that pred df can be replaced in next generate + self.original_strategy = deepcopy_basic_type(self.strategy_config) + + def random_init(self): + pred_df = self.load("pred.pkl") + + all_pred_dates = pred_df.index.get_level_values("datetime") + bt_start_date = self.backtest_config.get("start_time") + if bt_start_date is None: + first_bt_pred_date = all_pred_dates.min() + else: + first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min() + + # Shuffle the first backtest date's pred score + first_date_score = pred_df.loc[first_bt_pred_date]["score"] + np.random.shuffle(first_date_score.values) + + # Reset strategy so that pred df can be replaced in next generate + self.strategy_config = deepcopy_basic_type(self.original_strategy) + + self.save(**{"pred.pkl": pred_df}) + + def _generate(self, **kwargs): + risk_analysis_df_map = {} + + # Collect each frequency's analysis df as df list + for i in range(self.pass_num): + if self.shuffle_init_score: + self.random_init() + + super()._generate(**kwargs) + + for _analysis_freq in self.risk_analysis_freq: + risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, []) + risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list + + analysis_df = self.load(f"port_analysis_{_analysis_freq}.pkl") + analysis_df["run_id"] = i + risk_analysis_df_list.append(analysis_df) + + # Concat df list + for _analysis_freq in self.risk_analysis_freq: + combined_df = pd.concat(risk_analysis_df_map[_analysis_freq]) + + # Calculate return and information ratio's mean, std and mean/std + multi_pass_port_analysis_df = combined_df.groupby(level=[0,1]).apply(lambda x: pd.Series({ + "mean": x["risk"].mean(), + "std": x["risk"].std(), + "mean_std": x["risk"].mean() / x["risk"].std() + })) + + # Only look at "annualized_return" and "information_ratio" + multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[(slice(None),["annualized_return", "information_ratio"]),:] + print(multi_pass_port_analysis_df) + + # Save new df + self.save(**{f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df}) + + # Log metrics + metrics = flatten_dict({ + "mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(), + "std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(), + "mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(), + }) + self.recorder.log_metrics(**metrics) From 529a145bb5cd6b794c8d766efc2ac60c19d8267c Mon Sep 17 00:00:00 2001 From: Di Chen Date: Sat, 10 Jun 2023 12:33:22 +0800 Subject: [PATCH 2/9] Add list function --- qlib/workflow/record_temp.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 235042a5b00..3ff826cef99 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -571,7 +571,7 @@ class MultiPassPortAnaRecord(PortAnaRecord): - port_analysis.pkl : The last risk analysis of your portfolio, returned by `qlib/contrib/evaluate.py:risk_analysis` - multi_pass_port_analysis.pkl: The aggregated risk analysis data from port_analysis.pkl """ - depend_cls = None + depend_cls = SignalRecord def __init__( self, @@ -615,6 +615,7 @@ def _generate(self, **kwargs): if self.shuffle_init_score: self.random_init() + # Not check for cache file list super()._generate(**kwargs) for _analysis_freq in self.risk_analysis_freq: @@ -638,7 +639,7 @@ def _generate(self, **kwargs): # Only look at "annualized_return" and "information_ratio" multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[(slice(None),["annualized_return", "information_ratio"]),:] - print(multi_pass_port_analysis_df) + pprint(multi_pass_port_analysis_df) # Save new df self.save(**{f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df}) @@ -650,3 +651,12 @@ def _generate(self, **kwargs): "mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(), }) self.recorder.log_metrics(**metrics) + + def list(self): + list_path = [] + for _analysis_freq in self.risk_analysis_freq: + if _analysis_freq in self.all_freq: + list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl") + else: + warnings.warn(f"risk_analysis freq {_analysis_freq} is not found") + return list_path From 6c99d8db9b983e029b98efe35fabd04e73835658 Mon Sep 17 00:00:00 2001 From: Di Chen Date: Sat, 10 Jun 2023 12:57:05 +0800 Subject: [PATCH 3/9] Add documentation and support tag --- qlib/workflow/record_temp.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 3ff826cef99..1bb81701e65 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -563,13 +563,14 @@ class MultiPassPortAnaRecord(PortAnaRecord): If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random. The shuffle_init_score will only works when the signal is used as placeholder. The placeholder will be replaced by pred.pkl saved in recorder. - The following files will be stored in recorder. - - - report_normal.pkl & positions_normal.pkl: - - - The return report and detailed positions of the backtest, returned by `qlib/contrib/evaluate.py:backtest` - - port_analysis.pkl : The last risk analysis of your portfolio, returned by `qlib/contrib/evaluate.py:risk_analysis` - - multi_pass_port_analysis.pkl: The aggregated risk analysis data from port_analysis.pkl + Parameters + ---------- + recorder : Recorder + The recorder used to save the backtest results. + pass_num : int + The number of backtest passes. + shuffle_init_score : bool + Whether to shuffle the prediction score of the first backtest date. """ depend_cls = SignalRecord @@ -580,6 +581,16 @@ def __init__( shuffle_init_score=True, **kwargs ): + """ + Parameters + ---------- + recorder : Recorder + The recorder used to save the backtest results. + pass_num : int + The number of backtest passes. + shuffle_init_score : bool + Whether to shuffle the prediction score of the first backtest date. + """ self.pass_num = pass_num self.shuffle_init_score = shuffle_init_score @@ -587,12 +598,14 @@ def __init__( # Save original strategy so that pred df can be replaced in next generate self.original_strategy = deepcopy_basic_type(self.strategy_config) + if isinstance(self.original_strategy, dict) and ("signal" in self.original_strategy.get("kwargs", {})): + self.original_strategy["kwargs"]["signal"] = "" def random_init(self): pred_df = self.load("pred.pkl") all_pred_dates = pred_df.index.get_level_values("datetime") - bt_start_date = self.backtest_config.get("start_time") + bt_start_date = pd.to_datetime(self.backtest_config.get("start_time")) if bt_start_date is None: first_bt_pred_date = all_pred_dates.min() else: From 82353b34b713013a123e55e2739f810cb22945bc Mon Sep 17 00:00:00 2001 From: Di Chen Date: Sat, 10 Jun 2023 13:00:00 +0800 Subject: [PATCH 4/9] Add drop in replacement example --- ..._config_linear_Alpha158_multi_pass_bt.yaml | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 examples/benchmarks/Linear/workflow_config_linear_Alpha158_multi_pass_bt.yaml diff --git a/examples/benchmarks/Linear/workflow_config_linear_Alpha158_multi_pass_bt.yaml b/examples/benchmarks/Linear/workflow_config_linear_Alpha158_multi_pass_bt.yaml new file mode 100644 index 00000000000..edd9d81e411 --- /dev/null +++ b/examples/benchmarks/Linear/workflow_config_linear_Alpha158_multi_pass_bt.yaml @@ -0,0 +1,78 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + - + - + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: LinearModel + module_path: qlib.contrib.model.linear + kwargs: + estimator: ols + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: True + ann_scaler: 252 + - class: MultiPassPortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config From 6a24f87df928d72de8aec3968b0f08491ddfd53e Mon Sep 17 00:00:00 2001 From: Di Chen Date: Sat, 10 Jun 2023 13:29:49 +0800 Subject: [PATCH 5/9] reformat --- qlib/workflow/record_temp.py | 61 ++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 1bb81701e65..43ad1e97f5a 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -137,7 +137,6 @@ def check(self, include_self: bool = False, parents: bool = True): whether the records are stored properly. """ if include_self: - # Some mlflow backend will not list the directly recursively. # So we force to the directly artifacts = {} @@ -556,6 +555,7 @@ def list(self): warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found") return list_path + class MultiPassPortAnaRecord(PortAnaRecord): """ This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class. @@ -572,15 +572,10 @@ class MultiPassPortAnaRecord(PortAnaRecord): shuffle_init_score : bool Whether to shuffle the prediction score of the first backtest date. """ + depend_cls = SignalRecord - def __init__( - self, - recorder, - pass_num=10, - shuffle_init_score=True, - **kwargs - ): + def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs): """ Parameters ---------- @@ -593,9 +588,9 @@ def __init__( """ self.pass_num = pass_num self.shuffle_init_score = shuffle_init_score - + super().__init__(recorder, **kwargs) - + # Save original strategy so that pred df can be replaced in next generate self.original_strategy = deepcopy_basic_type(self.strategy_config) if isinstance(self.original_strategy, dict) and ("signal" in self.original_strategy.get("kwargs", {})): @@ -603,21 +598,21 @@ def __init__( def random_init(self): pred_df = self.load("pred.pkl") - + all_pred_dates = pred_df.index.get_level_values("datetime") bt_start_date = pd.to_datetime(self.backtest_config.get("start_time")) if bt_start_date is None: first_bt_pred_date = all_pred_dates.min() else: first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min() - + # Shuffle the first backtest date's pred score first_date_score = pred_df.loc[first_bt_pred_date]["score"] np.random.shuffle(first_date_score.values) - + # Reset strategy so that pred df can be replaced in next generate self.strategy_config = deepcopy_basic_type(self.original_strategy) - + self.save(**{"pred.pkl": pred_df}) def _generate(self, **kwargs): @@ -627,10 +622,10 @@ def _generate(self, **kwargs): for i in range(self.pass_num): if self.shuffle_init_score: self.random_init() - + # Not check for cache file list super()._generate(**kwargs) - + for _analysis_freq in self.risk_analysis_freq: risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, []) risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list @@ -638,31 +633,35 @@ def _generate(self, **kwargs): analysis_df = self.load(f"port_analysis_{_analysis_freq}.pkl") analysis_df["run_id"] = i risk_analysis_df_list.append(analysis_df) - + # Concat df list for _analysis_freq in self.risk_analysis_freq: combined_df = pd.concat(risk_analysis_df_map[_analysis_freq]) - + # Calculate return and information ratio's mean, std and mean/std - multi_pass_port_analysis_df = combined_df.groupby(level=[0,1]).apply(lambda x: pd.Series({ - "mean": x["risk"].mean(), - "std": x["risk"].std(), - "mean_std": x["risk"].mean() / x["risk"].std() - })) - + multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply( + lambda x: pd.Series( + {"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()} + ) + ) + # Only look at "annualized_return" and "information_ratio" - multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[(slice(None),["annualized_return", "information_ratio"]),:] + multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[ + (slice(None), ["annualized_return", "information_ratio"]), : + ] pprint(multi_pass_port_analysis_df) - + # Save new df self.save(**{f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df}) # Log metrics - metrics = flatten_dict({ - "mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(), - "std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(), - "mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(), - }) + metrics = flatten_dict( + { + "mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(), + "std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(), + "mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(), + } + ) self.recorder.log_metrics(**metrics) def list(self): From 7f75aec87ec6901f830c2aab67007e139b224954 Mon Sep 17 00:00:00 2001 From: Di Chen Date: Sun, 18 Jun 2023 10:37:11 +0800 Subject: [PATCH 6/9] Change according to comments --- qlib/workflow/record_temp.py | 47 ++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 43ad1e97f5a..9c503062081 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -5,6 +5,7 @@ import warnings import pandas as pd import numpy as np +from tqdm import trange from pprint import pprint from typing import Union, List, Optional @@ -231,9 +232,16 @@ def generate(self, *args, **kwargs): except FileNotFoundError: logger.warning("The dependent data does not exists. Generation skipped.") return - return self._generate(*args, **kwargs) + artifact_dict = self._generate(*args, **kwargs) + if isinstance(artifact_dict, dict): + self.save(**artifact_dict) + return artifact_dict - def _generate(self, *args, **kwargs): + def _generate(self, *args, **kwargs) -> dict[str, object]: + """ + Run the concrete generating task, return the dictionary of the generated results. + The caller method will save the results to the recorder. + """ raise NotImplementedError(f"Please implement the `_generate` method") @@ -337,8 +345,8 @@ def _generate(self, label: Optional[pd.DataFrame] = None, **kwargs): } ) self.recorder.log_metrics(**metrics) - self.save(**objects) pprint(metrics) + return objects def list(self): paths = ["ic.pkl", "ric.pkl"] @@ -469,17 +477,18 @@ def _generate(self, **kwargs): if self.backtest_config["end_time"] is None: self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1) + artifact_objects = {} # custom strategy and get backtest portfolio_metric_dict, indicator_dict = normal_backtest( executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config ) for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items(): - self.save(**{f"report_normal_{_freq}.pkl": report_normal}) - self.save(**{f"positions_normal_{_freq}.pkl": positions_normal}) + artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal}) + artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal}) for _freq, indicators_normal in indicator_dict.items(): - self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]}) - self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]}) + artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]}) + artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]}) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq not in portfolio_metric_dict: @@ -501,7 +510,7 @@ def _generate(self, **kwargs): analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict()) self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}) + artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -526,12 +535,13 @@ def _generate(self, **kwargs): analysis_dict = analysis_df["value"].to_dict() self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}) + artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) pprint(f"The following are analysis results of indicators({_analysis_freq}).") pprint(analysis_df) + return artifact_objects def list(self): list_path = [] @@ -593,8 +603,8 @@ def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs): # Save original strategy so that pred df can be replaced in next generate self.original_strategy = deepcopy_basic_type(self.strategy_config) - if isinstance(self.original_strategy, dict) and ("signal" in self.original_strategy.get("kwargs", {})): - self.original_strategy["kwargs"]["signal"] = "" + if (not isinstance(self.original_strategy, dict)) or ("signal" not in self.original_strategy.get("kwargs", {}))): + raise Exception("MultiPassPortAnaRecord require the passed in strategy to be a dict and contains ['kwargs']['signal'] field") def random_init(self): pred_df = self.load("pred.pkl") @@ -610,30 +620,30 @@ def random_init(self): first_date_score = pred_df.loc[first_bt_pred_date]["score"] np.random.shuffle(first_date_score.values) - # Reset strategy so that pred df can be replaced in next generate + # Use shuffled signal as the strategy signal self.strategy_config = deepcopy_basic_type(self.original_strategy) - - self.save(**{"pred.pkl": pred_df}) + self.strategy_config["kwargs"]["signal"] = pred_df def _generate(self, **kwargs): risk_analysis_df_map = {} # Collect each frequency's analysis df as df list - for i in range(self.pass_num): + for i in trange(self.pass_num): if self.shuffle_init_score: self.random_init() # Not check for cache file list - super()._generate(**kwargs) + single_run_artifacts = super()._generate(**kwargs) for _analysis_freq in self.risk_analysis_freq: risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, []) risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list - analysis_df = self.load(f"port_analysis_{_analysis_freq}.pkl") + analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"] analysis_df["run_id"] = i risk_analysis_df_list.append(analysis_df) + result_artifacts = {} # Concat df list for _analysis_freq in self.risk_analysis_freq: combined_df = pd.concat(risk_analysis_df_map[_analysis_freq]) @@ -652,7 +662,7 @@ def _generate(self, **kwargs): pprint(multi_pass_port_analysis_df) # Save new df - self.save(**{f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df}) + result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df}) # Log metrics metrics = flatten_dict( @@ -663,6 +673,7 @@ def _generate(self, **kwargs): } ) self.recorder.log_metrics(**metrics) + return result_artifacts def list(self): list_path = [] From 5caf2a3510111ecda55d4e3e39650bef4d1880d5 Mon Sep 17 00:00:00 2001 From: Di Chen Date: Sun, 18 Jun 2023 10:49:34 +0800 Subject: [PATCH 7/9] update format --- qlib/workflow/record_temp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 9c503062081..2b8759d6139 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -603,8 +603,10 @@ def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs): # Save original strategy so that pred df can be replaced in next generate self.original_strategy = deepcopy_basic_type(self.strategy_config) - if (not isinstance(self.original_strategy, dict)) or ("signal" not in self.original_strategy.get("kwargs", {}))): - raise Exception("MultiPassPortAnaRecord require the passed in strategy to be a dict and contains ['kwargs']['signal'] field") + if not isinstance(self.original_strategy, dict): + raise Exception("MultiPassPortAnaRecord require the passed in strategy to be a dict") + if "signal" not in self.original_strategy.get("kwargs", {}): + raise Exception("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter") def random_init(self): pred_df = self.load("pred.pkl") From 75e274ff2fc4e369d604776fe6fb20d69510799c Mon Sep 17 00:00:00 2001 From: Di Date: Sat, 15 Jul 2023 16:53:10 +0800 Subject: [PATCH 8/9] Update record_temp.py Fix type hint --- qlib/workflow/record_temp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 2b8759d6139..aaf6e3361db 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -7,7 +7,7 @@ import numpy as np from tqdm import trange from pprint import pprint -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict from qlib.utils.exceptions import LoadObjectError from ..contrib.evaluate import risk_analysis, indicator_analysis @@ -237,7 +237,7 @@ def generate(self, *args, **kwargs): self.save(**artifact_dict) return artifact_dict - def _generate(self, *args, **kwargs) -> dict[str, object]: + def _generate(self, *args, **kwargs) -> Dict[str, object]: """ Run the concrete generating task, return the dictionary of the generated results. The caller method will save the results to the recorder. From 7a81f6fa94c0cd6eb46657f6f45a1a3fb16341c5 Mon Sep 17 00:00:00 2001 From: Di Date: Sat, 15 Jul 2023 17:01:40 +0800 Subject: [PATCH 9/9] Update record_temp.py --- qlib/workflow/record_temp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index aaf6e3361db..4c230e6e5e4 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -19,6 +19,7 @@ from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift from ..utils.time import Freq from ..utils.data import deepcopy_basic_type +from ..utils.exceptions import QlibException from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec @@ -604,9 +605,9 @@ def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs): # Save original strategy so that pred df can be replaced in next generate self.original_strategy = deepcopy_basic_type(self.strategy_config) if not isinstance(self.original_strategy, dict): - raise Exception("MultiPassPortAnaRecord require the passed in strategy to be a dict") + raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict") if "signal" not in self.original_strategy.get("kwargs", {}): - raise Exception("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter") + raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter") def random_init(self): pred_df = self.load("pred.pkl")