@@ -21,7 +21,7 @@ class DataLoader(abc.ABC):
2121 """
2222
2323 @abc .abstractmethod
24- def load (self , instruments , start_time = None , end_time = None , freq = "day" ) -> pd .DataFrame :
24+ def load (self , instruments , start_time = None , end_time = None ) -> pd .DataFrame :
2525 """
2626 load the data as pd.DataFrame.
2727
@@ -78,6 +78,7 @@ def __init__(self, config: Tuple[list, tuple, dict]):
7878 <config> := <fields_info>
7979
8080 <fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
81+ # NOTE: list or tuple will be treated as the things when parsing
8182 """
8283 self .is_group = isinstance (config , dict )
8384
@@ -87,18 +88,22 @@ def __init__(self, config: Tuple[list, tuple, dict]):
8788 self .fields = self ._parse_fields_info (config )
8889
8990 def _parse_fields_info (self , fields_info : Tuple [list , tuple ]) -> Tuple [list , list ]:
90- if isinstance (fields_info , list ):
91+ if len (fields_info ) == 0 :
92+ raise ValueError ("The size of fields must be greater than 0" )
93+
94+ if not isinstance (fields_info , (list , tuple )):
95+ raise TypeError ("Unsupported type" )
96+
97+ if isinstance (fields_info [0 ], str ):
9198 exprs = names = fields_info
92- elif isinstance (fields_info , tuple ):
99+ elif isinstance (fields_info [ 0 ], ( list , tuple ) ):
93100 exprs , names = fields_info
94101 else :
95102 raise NotImplementedError (f"This type of input is not supported" )
96103 return exprs , names
97104
98105 @abc .abstractmethod
99- def load_group_df (
100- self , instruments , exprs : list , names : list , start_time = None , end_time = None , freq = "day"
101- ) -> pd .DataFrame :
106+ def load_group_df (self , instruments , exprs : list , names : list , start_time = None , end_time = None ) -> pd .DataFrame :
102107 """
103108 load the dataframe for specific group
104109
@@ -118,25 +123,25 @@ def load_group_df(
118123 """
119124 pass
120125
121- def load (self , instruments = None , start_time = None , end_time = None , freq = "day" ) -> pd .DataFrame :
126+ def load (self , instruments = None , start_time = None , end_time = None ) -> pd .DataFrame :
122127 if self .is_group :
123128 df = pd .concat (
124129 {
125- grp : self .load_group_df (instruments , exprs , names , start_time , end_time , freq )
130+ grp : self .load_group_df (instruments , exprs , names , start_time , end_time )
126131 for grp , (exprs , names ) in self .fields .items ()
127132 },
128133 axis = 1 ,
129134 )
130135 else :
131136 exprs , names = self .fields
132- df = self .load_group_df (instruments , exprs , names , start_time , end_time , freq )
137+ df = self .load_group_df (instruments , exprs , names , start_time , end_time )
133138 return df
134139
135140
136141class QlibDataLoader (DLWParser ):
137142 """Same as QlibDataLoader. The fields can be define by config"""
138143
139- def __init__ (self , config : Tuple [list , tuple , dict ], filter_pipe = None , swap_level = True ):
144+ def __init__ (self , config : Tuple [list , tuple , dict ], filter_pipe = None , swap_level = True , freq = "day" ):
140145 """
141146 Parameters
142147 ----------
@@ -156,11 +161,10 @@ def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_leve
156161
157162 self .filter_pipe = filter_pipe
158163 self .swap_level = swap_level
164+ self .freq = freq
159165 super ().__init__ (config )
160166
161- def load_group_df (
162- self , instruments , exprs : list , names : list , start_time = None , end_time = None , freq = "day"
163- ) -> pd .DataFrame :
167+ def load_group_df (self , instruments , exprs : list , names : list , start_time = None , end_time = None ) -> pd .DataFrame :
164168 if instruments is None :
165169 warnings .warn ("`instruments` is not set, will load all stocks" )
166170 instruments = "all"
@@ -169,7 +173,7 @@ def load_group_df(
169173 elif self .filter_pipe is not None :
170174 warnings .warn ("`filter_pipe` is not None, but it will not be used with `instruments` as list" )
171175
172- df = D .features (instruments , exprs , start_time , end_time , freq )
176+ df = D .features (instruments , exprs , start_time , end_time , self . freq )
173177 df .columns = names
174178 if self .swap_level :
175179 df = df .swaplevel ().sort_index () # NOTE: if swaplevel, return <datetime, instrument>
@@ -194,7 +198,7 @@ def __init__(self, config: dict, join="outer"):
194198 self .join = join
195199 self ._data = None
196200
197- def load (self , instruments = None , start_time = None , end_time = None , freq = "day" ) -> pd .DataFrame :
201+ def load (self , instruments = None , start_time = None , end_time = None ) -> pd .DataFrame :
198202 self ._maybe_load_raw_data ()
199203 if instruments is None :
200204 df = self ._data
0 commit comments