Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions qlib/workflow/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def delete_recorder(self, recorder_id=None, recorder_name=None):

UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!

def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""):
"""
Parameters
----------
Expand All @@ -334,8 +334,12 @@ def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None]
status : str
the criteria based on status to filter results.
`None` indicates no filtering.
filter_string : str
mlflow supported filter string like 'params."my_param"="a" and tags."my_tag"="b"', use this will help to reduce too much run number.
"""
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
runs = self._client.search_runs(
self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results, filter_string=filter_string
)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
Expand Down
5 changes: 4 additions & 1 deletion qlib/workflow/task/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
rec_filter_func=None,
artifacts_path={"pred": "pred.pkl"},
artifacts_key=None,
filter_string: str = "",
):
"""
Init RecorderCollector.
Expand All @@ -150,6 +151,7 @@ def __init__(
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
filter_string (str): filter string that used to apply in recorder quering (only support mlflow for now).
"""
super().__init__(process_list=process_list)
if isinstance(experiment, str):
Expand All @@ -163,6 +165,7 @@ def __init__(
self.rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self.rec_filter_func = rec_filter_func
self.filter_string = filter_string

def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
"""
Expand All @@ -187,7 +190,7 @@ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) ->

collect_dict = {}
# filter records
recs = self.experiment.list_recorders()
recs = self.experiment.list_recorders(filter_string=self.filter_string)
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
Expand Down