Skip to content

Commit 3493f29

Browse files
authored
Enhance Task Dict Var (#778)
1 parent e33de44 commit 3493f29

File tree

3 files changed

+71
-10
lines changed

3 files changed

+71
-10
lines changed

examples/model_rolling/task_manager_rolling.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from qlib.workflow.task.manage import TaskManager, run_task
1818
from qlib.workflow.task.collect import RecorderCollector
1919
from qlib.model.ens.group import RollingGroup
20-
from qlib.model.trainer import TrainerRM, task_train
20+
from qlib.model.trainer import TrainerR, TrainerRM, task_train
2121
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
2222

2323

@@ -29,7 +29,7 @@ def __init__(
2929
task_url="mongodb://10.0.0.4:27017/",
3030
task_db_name="rolling_db",
3131
experiment_name="rolling_exp",
32-
task_pool="rolling_task",
32+
task_pool=None, # if user want to "rolling_task"
3333
task_config=None,
3434
rolling_step=550,
3535
rolling_type=RollingGen.ROLL_SD,
@@ -43,14 +43,19 @@ def __init__(
4343
}
4444
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
4545
self.experiment_name = experiment_name
46-
self.task_pool = task_pool
46+
if task_pool is None:
47+
self.trainer = TrainerR(experiment_name=self.experiment_name)
48+
else:
49+
self.task_pool = task_pool
50+
self.trainer = TrainerRM(self.experiment_name, self.task_pool)
4751
self.task_config = task_config
4852
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
4953

5054
# Reset all things to the first status, be careful to save important data
5155
def reset(self):
5256
print("========== reset ==========")
53-
TaskManager(task_pool=self.task_pool).remove()
57+
if isinstance(self.trainer, TrainerRM):
58+
TaskManager(task_pool=self.task_pool).remove()
5459
exp = R.get_exp(experiment_name=self.experiment_name)
5560
for rid in exp.list_recorders():
5661
exp.delete_recorder(rid)
@@ -66,10 +71,10 @@ def task_generating(self):
6671

6772
def task_training(self, tasks):
6873
print("========== task_training ==========")
69-
trainer = TrainerRM(self.experiment_name, self.task_pool)
70-
trainer.train(tasks)
74+
self.trainer.train(tasks)
7175

7276
def worker(self):
77+
# NOTE: this is only used for TrainerRM
7378
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
7479
print("========== worker ==========")
7580
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)

qlib/model/trainer.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,61 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
8686
return R.get_recorder()
8787

8888

89+
def get_item_from_obj(config: dict, name_path: str) -> object:
90+
"""
91+
Follow the name_path to get values from config
92+
For example:
93+
If we follow the example in in the Parameters section,
94+
Timestamp('2008-01-02 00:00:00') will be returned
95+
96+
Parameters
97+
----------
98+
config : dict
99+
e.g.
100+
{'dataset': {'class': 'DatasetH',
101+
'kwargs': {'handler': {'class': 'Alpha158',
102+
'kwargs': {'end_time': '2020-08-01',
103+
'fit_end_time': '<dataset.kwargs.segments.train.1>',
104+
'fit_start_time': '<dataset.kwargs.segments.train.0>',
105+
'instruments': 'csi100',
106+
'start_time': '2008-01-01'},
107+
'module_path': 'qlib.contrib.data.handler'},
108+
'segments': {'test': (Timestamp('2017-01-03 00:00:00'),
109+
Timestamp('2019-04-08 00:00:00')),
110+
'train': (Timestamp('2008-01-02 00:00:00'),
111+
Timestamp('2014-12-31 00:00:00')),
112+
'valid': (Timestamp('2015-01-05 00:00:00'),
113+
Timestamp('2016-12-30 00:00:00'))}}
114+
}}
115+
name_path : str
116+
e.g.
117+
"dataset.kwargs.segments.train.1"
118+
119+
Returns
120+
-------
121+
object
122+
the retrieved object
123+
"""
124+
cur_cfg = config
125+
for k in name_path.split("."):
126+
if isinstance(cur_cfg, dict):
127+
cur_cfg = cur_cfg[k]
128+
elif k.isdigit():
129+
cur_cfg = cur_cfg[int(k)]
130+
else:
131+
raise ValueError(f"Error when getting {k} from cur_cfg")
132+
return cur_cfg
133+
134+
89135
def fill_placeholder(config: dict, config_extend: dict):
90136
"""
91137
Detect placeholder in config and fill them with config_extend.
92138
The item of dict must be single item(int, str, etc), dict and list. Tuples are not supported.
139+
There are two type of variables:
140+
- user-defined variables :
141+
e.g. when config_extend is `{"<MODEL>": model, "<DATASET>": dataset}`, "<MODEL>" and "<DATASET>" in `config` will be replaced with `model` `dataset`
142+
- variables extracted from `config` :
143+
e.g. the variables like "<dataset.kwargs.segments.train.0>" will be replaced with the values from `config`
93144
94145
Parameters
95146
----------
@@ -122,8 +173,13 @@ def fill_placeholder(config: dict, config_extend: dict):
122173
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
123174
item_queue.append(now_item[key])
124175
tail += 1
125-
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
126-
now_item[key] = config_extend[now_item[key]]
176+
elif isinstance(now_item[key], str):
177+
if now_item[key] in config_extend.keys():
178+
now_item[key] = config_extend[now_item[key]]
179+
else:
180+
m = re.match(r"<(?P<name_path>[^<>]+)>", now_item[key])
181+
if m is not None:
182+
now_item[key] = get_item_from_obj(config, m.groupdict()["name_path"])
127183
return config
128184

129185

qlib/tests/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
def get_data_handler_config(
5151
start_time="2008-01-01",
5252
end_time="2020-08-01",
53-
fit_start_time="2008-01-01",
54-
fit_end_time="2014-12-31",
53+
fit_start_time="<dataset.kwargs.segments.train.0>",
54+
fit_end_time="<dataset.kwargs.segments.train.1>",
5555
instruments=CSI300_MARKET,
5656
):
5757
return {

0 commit comments

Comments
 (0)