Skip to content

Commit a2be6e2

Browse files
you-n-gWangwuyi123
andauthored
handler demo cache (#606)
* handler demo cache * Update data_cache_demo.py * example to reusing processed data in memory * Skip dumping task of task_train * FIX Black Co-authored-by: Wangwuyi123 <51237097+Wangwuyi123@users.noreply.github.com>
1 parent fdbc666 commit a2be6e2

File tree

4 files changed

+198
-38
lines changed

4 files changed

+198
-38
lines changed

examples/data_demo/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Introduction
2+
The examples in this folder try to demonstrate some common usage of data-related modules of Qlib
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""
4+
The motivation of this demo
5+
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
6+
"""
7+
8+
from copy import deepcopy
9+
from pathlib import Path
10+
import pickle
11+
from pprint import pprint
12+
import subprocess
13+
import yaml
14+
from qlib.log import TimeInspector
15+
16+
from qlib import init
17+
from qlib.data.dataset.handler import DataHandlerLP
18+
from qlib.utils import init_instance_by_config
19+
20+
# For general purpose, we use relative path
21+
DIRNAME = Path(__file__).absolute().resolve().parent
22+
23+
if __name__ == "__main__":
24+
init()
25+
26+
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
27+
28+
# 1) show original time
29+
with TimeInspector.logt("The original time without handler cache:"):
30+
subprocess.run(f"qrun {config_path}", shell=True)
31+
32+
# 2) dump handler
33+
task_config = yaml.safe_load(config_path.open())
34+
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
35+
pprint(hd_conf)
36+
hd: DataHandlerLP = init_instance_by_config(hd_conf)
37+
hd_path = DIRNAME / "handler.pkl"
38+
hd.to_pickle(hd_path, dump_all=True)
39+
40+
# 3) create new task with handler cache
41+
new_task_config = deepcopy(task_config)
42+
new_task_config["task"]["dataset"]["kwargs"]["handler"] = f"file://{hd_path}"
43+
new_task_config
44+
new_task_path = DIRNAME / "new_task.yaml"
45+
print("The location of the new task", new_task_path)
46+
47+
# save new task
48+
with new_task_path.open("w") as f:
49+
yaml.safe_dump(new_task_config, f)
50+
51+
# 4) train model with new task
52+
with TimeInspector.logt("The time for task with handler cache:"):
53+
subprocess.run(f"qrun {new_task_path}", shell=True)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""
4+
The motivation of this demo
5+
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
6+
"""
7+
8+
from copy import deepcopy
9+
from pathlib import Path
10+
import pickle
11+
from pprint import pprint
12+
import subprocess
13+
14+
import yaml
15+
16+
from qlib import init
17+
from qlib.data.dataset.handler import DataHandlerLP
18+
from qlib.log import TimeInspector
19+
from qlib.model.trainer import task_train
20+
from qlib.utils import init_instance_by_config
21+
22+
# For general purpose, we use relative path
23+
DIRNAME = Path(__file__).absolute().resolve().parent
24+
25+
if __name__ == "__main__":
26+
init()
27+
28+
repeat = 2
29+
exp_name = "data_mem_reuse_demo"
30+
31+
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
32+
task_config = yaml.safe_load(config_path.open())
33+
34+
# 1) without using processed data in memory
35+
with TimeInspector.logt("The original time without reusing processed data in memory:"):
36+
for i in range(repeat):
37+
task_train(task_config["task"], experiment_name=exp_name)
38+
39+
# 2) prepare processed data in memory.
40+
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
41+
pprint(hd_conf)
42+
hd: DataHandlerLP = init_instance_by_config(hd_conf)
43+
44+
# 3) with reusing processed data in memory
45+
new_task = deepcopy(task_config["task"])
46+
new_task["dataset"]["kwargs"]["handler"] = hd
47+
print(new_task)
48+
49+
with TimeInspector.logt("The time with reusing processed data in memory:"):
50+
# this will save the time to reload and process data from disk(in `DataHandlerLP`)
51+
# It still takes a lot of time in the backtest phase
52+
for i in range(repeat):
53+
task_train(new_task, experiment_name=exp_name)
54+
55+
# 4) User can change other parts exclude processed data in memory(handler)
56+
new_task = deepcopy(task_config["task"])
57+
new_task["dataset"]["kwargs"]["segments"]["train"] = ("20100101", "20131231")
58+
with TimeInspector.logt("The time with reusing processed data in memory:"):
59+
task_train(new_task, experiment_name=exp_name)

qlib/model/trainer.py

Lines changed: 84 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
from typing import Callable, List
1818

19+
from tqdm.auto import tqdm
1920
from qlib.data.dataset import Dataset
2021
from qlib.log import get_module_logger
2122
from qlib.model.base import Model
@@ -25,6 +26,48 @@
2526
from qlib.workflow.recorder import Recorder
2627
from qlib.workflow.task.manage import TaskManager, run_task
2728

29+
# from qlib.data.dataset.weight import Reweighter
30+
31+
32+
def _log_task_info(task_config: dict):
33+
R.log_params(**flatten_dict(task_config))
34+
R.save_objects(**{"task": task_config}) # keep the original format and datatype
35+
R.set_tags(**{"hostname": socket.gethostname()})
36+
37+
38+
def _exe_task(task_config: dict):
39+
rec = R.get_recorder()
40+
# model & dataset initiation
41+
model: Model = init_instance_by_config(task_config["model"])
42+
dataset: Dataset = init_instance_by_config(task_config["dataset"])
43+
# FIXME: resume reweighter after merging data selection
44+
# reweighter: Reweighter = task_config.get("reweighter", None)
45+
# model training
46+
# auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
47+
model.fit(dataset)
48+
R.save_objects(**{"params.pkl": model})
49+
# this dataset is saved for online inference. So the concrete data should not be dumped
50+
dataset.config(dump_all=False, recursive=True)
51+
R.save_objects(**{"dataset": dataset})
52+
# fill placehorder
53+
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
54+
task_config = fill_placeholder(task_config, placehorder_value)
55+
# generate records: prediction, backtest, and analysis
56+
records = task_config.get("record", [])
57+
if isinstance(records, dict): # prevent only one dict
58+
records = [records]
59+
for record in records:
60+
# Some recorder require the parameter `model` and `dataset`.
61+
# try to automatically pass in them to the initialization function
62+
# to make defining the tasking easier
63+
r = init_instance_by_config(
64+
record,
65+
recorder=rec,
66+
default_module="qlib.workflow.record_temp",
67+
try_kwargs={"model": model, "dataset": dataset},
68+
)
69+
r.generate()
70+
2871

2972
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
3073
"""
@@ -39,11 +82,8 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
3982
Recorder: the model recorder
4083
"""
4184
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
42-
R.log_params(**flatten_dict(task_config))
43-
R.save_objects(**{"task": task_config}) # keep the original format and datatype
44-
R.set_tags(**{"hostname": socket.gethostname()})
45-
recorder: Recorder = R.get_recorder()
46-
return recorder
85+
_log_task_info(task_config)
86+
return R.get_recorder()
4787

4888

4989
def fill_placeholder(config: dict, config_extend: dict):
@@ -100,38 +140,11 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
100140
"""
101141
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
102142
task_config = R.load_object("task")
103-
# model & dataset initiation
104-
model: Model = init_instance_by_config(task_config["model"])
105-
dataset: Dataset = init_instance_by_config(task_config["dataset"])
106-
# model training
107-
model.fit(dataset)
108-
R.save_objects(**{"params.pkl": model})
109-
# this dataset is saved for online inference. So the concrete data should not be dumped
110-
dataset.config(dump_all=False, recursive=True)
111-
R.save_objects(**{"dataset": dataset})
112-
# fill placehorder
113-
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
114-
task_config = fill_placeholder(task_config, placehorder_value)
115-
# generate records: prediction, backtest, and analysis
116-
records = task_config.get("record", [])
117-
if isinstance(records, dict): # uniform the data format to list
118-
records = [records]
119-
120-
for record in records:
121-
# Some recorder require the parameter `model` and `dataset`.
122-
# try to automatically pass in them to the initialization function
123-
# to make defining the tasking easier
124-
r = init_instance_by_config(
125-
record,
126-
recorder=rec,
127-
default_module="qlib.workflow.record_temp",
128-
try_kwargs={"model": model, "dataset": dataset},
129-
)
130-
r.generate()
143+
_exe_task(task_config)
131144
return rec
132145

133146

134-
def task_train(task_config: dict, experiment_name: str) -> Recorder:
147+
def task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
135148
"""
136149
Task based training, will be divided into two steps.
137150
@@ -141,14 +154,17 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
141154
The config of a task.
142155
experiment_name: str
143156
The name of experiment
157+
recorder_name: str
158+
The name of recorder
144159
145160
Returns
146161
----------
147162
Recorder: The instance of the recorder
148163
"""
149-
recorder = begin_task_train(task_config, experiment_name)
150-
recorder = end_task_train(recorder, experiment_name)
151-
return recorder
164+
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
165+
_log_task_info(task_config)
166+
_exe_task(task_config)
167+
return R.get_recorder()
152168

153169

154170
class Trainer:
@@ -204,6 +220,30 @@ def is_delay(self) -> bool:
204220
def __call__(self, *args, **kwargs) -> list:
205221
return self.end_train(self.train(*args, **kwargs))
206222

223+
def has_worker(self) -> bool:
224+
"""
225+
Some trainer has backend worker to support parallel training
226+
This method can tell if the worker is enabled.
227+
228+
Returns
229+
-------
230+
bool:
231+
if the worker is enabled
232+
233+
"""
234+
return False
235+
236+
def worker(self):
237+
"""
238+
start the worker
239+
240+
Raises
241+
------
242+
NotImplementedError:
243+
If the worker is not supported
244+
"""
245+
raise NotImplementedError(f"Please implement the `worker` method")
246+
207247

208248
class TrainerR(Trainer):
209249
"""
@@ -252,7 +292,7 @@ def train(self, tasks: list, train_func: Callable = None, experiment_name: str =
252292
if experiment_name is None:
253293
experiment_name = self.experiment_name
254294
recs = []
255-
for task in tasks:
295+
for task in tqdm(tasks):
256296
rec = train_func(task, experiment_name, **kwargs)
257297
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
258298
recs.append(rec)
@@ -457,6 +497,9 @@ def worker(
457497
task_pool = experiment_name
458498
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
459499

500+
def has_worker(self) -> bool:
501+
return True
502+
460503

461504
class DelayTrainerRM(TrainerRM):
462505
"""
@@ -579,3 +622,6 @@ def worker(self, end_train_func=None, experiment_name: str = None):
579622
experiment_name=experiment_name,
580623
before_status=TaskManager.STATUS_PART_DONE,
581624
)
625+
626+
def has_worker(self) -> bool:
627+
return True

0 commit comments

Comments
 (0)