diff --git a/qlib/data/data.py b/qlib/data/data.py index 3dcb226990b..5dad558e7d4 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -25,7 +25,12 @@ from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache -from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path +from ..utils import ( + Wrapper, + init_instance_by_config, + register_wrapper, + get_module_by_module_path, +) class CalendarProvider(abc.ABC): @@ -221,7 +226,11 @@ def convert_instruments(self, instrument): _df_list = [] # FIXME: each process will read these files for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"): - _df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) + _df = pd.read_csv( + _path, + sep="\t", + names=["inst", "start_datetime", "end_datetime", "save_inst"], + ) _df_list.append(_df.iloc[:, [0, -1]]) df = pd.concat(_df_list, sort=False).sort_values("save_inst") df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill") @@ -588,7 +597,11 @@ def _load_instruments(self, market): if not os.path.exists(fname): raise ValueError("instruments not exists for market " + market) _instruments = dict() - df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) + df = pd.read_csv( + fname, + sep="\t", + names=["inst", "start_datetime", "end_datetime", "save_inst"], + ) df["start_datetime"] = pd.to_datetime(df["start_datetime"]) df["end_datetime"] = pd.to_datetime(df["end_datetime"]) for row in df.itertuples(index=False): @@ -1035,12 +1048,31 @@ def __init__(self): DatasetD.set_conn(self.client) -Cal = Wrapper() -Inst = Wrapper() -FeatureD = Wrapper() -ExpressionD = Wrapper() -DatasetD = Wrapper() -D = Wrapper() +import sys + +if sys.version_info >= (3, 9): + from typing import Annotated + + CalendarProviderWrapper = Annotated[CalendarProvider, Wrapper] + InstrumentProviderWrapper = Annotated[InstrumentProvider, Wrapper] + FeatureProviderWrapper = Annotated[FeatureProvider, Wrapper] + ExpressionProviderWrapper = Annotated[ExpressionProvider, Wrapper] + DatasetProviderWrapper = Annotated[DatasetProvider, Wrapper] + BaseProviderWrapper = Annotated[BaseProvider, Wrapper] +else: + CalendarProviderWrapper = CalendarProvider + InstrumentProviderWrapper = InstrumentProvider + FeatureProviderWrapper = FeatureProvider + ExpressionProviderWrapper = ExpressionProvider + DatasetProviderWrapper = DatasetProvider + BaseProviderWrapper = BaseProvider + +Cal: CalendarProviderWrapper = Wrapper() +Inst: InstrumentProviderWrapper = Wrapper() +FeatureD: FeatureProviderWrapper = Wrapper() +ExpressionD: ExpressionProviderWrapper = Wrapper() +DatasetD: DatasetProviderWrapper = Wrapper() +D: BaseProviderWrapper = Wrapper() def register_all_wrappers(): diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index e65bfb03ff9..15faa0da10d 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -461,5 +461,14 @@ def set_tags(self, **kwargs): self.get_exp().get_recorder().set_tags(**kwargs) +import sys + +if sys.version_info >= (3, 9): + from typing import Annotated + + QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper] +else: + QlibRecorderWrapper = QlibRecorder + # global record -R = Wrapper() +R: QlibRecorderWrapper = Wrapper()