Skip to content

Commit 0aee46e

Browse files
authored
Merge pull request #466 from you-n-g/online_hotfix
Online bug fix, enhancement & docs for dataset, workflow, trainer ...
2 parents 5a50d7c + 9c8d423 commit 0aee46e

File tree

15 files changed

+167
-56
lines changed

15 files changed

+167
-56
lines changed

qlib/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
2020
from .config import C
2121
from .data.cache import H
2222

23-
H.clear()
24-
2523
# FIXME: this logger ignored the level in config
2624
logger = get_module_logger("Initialization", level=logging.INFO)
2725

26+
skip_if_reg = kwargs.pop("skip_if_reg", False)
27+
if skip_if_reg and C.registered:
28+
# if we reinitialize Qlib during running an experiment `R.start`.
29+
# it will result in loss of the recorder
30+
logger.warning("Skip initialization because `skip_if_reg is True`")
31+
return
32+
33+
H.clear()
2834
C.set(default_conf, **kwargs)
2935

3036
# check path if server/local
@@ -197,14 +203,15 @@ def auto_init(**kwargs):
197203
- Find the project configuration and init qlib
198204
- The parsing process will be affected by the `conf_type` of the configuration file
199205
- Init qlib with default config
206+
- Skip initialization if already initialized
200207
"""
208+
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
201209

202210
try:
203211
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
204212
except FileNotFoundError:
205213
init(**kwargs)
206214
else:
207-
208215
conf_pp = pp / "config.yaml"
209216
with conf_pp.open() as f:
210217
conf = yaml.safe_load(f)

qlib/data/dataset/__init__.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ...utils.serial import Serializable
22
from typing import Union, List, Tuple, Dict, Text, Optional
3-
from ...utils import init_instance_by_config, np_ffill
3+
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
44
from ...log import get_module_logger
55
from .handler import DataHandler, DataHandlerLP
66
from copy import deepcopy
@@ -243,6 +243,8 @@ class TSDataSampler:
243243
244244
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
245245
dataset based on tabular data.
246+
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
247+
data.
246248
247249
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
248250
more powerful subclasses.
@@ -309,11 +311,19 @@ def __init__(
309311
self.data_index = deepcopy(self.data.index)
310312

311313
if flt_data is not None:
312-
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
314+
if isinstance(flt_data, pd.DataFrame):
315+
assert len(flt_data.columns) == 1
316+
flt_data = flt_data.iloc[:, 0]
317+
# NOTE: bool(np.nan) is True !!!!!!!!
318+
# make sure reindex comes first. Otherwise extra NaN may appear.
319+
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
320+
self.flt_data = flt_data.values
313321
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
314322
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
315323

316-
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
324+
self.start_idx, self.end_idx = self.data_index.slice_locs(
325+
start=time_to_slc_point(start), end=time_to_slc_point(end)
326+
)
317327
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
318328

319329
del self.data # save memory
@@ -341,7 +351,7 @@ def config(self, **kwargs):
341351
setattr(self, k, v)
342352

343353
@staticmethod
344-
def build_index(data: pd.DataFrame) -> dict:
354+
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
345355
"""
346356
The relation of the data
347357
@@ -352,9 +362,15 @@ def build_index(data: pd.DataFrame) -> dict:
352362
353363
Returns
354364
-------
355-
dict:
356-
{<index>: <prev_index or None>}
357-
# get the previous index of a line given index
365+
Tuple[pd.DataFrame, dict]:
366+
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
367+
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
368+
datetime
369+
2021-01-11 0 1 2 3 4 5 ...
370+
2021-01-12 4146 4147 4148 4149 4150 4151 ...
371+
2021-01-13 8293 8294 8295 8296 8297 8298 ...
372+
2021-01-14 12441 12442 12443 12444 12445 12446 ...
373+
2) the second element: {<original index>: <row, col>}
358374
"""
359375
# object incase of pandas converting int to flaot
360376
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)

qlib/log.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
2828

2929
def __init__(self, module_name):
3030
self.module_name = module_name
31-
self.level = 0
31+
# this feature name conflicts with the attribute with Logger
32+
# rename it to avoid some corner cases that result in comparing `str` and `int`
33+
self.__level = 0
3234

3335
@property
3436
def logger(self):
3537
logger = logging.getLogger(self.module_name)
36-
logger.setLevel(self.level)
38+
logger.setLevel(self.__level)
3739
return logger
3840

3941
def setLevel(self, level):
40-
self.level = level
42+
self.__level = level
4143

4244
def __getattr__(self, name):
4345
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.

qlib/model/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def finetune(self, dataset: Dataset):
9797
9898
# Finetune model based on previous trained model
9999
with R.start(experiment_name="finetune model"):
100-
recorder = R.get_recorder(rid, experiment_name="init models")
100+
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
101101
model = recorder.load_object("init_model")
102102
model.finetune(dataset, num_boost_round=10)
103103

qlib/model/trainer.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
99
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
1010
11-
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
11+
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
1212
"""
1313

