1010from qlib .log import get_module_logger
1111from qlib .utils .serial import Serializable
1212from qlib .workflow import R
13+ from qlib .workflow .exp import Experiment
1314
1415
1516class Collector (Serializable ):
@@ -146,7 +147,9 @@ def __init__(
146147 Init RecorderCollector.
147148
148149 Args:
149- experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
150+ experiment:
151+ (Experiment or str): an instance of an Experiment or the name of an Experiment
152+ (Callable): an callable function, which returns a list of experiments
150153 process_list (list or Callable): the list of processors or the instance of a processor to process dict.
151154 rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
152155 rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
@@ -157,6 +160,7 @@ def __init__(
157160 super ().__init__ (process_list = process_list )
158161 if isinstance (experiment , str ):
159162 experiment = R .get_exp (experiment_name = experiment )
163+ assert isinstance (experiment , (Experiment , Callable ))
160164 self .experiment = experiment
161165 self .artifacts_path = artifacts_path
162166 if rec_key_func is None :
@@ -192,15 +196,16 @@ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) ->
192196 collect_dict = {}
193197 # filter records
194198
195- with TimeInspector .logt ("Time to `list_recorders` in RecorderCollector" ):
196- recs = self .experiment .list_recorders (** self .list_kwargs )
197- recs_flt = {}
198- for rid , rec in recs .items ():
199- if rec_filter_func is None or rec_filter_func (rec ):
200- recs_flt [rid ] = rec
199+ if isinstance (self .experiment , Experiment ):
200+ with TimeInspector .logt ("Time to `list_recorders` in RecorderCollector" ):
201+ recs = list (self .experiment .list_recorders (** self .list_kwargs ).values ())
202+ elif isinstance (self .experiment , Callable ):
203+ recs = self .experiment ()
204+
205+ recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func (rec )]
201206
202207 logger = get_module_logger ("RecorderCollector" )
203- for _ , rec in recs_flt . items () :
208+ for rec in recs :
204209 rec_key = self .rec_key_func (rec )
205210 for key in artifacts_key :
206211 if self .ART_KEY_RAW == key :
0 commit comments