Skip to content

Commit 802dac8

Browse files
committed
move freq params to dataloader
1 parent bdc70c1 commit 802dac8

File tree

6 files changed

+51
-31
lines changed

6 files changed

+51
-31
lines changed

examples/highfreq/highfreq_handler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def __init__(
1010
instruments="csi300",
1111
start_time=None,
1212
end_time=None,
13-
freq="1min",
1413
infer_processors=[],
1514
learn_processors=[],
1615
fit_start_time=None,
@@ -37,13 +36,13 @@ def check_transform_proc(proc_l):
3736
"kwargs": {
3837
"config": self.get_feature_config(),
3938
"swap_level": False,
39+
"freq": "1min",
4040
},
4141
}
4242
super().__init__(
4343
instruments=instruments,
4444
start_time=start_time,
4545
end_time=end_time,
46-
freq=freq,
4746
data_loader=data_loader,
4847
infer_processors=infer_processors,
4948
learn_processors=learn_processors,
@@ -124,20 +123,19 @@ def __init__(
124123
instruments="csi300",
125124
start_time=None,
126125
end_time=None,
127-
freq="1min",
128126
):
129127
data_loader = {
130128
"class": "QlibDataLoader",
131129
"kwargs": {
132130
"config": self.get_feature_config(),
133131
"swap_level": False,
132+
"freq": "1min",
134133
},
135134
}
136135
super().__init__(
137136
instruments=instruments,
138137
start_time=start_time,
139138
end_time=end_time,
140-
freq=freq,
141139
data_loader=data_loader,
142140
)
143141

qlib/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def set_conf_from_C(self, config_c):
9090
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
9191
"maxtasksperchild": None,
9292
"default_disk_cache": 1, # 0:skip/1:use
93-
"disable_disk_cache": False, # disable disk cache; if High-frequency data generally disable_disk_cache=True
9493
"mem_cache_size_limit": 500,
9594
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
9695
# default 1 hour

qlib/data/data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -961,8 +961,7 @@ def features(
961961
is a provider class.
962962
"""
963963
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
964-
if C.disable_disk_cache:
965-
disk_cache = False
964+
fields = list(fields) # In case of tuple.
966965
try:
967966
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
968967
except TypeError:

qlib/data/dataset/handler.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def __init__(
5757
instruments=None,
5858
start_time=None,
5959
end_time=None,
60-
freq="day",
6160
data_loader: Tuple[dict, str, DataLoader] = None,
6261
init_data=True,
6362
fetch_orig=True,
63+
**kwargs,
6464
):
6565
"""
6666
Parameters
@@ -71,38 +71,58 @@ def __init__(
7171
start_time of the original data.
7272
end_time :
7373
end_time of the original data.
74-
freq :
75-
frequency of data
7674
data_loader : Tuple[dict, str, DataLoader]
7775
data loader to load the data.
7876
init_data :
7977
intialize the original data in the constructor.
8078
fetch_orig : bool
8179
Return the original data instead of copy if possible.
80+
**kwargs:
81+
it will be passed into data_loader
8282
"""
8383
# Set logger
8484
self.logger = get_module_logger("DataHandler")
8585

8686
# Setup data loader
8787
assert data_loader is not None # to make start_time end_time could have None default value
8888

89+
# what data source to load data
8990
self.data_loader = init_instance_by_config(
9091
data_loader,
9192
None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module,
9293
accept_types=DataLoader,
94+
**kwargs,
9395
)
9496

97+
# what data to be loaded from data source
98+
# For IDE auto-completion.
9599
self.instruments = instruments
96100
self.start_time = start_time
97101
self.end_time = end_time
98-
self.freq = freq
102+
99103
self.fetch_orig = fetch_orig
100104
if init_data:
101105
with TimeInspector.logt("Init data"):
102106
self.init()
103107
super().__init__()
104108

105-
def init(self, enable_cache: bool = True):
109+
def conf_data(self, **kwargs):
110+
"""
111+
configuration of data.
112+
# what data to be loaded from data source
113+
114+
This method will be used when loading pickled handler from dataset.
115+
The data will be initialized with different time range.
116+
117+
"""
118+
attr_list = {"instruments", "start_time", "end_time"}
119+
for k, v in kwargs.items():
120+
if k in attr_list:
121+
setattr(self, k, v)
122+
else:
123+
raise KeyError("Such config is not supported.")
124+
125+
def init(self, enable_cache: bool = False):
106126
"""
107127
initialize the data.
108128
In case of running intialization for multiple time, it will do nothing for the second time.
@@ -123,7 +143,7 @@ def init(self, enable_cache: bool = True):
123143
# Setup data.
124144
# _data may be with multiple column index level. The outer level indicates the feature set name
125145
with TimeInspector.logt("Loading data"):
126-
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq)
146+
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
127147
# TODO: cache
128148

129149
CS_ALL = "__all" # return all columns with single-level index column
@@ -262,7 +282,6 @@ def __init__(
262282
instruments=None,
263283
start_time=None,
264284
end_time=None,
265-
freq="day",
266285
data_loader: Tuple[dict, str, DataLoader] = None,
267286
infer_processors=[],
268287
learn_processors=[],
@@ -328,7 +347,7 @@ def __init__(
328347

329348
self.process_type = process_type
330349
self.drop_raw = drop_raw
331-
super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs)
350+
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
332351

333352
def get_all_processors(self):
334353
return self.infer_processors + self.learn_processors

qlib/data/dataset/loader.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class DataLoader(abc.ABC):
2121
"""
2222

2323
@abc.abstractmethod
24-
def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
24+
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
2525
"""
2626
load the data as pd.DataFrame.
2727
@@ -78,6 +78,7 @@ def __init__(self, config: Tuple[list, tuple, dict]):
7878
<config> := <fields_info>
7979
8080
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
81+
# NOTE: list or tuple will be treated as the things when parsing
8182
"""
8283
self.is_group = isinstance(config, dict)
8384

@@ -87,18 +88,22 @@ def __init__(self, config: Tuple[list, tuple, dict]):
8788
self.fields = self._parse_fields_info(config)
8889

8990
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
90-
if isinstance(fields_info, list):
91+
if len(fields_info) == 0:
92+
raise ValueError("The size of fields must be greater than 0")
93+
94+
if not isinstance(fields_info, (list, tuple)):
95+
raise TypeError("Unsupported type")
96+
97+
if isinstance(fields_info[0], str):
9198
exprs = names = fields_info
92-
elif isinstance(fields_info, tuple):
99+
elif isinstance(fields_info[0], (list, tuple)):
93100
exprs, names = fields_info
94101
else:
95102
raise NotImplementedError(f"This type of input is not supported")
96103
return exprs, names
97104

98105
@abc.abstractmethod
99-
def load_group_df(
100-
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
101-
) -> pd.DataFrame:
106+
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
102107
"""
103108
load the dataframe for specific group
104109
@@ -118,25 +123,25 @@ def load_group_df(
118123
"""
119124
pass
120125

121-
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
126+
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
122127
if self.is_group:
123128
df = pd.concat(
124129
{
125-
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
130+
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
126131
for grp, (exprs, names) in self.fields.items()
127132
},
128133
axis=1,
129134
)
130135
else:
131136
exprs, names = self.fields
132-
df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
137+
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
133138
return df
134139

135140

136141
class QlibDataLoader(DLWParser):
137142
"""Same as QlibDataLoader. The fields can be define by config"""
138143

139-
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True):
144+
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
140145
"""
141146
Parameters
142147
----------
@@ -156,11 +161,10 @@ def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_leve
156161

157162
self.filter_pipe = filter_pipe
158163
self.swap_level = swap_level
164+
self.freq = freq
159165
super().__init__(config)
160166

161-
def load_group_df(
162-
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
163-
) -> pd.DataFrame:
167+
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
164168
if instruments is None:
165169
warnings.warn("`instruments` is not set, will load all stocks")
166170
instruments = "all"
@@ -169,7 +173,7 @@ def load_group_df(
169173
elif self.filter_pipe is not None:
170174
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
171175

172-
df = D.features(instruments, exprs, start_time, end_time, freq)
176+
df = D.features(instruments, exprs, start_time, end_time, self.freq)
173177
df.columns = names
174178
if self.swap_level:
175179
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
@@ -194,7 +198,7 @@ def __init__(self, config: dict, join="outer"):
194198
self.join = join
195199
self._data = None
196200

197-
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
201+
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
198202
self._maybe_load_raw_data()
199203
if instruments is None:
200204
df = self._data

qlib/workflow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from contextlib import contextmanager
55
from .expm import MLflowExpManager
6+
from .exp import Experiment
67
from .recorder import Recorder
78
from ..utils import Wrapper
89

@@ -165,7 +166,7 @@ def list_recorders(self, experiment_id=None, experiment_name=None):
165166
"""
166167
return self.get_exp(experiment_id, experiment_name).list_recorders()
167168

168-
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
169+
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
169170
"""
170171
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
171172
True, if no valid experiment is found, this method will create one for you. Otherwise, it will

0 commit comments

Comments
 (0)