From 917dee4e51a6e89ee17ae4f5ad09045a729a9ce5 Mon Sep 17 00:00:00 2001 From: CTai Date: Wed, 5 Jan 2022 03:03:38 +0800 Subject: [PATCH 1/4] Fix some warnings in log.py. --- qlib/log.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 0b8b04b4864..8c6bec61392 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,12 +13,12 @@ class MetaLogger(type): - def __new__(cls, name, bases, dict): + def __new__(mcs, name, bases, dict_): wrapper_dict = logging.Logger.__dict__.copy() for key in wrapper_dict: - if key not in dict and key != "__reduce__": - dict[key] = wrapper_dict[key] - return type.__new__(cls, name, bases, dict) + if key not in dict_ and key != "__reduce__": + dict_[key] = wrapper_dict[key] + return type.__new__(mcs, name, bases, dict_) class QlibLogger(metaclass=MetaLogger): @@ -48,7 +48,7 @@ def __getattr__(self, name): return self.logger.__getattribute__(name) -def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger: +def get_module_logger(module_name, level: Optional[int] = None) -> QlibLogger: """ Get a logger for a specific module. @@ -146,6 +146,7 @@ def set_log_with_config(log_config: Dict[Text, Any]): class LogFilter(logging.Filter): def __init__(self, param=None): + super().__init__() self.param = param @staticmethod From f0986ed506b4c98c0f5f3a06a5ccf0e70a4850c8 Mon Sep 17 00:00:00 2001 From: CTai Date: Wed, 5 Jan 2022 09:13:19 +0800 Subject: [PATCH 2/4] Fix typo and using black format. --- qlib/config.py | 37 +++++++++++++++++++++--------- qlib/log.py | 2 +- qlib/workflow/__init__.py | 47 ++++++++++++++++++++++++++++++--------- 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/qlib/config.py b/qlib/config.py index c29338f5a71..a2e66a4ed88 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -4,7 +4,7 @@ About the configs ================= -The config will based on _default_config. +The config will be based on _default_config. Two modes are supported - client - server @@ -28,7 +28,9 @@ class Config: def __init__(self, default_conf): - self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflictions with __getattr__ + self.__dict__["_default_config"] = copy.deepcopy( + default_conf + ) # avoiding conflicts with __getattr__ self.reset() def __getitem__(self, key): @@ -271,7 +273,11 @@ def __init__(self, default_conf): self._registered = False class DataPathManager: - def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]): + def __init__( + self, + provider_uri: Union[str, Path, dict], + mount_path: Union[str, Path, dict], + ): self.provider_uri = provider_uri self.mount_path = mount_path @@ -285,14 +291,19 @@ def format_provider_uri(provider_uri: Union[str, dict, Path]) -> dict: else: raise TypeError(f"provider_uri does not support {type(provider_uri)}") for freq, _uri in provider_uri.items(): - if QlibConfig.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI: + if ( + QlibConfig.DataPathManager.get_uri_type(_uri) + == QlibConfig.LOCAL_URI + ): provider_uri[freq] = str(Path(_uri).expanduser().resolve()) return provider_uri @staticmethod def get_uri_type(uri: Union[str, Path]): uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve()) - is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:' + is_win = ( + re.match("^[a-zA-Z]:.*", uri) is not None + ) # such as 'C:\\data', 'D:' # such as 'host:/data/' (User may define short hostname by themselves or use localhost) is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None @@ -360,10 +371,10 @@ def set(self, default_conf: str = "client", **kwargs): """ configure qlib based on the input parameters - The configure will act like a dictionary. + The configuration will act like a dictionary. - Normally, it literally replace the value according to the keys. - However, sometimes it is hard for users to set the config when the configure is nested and complicated + Normally, it literally is replaced the value according to the keys. + However, sometimes it is hard for users to set the config when the configuration is nested and complicated So this API provides some special parameters for users to set the keys in a more convenient way. - region: REG_CN, REG_US @@ -391,7 +402,9 @@ def set(self, default_conf: str = "client", **kwargs): logger.info(f"default_conf: {default_conf}.") self.set_mode(default_conf) - self.set_region(kwargs.get("region", self["region"] if "region" in self else REG_CN)) + self.set_region( + kwargs.get("region", self["region"] if "region" in self else REG_CN) + ) for k, v in kwargs.items(): if k not in self: @@ -410,7 +423,11 @@ def set(self, default_conf: str = "client", **kwargs): self["expression_cache"] = None # check dataset cache if self.is_depend_redis(self["dataset_cache"]): - log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"] + log_str += ( + f" and {self['dataset_cache']}" + if log_str + else self["dataset_cache"] + ) self["dataset_cache"] = None if log_str: logger.warning( diff --git a/qlib/log.py b/qlib/log.py index 8c6bec61392..71c59c7963e 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -107,7 +107,7 @@ def log_cost_time(cls, info="Done"): """ Get last time mark from stack, calculate time diff with current time, and log time diff and info. :param info: str - Info that will be log into stdout. + Info that will be logged into stdout. """ cost_time = time() - cls.time_marks.pop() cls.timer_logger.info("Time cost: {0:.3f}s | {1}".format(cost_time, info)) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 8e20f4816f3..906a568f98a 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -19,7 +19,9 @@ def __init__(self, exp_manager): self.exp_manager: ExpManager = exp_manager def __repr__(self): - return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager) + return "{name}(manager={manager})".format( + name=self.__class__.__name__, manager=self.exp_manager + ) @contextmanager def start( @@ -76,12 +78,21 @@ def start( try: yield run except Exception as e: - self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong + self.end_exp( + Recorder.STATUS_FA + ) # end the experiment if something went wrong raise e self.end_exp(Recorder.STATUS_FI) def start_exp( - self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False + self, + *, + experiment_id=None, + experiment_name=None, + recorder_id=None, + recorder_name=None, + uri=None, + resume=False, ): """ Lower level method for starting an experiment. When use this method, one should end the experiment manually @@ -216,9 +227,13 @@ def list_recorders(self, experiment_id=None, experiment_name=None): ------- A dictionary (id -> recorder) of recorder information that being stored. """ - return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders() + return self.get_exp( + experiment_id=experiment_id, experiment_name=experiment_name + ).list_recorders() - def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment: + def get_exp( + self, *, experiment_id=None, experiment_name=None, create: bool = True + ) -> Experiment: """ Method for retrieving an experiment with given id or name. Once the `create` argument is set to True, if no valid experiment is found, this method will create one for you. Otherwise, it will @@ -289,7 +304,10 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr An experiment instance with given id or name. """ return self.exp_manager.get_exp( - experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False + experiment_id=experiment_id, + experiment_name=experiment_name, + create=create, + start=False, ) def delete_exp(self, experiment_id=None, experiment_name=None): @@ -355,7 +373,12 @@ def uri_context(self, uri: Text): self.exp_manager.set_uri(prev_uri) def get_recorder( - self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None + self, + *, + recorder_id=None, + recorder_name=None, + experiment_id=None, + experiment_name=None, ) -> Recorder: """ Method for retrieving a recorder. @@ -409,9 +432,9 @@ def get_recorder( ------- A recorder instance. """ - return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder( - recorder_id, recorder_name, create=False, start=False - ) + return self.get_exp( + experiment_name=experiment_name, experiment_id=experiment_id, create=False + ).get_recorder(recorder_id, recorder_name, create=False, start=False) def delete_recorder(self, recorder_id=None, recorder_name=None): """ @@ -433,7 +456,9 @@ def delete_recorder(self, recorder_id=None, recorder_name=None): """ self.get_exp().delete_recorder(recorder_id, recorder_name) - def save_objects(self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any]): + def save_objects( + self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any] + ): """ Method for saving objects as artifacts in the experiment to the uri. It supports either saving from a local file/directory, or directly saving objects. User can use valid python's keywords arguments From 747c5b5f95074abd18bdc69e76114dec14fba4f0 Mon Sep 17 00:00:00 2001 From: CTai Date: Wed, 5 Jan 2022 22:06:59 +0800 Subject: [PATCH 3/4] Fix black. --- qlib/config.py | 23 +++++------------------ qlib/workflow/__init__.py | 26 ++++++++------------------ 2 files changed, 13 insertions(+), 36 deletions(-) diff --git a/qlib/config.py b/qlib/config.py index a2e66a4ed88..2459e916c1a 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -28,9 +28,7 @@ class Config: def __init__(self, default_conf): - self.__dict__["_default_config"] = copy.deepcopy( - default_conf - ) # avoiding conflicts with __getattr__ + self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflicts with __getattr__ self.reset() def __getitem__(self, key): @@ -291,19 +289,14 @@ def format_provider_uri(provider_uri: Union[str, dict, Path]) -> dict: else: raise TypeError(f"provider_uri does not support {type(provider_uri)}") for freq, _uri in provider_uri.items(): - if ( - QlibConfig.DataPathManager.get_uri_type(_uri) - == QlibConfig.LOCAL_URI - ): + if QlibConfig.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI: provider_uri[freq] = str(Path(_uri).expanduser().resolve()) return provider_uri @staticmethod def get_uri_type(uri: Union[str, Path]): uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve()) - is_win = ( - re.match("^[a-zA-Z]:.*", uri) is not None - ) # such as 'C:\\data', 'D:' + is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:' # such as 'host:/data/' (User may define short hostname by themselves or use localhost) is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None @@ -402,9 +395,7 @@ def set(self, default_conf: str = "client", **kwargs): logger.info(f"default_conf: {default_conf}.") self.set_mode(default_conf) - self.set_region( - kwargs.get("region", self["region"] if "region" in self else REG_CN) - ) + self.set_region(kwargs.get("region", self["region"] if "region" in self else REG_CN)) for k, v in kwargs.items(): if k not in self: @@ -423,11 +414,7 @@ def set(self, default_conf: str = "client", **kwargs): self["expression_cache"] = None # check dataset cache if self.is_depend_redis(self["dataset_cache"]): - log_str += ( - f" and {self['dataset_cache']}" - if log_str - else self["dataset_cache"] - ) + log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"] self["dataset_cache"] = None if log_str: logger.warning( diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 906a568f98a..4162969ae6b 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -19,9 +19,7 @@ def __init__(self, exp_manager): self.exp_manager: ExpManager = exp_manager def __repr__(self): - return "{name}(manager={manager})".format( - name=self.__class__.__name__, manager=self.exp_manager - ) + return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager) @contextmanager def start( @@ -78,9 +76,7 @@ def start( try: yield run except Exception as e: - self.end_exp( - Recorder.STATUS_FA - ) # end the experiment if something went wrong + self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong raise e self.end_exp(Recorder.STATUS_FI) @@ -227,13 +223,9 @@ def list_recorders(self, experiment_id=None, experiment_name=None): ------- A dictionary (id -> recorder) of recorder information that being stored. """ - return self.get_exp( - experiment_id=experiment_id, experiment_name=experiment_name - ).list_recorders() + return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders() - def get_exp( - self, *, experiment_id=None, experiment_name=None, create: bool = True - ) -> Experiment: + def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment: """ Method for retrieving an experiment with given id or name. Once the `create` argument is set to True, if no valid experiment is found, this method will create one for you. Otherwise, it will @@ -432,9 +424,9 @@ def get_recorder( ------- A recorder instance. """ - return self.get_exp( - experiment_name=experiment_name, experiment_id=experiment_id, create=False - ).get_recorder(recorder_id, recorder_name, create=False, start=False) + return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder( + recorder_id, recorder_name, create=False, start=False + ) def delete_recorder(self, recorder_id=None, recorder_name=None): """ @@ -456,9 +448,7 @@ def delete_recorder(self, recorder_id=None, recorder_name=None): """ self.get_exp().delete_recorder(recorder_id, recorder_name) - def save_objects( - self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any] - ): + def save_objects(self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any]): """ Method for saving objects as artifacts in the experiment to the uri. It supports either saving from a local file/directory, or directly saving objects. User can use valid python's keywords arguments From ce8011cd4674965a76f620e523b8f06a3efba15e Mon Sep 17 00:00:00 2001 From: CTai Date: Thu, 6 Jan 2022 13:23:53 +0800 Subject: [PATCH 4/4] Rename dict_ to attrs --- qlib/log.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 71c59c7963e..debc24fa418 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,12 +13,12 @@ class MetaLogger(type): - def __new__(mcs, name, bases, dict_): + def __new__(mcs, name, bases, attrs): wrapper_dict = logging.Logger.__dict__.copy() for key in wrapper_dict: - if key not in dict_ and key != "__reduce__": - dict_[key] = wrapper_dict[key] - return type.__new__(mcs, name, bases, dict_) + if key not in attrs and key != "__reduce__": + attrs[key] = wrapper_dict[key] + return type.__new__(mcs, name, bases, attrs) class QlibLogger(metaclass=MetaLogger):