Skip to content

Commit 5a16a5a

Browse files
authored
Auto injecting model and dataset for Recorder (microsoft#645)
* Auto injecting model and dataset for Recorder * Support using Feature in expression
1 parent 3b11912 commit 5a16a5a

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

qlib/data/ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing import Union, List, Type
1414
from scipy.stats import percentileofscore
1515

16-
from .base import Expression, ExpressionOps
16+
from .base import Expression, ExpressionOps, Feature
1717
from ..log import get_module_logger
1818
from ..utils import get_callable_kwargs
1919

@@ -1485,6 +1485,7 @@ def __init__(self, feature_left, feature_right, N):
14851485
IdxMax,
14861486
IdxMin,
14871487
If,
1488+
Feature,
14881489
]
14891490

14901491

@@ -1517,7 +1518,7 @@ def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):
15171518
else:
15181519
_ops_class = _operator
15191520

1520-
if not issubclass(_ops_class, ExpressionOps):
1521+
if not issubclass(_ops_class, Expression):
15211522
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
15221523

15231524
if _ops_class.__name__ in self._ops:

qlib/model/trainer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,19 @@ def fill_placeholder(config: dict, config_extend: dict):
7070
# bfs
7171
top = 0
7272
tail = 1
73-
item_quene = [config]
73+
item_queue = [config]
7474
while top < tail:
75-
now_item = item_quene[top]
75+
now_item = item_queue[top]
7676
top += 1
7777
if isinstance(now_item, list):
7878
item_keys = range(len(now_item))
7979
elif isinstance(now_item, dict):
8080
item_keys = now_item.keys()
8181
for key in item_keys:
8282
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
83-
item_quene.append(now_item[key])
83+
item_queue.append(now_item[key])
8484
tail += 1
85-
elif now_item[key] in config_extend.keys():
85+
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
8686
now_item[key] = config_extend[now_item[key]]
8787
return config
8888

@@ -114,10 +114,19 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
114114
task_config = fill_placeholder(task_config, placehorder_value)
115115
# generate records: prediction, backtest, and analysis
116116
records = task_config.get("record", [])
117-
if isinstance(records, dict): # prevent only one dict
117+
if isinstance(records, dict): # uniform the data format to list
118118
records = [records]
119+
119120
for record in records:
120-
r = init_instance_by_config(record, recorder=rec)
121+
# Some recorder require the parameter `model` and `dataset`.
122+
# try to automatically pass in them to the initialization function
123+
# to make defining the tasking easier
124+
r = init_instance_by_config(
125+
record,
126+
recorder=rec,
127+
default_module="qlib.workflow.record_temp",
128+
try_kwargs={"model": model, "dataset": dataset},
129+
)
121130
r.generate()
122131
return rec
123132

qlib/utils/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import numpy as np
2828
import pandas as pd
2929
from pathlib import Path
30-
from typing import Union, Tuple, Any, Text, Optional
30+
from typing import Dict, Union, Tuple, Any, Text, Optional
3131
from types import ModuleType
3232
from urllib.parse import urlparse
3333

@@ -232,7 +232,11 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
232232

233233

234234
def init_instance_by_config(
235-
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
235+
config: Union[str, dict, object],
236+
default_module=None,
237+
accept_types: Union[type, Tuple[type]] = (),
238+
try_kwargs: Dict = {},
239+
**kwargs,
236240
) -> Any:
237241
"""
238242
get initialized instance with config
@@ -270,6 +274,10 @@ def init_instance_by_config(
270274
Optional. If the config is a instance of specific type, return the config directly.
271275
This will be passed into the second parameter of isinstance.
272276
277+
try_kwargs: Dict
278+
Try to pass in kwargs in `try_kwargs` when initialized the instance
279+
If error occurred, it will fail back to initialization without try_kwargs.
280+
273281
Returns
274282
-------
275283
object:
@@ -286,7 +294,14 @@ def init_instance_by_config(
286294
return pickle.load(f)
287295

288296
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
289-
return klass(**cls_kwargs, **kwargs)
297+
298+
try:
299+
return klass(**cls_kwargs, **try_kwargs, **kwargs)
300+
except (TypeError,):
301+
# TypeError for handling errors like
302+
# 1: `XXX() got multiple values for keyword argument 'YYY'`
303+
# 2: `XXX() got an unexpected keyword argument 'YYY'
304+
return klass(**cls_kwargs, **kwargs)
290305

291306

292307
@contextlib.contextmanager

0 commit comments

Comments
 (0)