Skip to content

Commit d53ff38

Browse files
committed
Callable Exp
1 parent 01bdf6c commit d53ff38

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

qlib/workflow/record_temp.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
from pathlib import Path
1010
from pprint import pprint
11-
from typing import Union, List
11+
from typing import Union, List, Optional
1212
from collections import defaultdict
1313

1414
from qlib.utils.exceptions import LoadObjectError
@@ -270,7 +270,13 @@ def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0,
270270
self.label_col = label_col
271271
self.skip_existing = skip_existing
272272

273-
def generate(self, **kwargs):
273+
def generate(self, label: Optional[pd.DataFrame]=None, **kwargs):
274+
"""
275+
Parameters
276+
----------
277+
label : Optional[pd.DataFrame]
278+
Label should be a dataframe.
279+
"""
274280
if self.skip_existing:
275281
try:
276282
self.check(include_self=True, parents=False)
@@ -283,7 +289,8 @@ def generate(self, **kwargs):
283289
self.check()
284290

285291
pred = self.load("pred.pkl")
286-
label = self.load("label.pkl")
292+
if label is None:
293+
label = self.load("label.pkl")
287294
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
288295
logger.warn(f"Empty label.")
289296
return

qlib/workflow/task/collect.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from qlib.log import get_module_logger
1111
from qlib.utils.serial import Serializable
1212
from qlib.workflow import R
13+
from qlib.workflow.exp import Experiment
1314

1415

1516
class Collector(Serializable):
@@ -146,7 +147,9 @@ def __init__(
146147
Init RecorderCollector.
147148
148149
Args:
149-
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
150+
experiment:
151+
(Experiment or str): an instance of an Experiment or the name of an Experiment
152+
(Callable): an callable function, which returns a list of experiments
150153
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
151154
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
152155
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
@@ -157,6 +160,7 @@ def __init__(
157160
super().__init__(process_list=process_list)
158161
if isinstance(experiment, str):
159162
experiment = R.get_exp(experiment_name=experiment)
163+
assert isinstance(experiment, (Experiment, Callable))
160164
self.experiment = experiment
161165
self.artifacts_path = artifacts_path
162166
if rec_key_func is None:
@@ -192,15 +196,16 @@ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) ->
192196
collect_dict = {}
193197
# filter records
194198

195-
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
196-
recs = self.experiment.list_recorders(**self.list_kwargs)
197-
recs_flt = {}
198-
for rid, rec in recs.items():
199-
if rec_filter_func is None or rec_filter_func(rec):
200-
recs_flt[rid] = rec
199+
if isinstance(self.experiment, Experiment):
200+
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
201+
recs = list(self.experiment.list_recorders(**self.list_kwargs).values())
202+
elif isinstance(self.experiment, Callable):
203+
recs = self.experiment()
204+
205+
recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)]
201206

202207
logger = get_module_logger("RecorderCollector")
203-
for _, rec in recs_flt.items():
208+
for rec in recs:
204209
rec_key = self.rec_key_func(rec)
205210
for key in artifacts_key:
206211
if self.ART_KEY_RAW == key:

0 commit comments

Comments
 (0)