Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion qlib/data/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,12 @@ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame
df = self._data.loc(axis=0)[:, instruments]
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
if start_time is not None:
start_time = pd.Timestamp(start_time)
if end_time is not None:
end_time = pd.Timestamp(end_time)
return df.loc[start_time:end_time]

def _maybe_load_raw_data(self):
if self._data is not None:
Expand Down
2 changes: 1 addition & 1 deletion qlib/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge

class TimeInspector:

timer_logger = get_module_logger("timer", level=logging.WARNING)
timer_logger = get_module_logger("timer", level=logging.INFO)

time_marks = []

Expand Down
6 changes: 3 additions & 3 deletions qlib/utils/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list
@classmethod
def load(cls, filepath):
"""
Load the collector from a filepath.
Load the serializable class from a filepath.

Args:
filepath (str): the path of file

Raises:
TypeError: the pickled file must be `Collector`
TypeError: the pickled file must be `type(cls)`

Returns:
Collector: the instance of Collector
`type(cls)`: the instance of `type(cls)`
"""
with open(filepath, "rb") as f:
object = cls.get_backend().load(f)
Expand Down
5 changes: 2 additions & 3 deletions qlib/workflow/online/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,9 @@ def update(self, dataset: DatasetH = None):
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797

start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time > self.to_date:
if self.last_end >= self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
f"The prediction in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
)
return

Expand Down
56 changes: 43 additions & 13 deletions qlib/workflow/online/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from qlib.log import get_module_logger
from qlib.utils import get_cls_kwargs
from qlib.utils.exceptions import QlibException
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
Expand Down Expand Up @@ -90,15 +91,15 @@ class OnlineToolR(OnlineTool):
The implementation of OnlineTool based on (R)ecorder.
"""

def __init__(self, experiment_name: str):
def __init__(self, default_exp_name: str = None):
"""
Init OnlineToolR.

Args:
experiment_name (str): the experiment name.
default_exp_name (str): the default experiment name.
"""
super().__init__()
self.exp_name = experiment_name
self.default_exp_name = default_exp_name

def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Expand Down Expand Up @@ -127,45 +128,74 @@ def get_online_tag(self, recorder: Recorder) -> str:
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)

def reset_online_tag(self, recorder: Union[Recorder, List]):
def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):
"""
Offline all models and set the recorders to 'online'.

Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
exp_name (str): the experiment name. If None, then use default_exp_name.

"""
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
recs = list_recorders(exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)

def online_models(self) -> list:
def online_models(self, exp_name: str = None) -> list:
"""
Get current `online` models

Args:
exp_name (str): the experiment name. If None, then use default_exp_name.

Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())

def update_online_pred(self, to_date=None):
def update_online_pred(self, to_date=None, exp_name: str = None):
"""
Update the predictions of online models to to_date.

Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
"""
online_models = self.online_models()
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
online_models = self.online_models(exp_name=exp_name)
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()

self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
try:
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
except QlibException as e:
# skip the recorder without pred
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
continue
updater.update()

self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")
12 changes: 9 additions & 3 deletions qlib/workflow/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime

from mlflow.exceptions import MlflowException
from qlib.utils.exceptions import QlibException
from ..utils.objm import FileManager
from ..log import get_module_logger

Expand Down Expand Up @@ -308,9 +311,12 @@ def save_objects(self, local_path=None, artifact_path=None, **kwargs):

def load_object(self, name):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
try:
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
except OSError as e:
raise QlibException(message=str(e))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exception should be defined in the interface of Recorder


def log_params(self, **kwargs):
for name, data in kwargs.items():
Expand Down