1616import re
1717from typing import Callable , List
1818
19+ from tqdm .auto import tqdm
1920from qlib .data .dataset import Dataset
2021from qlib .log import get_module_logger
2122from qlib .model .base import Model
2526from qlib .workflow .recorder import Recorder
2627from qlib .workflow .task .manage import TaskManager , run_task
2728
29+ # from qlib.data.dataset.weight import Reweighter
30+
31+
32+ def _log_task_info (task_config : dict ):
33+ R .log_params (** flatten_dict (task_config ))
34+ R .save_objects (** {"task" : task_config }) # keep the original format and datatype
35+ R .set_tags (** {"hostname" : socket .gethostname ()})
36+
37+
38+ def _exe_task (task_config : dict ):
39+ rec = R .get_recorder ()
40+ # model & dataset initiation
41+ model : Model = init_instance_by_config (task_config ["model" ])
42+ dataset : Dataset = init_instance_by_config (task_config ["dataset" ])
43+ # FIXME: resume reweighter after merging data selection
44+ # reweighter: Reweighter = task_config.get("reweighter", None)
45+ # model training
46+ # auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
47+ model .fit (dataset )
48+ R .save_objects (** {"params.pkl" : model })
49+ # this dataset is saved for online inference. So the concrete data should not be dumped
50+ dataset .config (dump_all = False , recursive = True )
51+ R .save_objects (** {"dataset" : dataset })
52+ # fill placehorder
53+ placehorder_value = {"<MODEL>" : model , "<DATASET>" : dataset }
54+ task_config = fill_placeholder (task_config , placehorder_value )
55+ # generate records: prediction, backtest, and analysis
56+ records = task_config .get ("record" , [])
57+ if isinstance (records , dict ): # prevent only one dict
58+ records = [records ]
59+ for record in records :
60+ # Some recorder require the parameter `model` and `dataset`.
61+ # try to automatically pass in them to the initialization function
62+ # to make defining the tasking easier
63+ r = init_instance_by_config (
64+ record ,
65+ recorder = rec ,
66+ default_module = "qlib.workflow.record_temp" ,
67+ try_kwargs = {"model" : model , "dataset" : dataset },
68+ )
69+ r .generate ()
70+
2871
2972def begin_task_train (task_config : dict , experiment_name : str , recorder_name : str = None ) -> Recorder :
3073 """
@@ -39,11 +82,8 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
3982 Recorder: the model recorder
4083 """
4184 with R .start (experiment_name = experiment_name , recorder_name = recorder_name ):
42- R .log_params (** flatten_dict (task_config ))
43- R .save_objects (** {"task" : task_config }) # keep the original format and datatype
44- R .set_tags (** {"hostname" : socket .gethostname ()})
45- recorder : Recorder = R .get_recorder ()
46- return recorder
85+ _log_task_info (task_config )
86+ return R .get_recorder ()
4787
4888
4989def fill_placeholder (config : dict , config_extend : dict ):
@@ -100,38 +140,11 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
100140 """
101141 with R .start (experiment_name = experiment_name , recorder_id = rec .info ["id" ], resume = True ):
102142 task_config = R .load_object ("task" )
103- # model & dataset initiation
104- model : Model = init_instance_by_config (task_config ["model" ])
105- dataset : Dataset = init_instance_by_config (task_config ["dataset" ])
106- # model training
107- model .fit (dataset )
108- R .save_objects (** {"params.pkl" : model })
109- # this dataset is saved for online inference. So the concrete data should not be dumped
110- dataset .config (dump_all = False , recursive = True )
111- R .save_objects (** {"dataset" : dataset })
112- # fill placehorder
113- placehorder_value = {"<MODEL>" : model , "<DATASET>" : dataset }
114- task_config = fill_placeholder (task_config , placehorder_value )
115- # generate records: prediction, backtest, and analysis
116- records = task_config .get ("record" , [])
117- if isinstance (records , dict ): # uniform the data format to list
118- records = [records ]
119-
120- for record in records :
121- # Some recorder require the parameter `model` and `dataset`.
122- # try to automatically pass in them to the initialization function
123- # to make defining the tasking easier
124- r = init_instance_by_config (
125- record ,
126- recorder = rec ,
127- default_module = "qlib.workflow.record_temp" ,
128- try_kwargs = {"model" : model , "dataset" : dataset },
129- )
130- r .generate ()
143+ _exe_task (task_config )
131144 return rec
132145
133146
134- def task_train (task_config : dict , experiment_name : str ) -> Recorder :
147+ def task_train (task_config : dict , experiment_name : str , recorder_name : str = None ) -> Recorder :
135148 """
136149 Task based training, will be divided into two steps.
137150
@@ -141,14 +154,17 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
141154 The config of a task.
142155 experiment_name: str
143156 The name of experiment
157+ recorder_name: str
158+ The name of recorder
144159
145160 Returns
146161 ----------
147162 Recorder: The instance of the recorder
148163 """
149- recorder = begin_task_train (task_config , experiment_name )
150- recorder = end_task_train (recorder , experiment_name )
151- return recorder
164+ with R .start (experiment_name = experiment_name , recorder_name = recorder_name ):
165+ _log_task_info (task_config )
166+ _exe_task (task_config )
167+ return R .get_recorder ()
152168
153169
154170class Trainer :
@@ -204,6 +220,30 @@ def is_delay(self) -> bool:
204220 def __call__ (self , * args , ** kwargs ) -> list :
205221 return self .end_train (self .train (* args , ** kwargs ))
206222
223+ def has_worker (self ) -> bool :
224+ """
225+ Some trainer has backend worker to support parallel training
226+ This method can tell if the worker is enabled.
227+
228+ Returns
229+ -------
230+ bool:
231+ if the worker is enabled
232+
233+ """
234+ return False
235+
236+ def worker (self ):
237+ """
238+ start the worker
239+
240+ Raises
241+ ------
242+ NotImplementedError:
243+ If the worker is not supported
244+ """
245+ raise NotImplementedError (f"Please implement the `worker` method" )
246+
207247
208248class TrainerR (Trainer ):
209249 """
@@ -252,7 +292,7 @@ def train(self, tasks: list, train_func: Callable = None, experiment_name: str =
252292 if experiment_name is None :
253293 experiment_name = self .experiment_name
254294 recs = []
255- for task in tasks :
295+ for task in tqdm ( tasks ) :
256296 rec = train_func (task , experiment_name , ** kwargs )
257297 rec .set_tags (** {self .STATUS_KEY : self .STATUS_BEGIN })
258298 recs .append (rec )
@@ -457,6 +497,9 @@ def worker(
457497 task_pool = experiment_name
458498 run_task (train_func , task_pool = task_pool , experiment_name = experiment_name )
459499
500+ def has_worker (self ) -> bool :
501+ return True
502+
460503
461504class DelayTrainerRM (TrainerRM ):
462505 """
@@ -579,3 +622,6 @@ def worker(self, end_train_func=None, experiment_name: str = None):
579622 experiment_name = experiment_name ,
580623 before_status = TaskManager .STATUS_PART_DONE ,
581624 )
625+
626+ def has_worker (self ) -> bool :
627+ return True
0 commit comments