Skip to content

Commit c0e7cbc

Browse files
Derek-Wdsyou-n-g
authored andcommitted
Add filter_pipe API
1 parent 828993b commit c0e7cbc

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

qlib/contrib/data/handler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
learn_processors=_DEFAULT_LEARN_PROCESSORS,
5555
fit_start_time=None,
5656
fit_end_time=None,
57+
filter_pipe=None,
5758
**kwargs,
5859
):
5960
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -66,6 +67,7 @@ def __init__(
6667
"feature": self.get_feature_config(),
6768
"label": kwargs.get("label", self.get_label_config()),
6869
},
70+
"filter_pipe": filter_pipe,
6971
},
7072
}
7173

@@ -138,6 +140,7 @@ def __init__(
138140
fit_start_time=None,
139141
fit_end_time=None,
140142
process_type=DataHandlerLP.PTYPE_A,
143+
filter_pipe=None,
141144
**kwargs,
142145
):
143146
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -146,7 +149,11 @@ def __init__(
146149
data_loader = {
147150
"class": "QlibDataLoader",
148151
"kwargs": {
149-
"config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())},
152+
"config": {
153+
"feature": self.get_feature_config(),
154+
"label": kwargs.get("label", self.get_label_config()),
155+
},
156+
"filter_pipe": filter_pipe,
150157
},
151158
}
152159
super().__init__(

qlib/data/dataset/loader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from typing import Tuple, Union
1111

1212
from qlib.data import D
13-
from qlib.utils import load_dataset
13+
from qlib.data import filter as filter_module
14+
from qlib.data.filter import BaseDFilter
15+
from qlib.utils import load_dataset, init_instance_by_config
1416

1517

1618
class DataLoader(abc.ABC):
@@ -145,6 +147,13 @@ def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_leve
145147
swap_level :
146148
Whether to swap level of MultiIndex
147149
"""
150+
if filter_pipe is not None:
151+
assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list."
152+
filter_pipe = [
153+
init_instance_by_config(fp, None if "module_path" in fp else filter_module, accept_types=BaseDFilter)
154+
for fp in filter_pipe
155+
]
156+
148157
self.filter_pipe = filter_pipe
149158
self.swap_level = swap_level
150159
super().__init__(config)

0 commit comments

Comments
 (0)