@@ -56,39 +56,8 @@ def train(uri_path: str = None):
5656 ic = sar .load ("ic.pkl" )
5757 ric = sar .load ("ric.pkl" )
5858
59- return pred_score , {"ic" : ic , "ric" : ric }, rid
60-
61-
62- def train_with_sigana (uri_path : str = None ):
63- """train model followed by SigAnaRecord
64-
65- Returns
66- -------
67- pred_score: pandas.DataFrame
68- predict scores
69- performance: dict
70- model performance
71- """
72- model = init_instance_by_config (CSI300_GBDT_TASK ["model" ])
73- dataset = init_instance_by_config (CSI300_GBDT_TASK ["dataset" ])
74- # start exp
75- with R .start (experiment_name = "workflow_with_sigana" , uri = uri_path ):
76- R .log_params (** flatten_dict (CSI300_GBDT_TASK ))
77- model .fit (dataset )
78- recorder = R .get_recorder ()
79-
80- sr = SignalRecord (model , dataset , recorder )
81- sr .generate ()
82- pred_score = sr .load ("pred.pkl" )
83-
84- # predict and calculate ic and ric
85- sar = SigAnaRecord (recorder )
86- sar .generate ()
87- ic = sar .load ("ic.pkl" )
88- ric = sar .load ("ric.pkl" )
89-
9059 uri_path = R .get_uri ()
91- return pred_score , {"ic" : ic , "ric" : ric }, uri_path
60+ return pred_score , {"ic" : ic , "ric" : ric }, rid , uri_path
9261
9362
9463def fake_experiment ():
@@ -186,19 +155,13 @@ def tearDownClass(cls) -> None:
186155 shutil .rmtree (cls .URI_PATH .lstrip ("file:" ))
187156
188157 @pytest .mark .slow
189- def test_0_train_with_sigana (self ):
190- TestAllFlow .PRED_SCORE , ic_ric , uri_path = train_with_sigana (self .URI_PATH )
191- self .assertGreaterEqual (ic_ric ["ic" ].all (), 0 , "train failed" )
192- self .assertGreaterEqual (ic_ric ["ric" ].all (), 0 , "train failed" )
193-
194- @pytest .mark .slow
195- def test_1_train (self ):
196- TestAllFlow .PRED_SCORE , ic_ric , TestAllFlow .RID = train (self .URI_PATH )
158+ def test_0_train (self ):
159+ TestAllFlow .PRED_SCORE , ic_ric , TestAllFlow .RID , uri_path = train (self .URI_PATH )
197160 self .assertGreaterEqual (ic_ric ["ic" ].all (), 0 , "train failed" )
198161 self .assertGreaterEqual (ic_ric ["ric" ].all (), 0 , "train failed" )
199162
200163 @pytest .mark .slow
201- def test_2_backtest (self ):
164+ def test_1_backtest (self ):
202165 analyze_df = backtest_analysis (TestAllFlow .PRED_SCORE , TestAllFlow .RID , self .URI_PATH )
203166 self .assertGreaterEqual (
204167 analyze_df .loc (axis = 0 )["excess_return_with_cost" , "annualized_return" ].values [0 ],
@@ -208,7 +171,7 @@ def test_2_backtest(self):
208171 self .assertTrue (not analyze_df .isna ().any ().any (), "backtest failed" )
209172
210173 @pytest .mark .slow
211- def test_3_expmanager (self ):
174+ def test_2_expmanager (self ):
212175 pass_default , pass_current , uri_path = fake_experiment ()
213176 self .assertTrue (pass_default , msg = "default uri is incorrect" )
214177 self .assertTrue (pass_current , msg = "current uri is incorrect" )
@@ -217,10 +180,9 @@ def test_3_expmanager(self):
217180
218181def suite ():
219182 _suite = unittest .TestSuite ()
220- _suite .addTest (TestAllFlow ("test_0_train_with_sigana" ))
221- _suite .addTest (TestAllFlow ("test_1_train" ))
222- _suite .addTest (TestAllFlow ("test_2_backtest" ))
223- _suite .addTest (TestAllFlow ("test_3_expmanager" ))
183+ _suite .addTest (TestAllFlow ("test_0_train" ))
184+ _suite .addTest (TestAllFlow ("test_1_backtest" ))
185+ _suite .addTest (TestAllFlow ("test_2_expmanager" ))
224186 return _suite
225187
226188
0 commit comments