1414
import socket
@@ -153,6 +153,9 @@ def is_delay(self) -> bool:
153153
"""
154154
return self.delay
155155

156+
def __call__(self, *args, **kwargs) -> list:
157+
return self.end_train(self.train(*args, **kwargs))
158+
156159

157160
class TrainerR(Trainer):
158161
"""
@@ -286,19 +289,26 @@ class TrainerRM(Trainer):
286289
# This tag is the _id in TaskManager to distinguish tasks.
287290
TM_ID = "_id in TaskManager"
288291

289-
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
292+
def __init__(
293+
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
294+
):
290295
"""
291296
Init TrainerR.
292297
293298
Args:
294299
experiment_name (str): the default name of experiment.
295300
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
296301
train_func (Callable, optional): default training method. Defaults to `task_train`.
302+
skip_run_task (bool):
303+
If skip_run_task == True:
304+
Only run_task in the worker. Otherwise skip run_task.
297305
"""
306+
298307
super().__init__()
299308
self.experiment_name = experiment_name
300309
self.task_pool = task_pool
301310
self.train_func = train_func
311+
self.skip_run_task = skip_run_task
302312

303313
def train(
304314
self,
@@ -340,15 +350,16 @@ def train(
340350
tm = TaskManager(task_pool=task_pool)
341351
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
342352
query = {"_id": {"$in": _id_list}}
343-
run_task(
344-
train_func,
345-
task_pool,
346-
query=query, # only train these tasks
347-
experiment_name=experiment_name,
348-
before_status=before_status,
349-
after_status=after_status,
350-
**kwargs,
351-
)
353+
if not self.skip_run_task:
354+
run_task(
355+
train_func,
356+
task_pool,
357+
query=query, # only train these tasks
358+
experiment_name=experiment_name,
359+
before_status=before_status,
360+
after_status=after_status,
361+
**kwargs,
362+
)
352363

353364
if not self.is_delay():
354365
tm.wait(query=query)
@@ -411,6 +422,7 @@ def __init__(
411422
task_pool: str = None,
412423
train_func=begin_task_train,
413424
end_train_func=end_task_train,
425+
skip_run_task: bool = False,
414426
):
415427
"""
416428
Init DelayTrainerRM.
@@ -420,10 +432,15 @@ def __init__(
420432
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
421433
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
422434
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
435+
skip_run_task (bool):
436+
If skip_run_task == True:
437+
Only run_task in the worker. Otherwise skip run_task.
438+
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
423439
"""
424440
super().__init__(experiment_name, task_pool, train_func)
425441
self.end_train_func = end_train_func
426442
self.delay = True
443+
self.skip_run_task = skip_run_task
427444

428445
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
429446
"""
@@ -477,14 +494,15 @@ def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kw
477494
_id_list.append(rec.list_tags()[self.TM_ID])
478495

479496
query = {"_id": {"$in": _id_list}}
480-
run_task(
481-
end_train_func,
482-
task_pool,
483-
query=query, # only train these tasks
484-
experiment_name=experiment_name,
485-
before_status=TaskManager.STATUS_PART_DONE,
486-
**kwargs,
487-
)
497+
if not self.skip_run_task:
498+
run_task(
499+
end_train_func,
500+
task_pool,
501+
query=query, # only train these tasks
502+
experiment_name=experiment_name,
503+
before_status=TaskManager.STATUS_PART_DONE,
504+
**kwargs,
505+
)
488506

489507
TaskManager(task_pool=task_pool).wait(query=query)
490508

qlib/utils/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
642642
return pred_left, pred_right
643643

644644

645+
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
646+
"""
647+
Time slicing in Qlib or Pandas is a frequently-used action.
648+
However, user often input all kinds of data format to represent time.
649+
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
650+
651+
Parameters
652+
----------
653+
t : Union[None, str, pd.Timestamp]
654+
original time
655+
656+
Returns
657+
-------
658+
Union[None, pd.Timestamp]:
659+
"""
660+
if t is None:
661+
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
662+
return t
663+
else:
664+
return pd.Timestamp(t)
665+
666+
645667
def can_use_cache():
646668
res = True
647669
r = get_redis_connection()

qlib/workflow/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def list_recorders(self, experiment_id=None, experiment_name=None):
216216
-------
217217
A dictionary (id -> recorder) of recorder information that being stored.
218218
"""
219-
return self.get_exp(experiment_id, experiment_name).list_recorders()
219+
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
220220

221-
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
221+
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
222222
"""
223223
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
224224
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
@@ -263,7 +263,7 @@ def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True)
263263
264264
# Case 2
265265
with R.start('test'):
266-
exp = R.get_exp('test1')
266+
exp = R.get_exp(experiment_name='test1')
267267
268268
# Case 3
269269
exp = R.get_exp() -> a default experiment.
@@ -288,7 +288,9 @@ def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True)
288288
-------
289289
An experiment instance with given id or name.
290290
"""
291-
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
291+
return self.exp_manager.get_exp(
292+
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
293+
)
292294

293295
def delete_exp(self, experiment_id=None, experiment_name=None):
294296
"""
@@ -332,7 +334,9 @@ def set_uri(self, uri: Optional[Text]):
332334
"""
333335
self.exp_manager.set_uri(uri)
334336

335-
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
337+
def get_recorder(
338+
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
339+
) -> Recorder:
336340
"""
337341
Method for retrieving a recorder.
338342
@@ -385,7 +389,7 @@ def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=Non
385389
-------
386390
A recorder instance.
387391
"""
388-
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
392+
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
389393
recorder_id, recorder_name, create=False, start=False
390394
)
391395

qlib/workflow/exp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
from typing import Union
45
import mlflow, logging
56
from mlflow.entities import ViewType
67
from mlflow.exceptions import MlflowException
@@ -213,11 +214,15 @@ def _get_recorder(self, recorder_id=None, recorder_name=None):
213214
"""
214215
raise NotImplementedError(f"Please implement the `_get_recorder` method")
215216

216-
def list_recorders(self):
217+
def list_recorders(self, **flt_kwargs):
217218
"""
218219
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
219220
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
220221
222+
flt_kwargs : dict
223+
filter recorders by conditions
224+
e.g. list_recorders(status=Recorder.STATUS_FI)
225+
221226
Returns
222227
-------
223228
A dictionary (id -> recorder) of recorder information that being stored.
@@ -320,11 +325,21 @@ def delete_recorder(self, recorder_id=None, recorder_name=None):
320325

321326
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
322327

323-
def list_recorders(self, max_results=UNLIMITED):
328+
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
329+
"""
330+
Parameters
331+
----------
332+
max_results : int
333+
the number limitation of the results
334+
status : str
335+
the criteria based on status to filter results.
336+
`None` indicates no filtering.
337+
"""
324338
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
325339
recorders = dict()
326340
for i in range(len(runs)):
327341
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
328-
recorders[runs[i].info.run_id] = recorder
342+
if status is None or recorder.status == status:
343+
recorders[runs[i].info.run_id] = recorder
329344

330345
return recorders

0 commit comments

Comments
 (0)