diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 88df0c3e2e6..f64b4d32809 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,7 +35,7 @@ jobs: pip install numpy==1.19.5 ruamel.yaml pip install pyqlib --ignore-installed - - name: Make html with sphnix + - name: Make html with sphinx run: | pip install -U sphinx pip install sphinx_rtd_theme readthedocs_sphinx_ext @@ -97,12 +97,21 @@ jobs: run: | pip install --upgrade pip pip install flake8 - flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib + flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib + + # https://github.com/python/mypy/issues/10600 + - name: Check Qlib with mypy + run: | + pip install mypy + mypy qlib --install-types --non-interactive || true + mypy qlib - name: Test data downloads run: | python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')" + azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive + mv /tmp/qlibpublic/data tests/.data - name: Test workflow by config (install from pip) run: | @@ -113,6 +122,7 @@ jobs: - name: Install Qlib from source run: | pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't + pip install gym tianshou torch pip install -e . - name: Install test dependencies @@ -129,4 +139,3 @@ jobs: - name: Test workflow by config (install from source) run: | python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml - diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml index 6f40a3b1f79..0deab4d731c 100644 --- a/.github/workflows/test_macos.yml +++ b/.github/workflows/test_macos.yml @@ -38,7 +38,7 @@ jobs: run: | pip install --upgrade pip pip install flake8 - flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib + flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib - name: Install Qlib with pip run: | @@ -65,6 +65,8 @@ jobs: run: | python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')" + azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive + mv /tmp/qlibpublic/data tests/.data - name: Test workflow by config (install from pip) run: | python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml @@ -75,6 +77,7 @@ jobs: python -m pip install --upgrade cython python -m pip install numpy jupyter jupyter_contrib_nbextensions python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't + python -m pip install gym tianshou torch pip install -e . - name: Install test dependencies run: | diff --git a/.gitignore b/.gitignore index a563ed5c7f7..3f9c159aa4c 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,10 @@ examples/estimator/estimator_example/ *.egg-info/ +# test related +test-output.xml +.output +.data # special software mlruns/ @@ -34,6 +38,7 @@ mlruns/ tags .pytest_cache/ +.mypy_cache/ .vscode/ *.swp diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 00000000000..195a0505f70 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,17 @@ +[mypy] +exclude = (?x)( + ^qlib/backtest + | ^qlib/contrib + | ^qlib/data + | ^qlib/model + | ^qlib/strategy + | ^qlib/tests + | ^qlib/utils + | ^qlib/workflow + | ^qlib/config\.py$ + | ^qlib/log\.py$ + | ^qlib/__init__\.py$ + ) +ignore_missing_imports = true +disallow_incomplete_defs = true +follow_imports = skip diff --git a/qlib/constant.py b/qlib/constant.py index fa014a29570..458890957d2 100644 --- a/qlib/constant.py +++ b/qlib/constant.py @@ -8,3 +8,6 @@ # Epsilon for avoiding division by zero. EPS = 1e-12 + +# Infinity in integer +INF = 10**18 diff --git a/qlib/log.py b/qlib/log.py index 8533a7ba377..a2fc5deff09 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -61,7 +61,11 @@ def get_module_logger(module_name, level: Optional[int] = None) -> QlibLogger: if level is None: level = C.logging_level - module_name = "qlib.{}".format(module_name) + if not module_name.startswith("qlib."): + # Add a prefix of qlib. when the requested ``module_name`` doesn't start with ``qlib.``. + # If the module_name is already qlib.xxx, we do not format here. Otherwise, it will become qlib.qlib.xxx. + module_name = "qlib.{}".format(module_name) + # Get logger. module_logger = QlibLogger(module_name) module_logger.setLevel(level) diff --git a/qlib/rl/aux_info.py b/qlib/rl/aux_info.py new file mode 100644 index 00000000000..65cd95d5dd7 --- /dev/null +++ b/qlib/rl/aux_info.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Generic, TYPE_CHECKING, TypeVar + +from qlib.typehint import final + +from .simulator import StateType + +if TYPE_CHECKING: + from .utils.env_wrapper import EnvWrapper + + +__all__ = ["AuxiliaryInfoCollector"] + +AuxInfoType = TypeVar("AuxInfoType") + + +class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]): + """Override this class to collect customized auxiliary information from environment.""" + + env: EnvWrapper | None = None + + @final + def __call__(self, simulator_state: StateType) -> AuxInfoType: + return self.collect(simulator_state) + + def collect(self, simulator_state: StateType) -> AuxInfoType: + """Override this for customized auxiliary info. + Usually useful in Multi-agent RL. + + Parameters + ---------- + simulator_state + Retrieved with ``simulator.get_state()``. + + Returns + ------- + Auxiliary information. + """ + raise NotImplementedError("collect is not implemented!") diff --git a/qlib/rl/data/__init__.py b/qlib/rl/data/__init__.py new file mode 100644 index 00000000000..d73517534c8 --- /dev/null +++ b/qlib/rl/data/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Common utilities to handle ad-hoc-styled data. + +Most of these snippets comes from research project (paper code). +Please take caution when using them in production. +""" diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py new file mode 100644 index 00000000000..6cf38680147 --- /dev/null +++ b/qlib/rl/data/pickle_styled.py @@ -0,0 +1,257 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""This module contains utilities to read financial data from pickle-styled files. + +This is the format used in `OPD paper `__. NOT the standard data format in qlib. + +The data here are all wrapped with ``@lru_cache``, which saves the expensive IO cost to repetitively read the data. +We also encourage users to use ``get_xxx_yyy`` rather than ``XxxYyy`` (although they are the same thing), +because ``get_xxx_yyy`` is cache-optimized. + +Note that these pickle files are dumped with Python 3.8. Python lower than 3.7 might not be able to load them. +See `PEP 574 `__ for details. + +This file shows resemblence to qlib.backtest.high_performance_ds. We might merge those two in future. +""" + +# TODO: merge with qlib/backtest/high_performance_ds.py + +from __future__ import annotations + +from functools import lru_cache +from typing import List, Sequence, cast +from pathlib import Path + +import cachetools +import numpy as np +import pandas as pd +from cachetools.keys import hashkey + +from qlib.backtest.decision import OrderDir, Order +from qlib.typehint import Literal + + +DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] +"""Several ad-hoc deal price. +``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``. +``bid_or_ask_fill``: Based on ``bid_or_ask``. If price is 0, use another price (``$ask0`` / ``$bid0``) instead. +``close``: Use close price (``$close0``) as deal price. +""" + + +def _infer_processed_data_column_names(shape: int) -> list[str]: + if shape == 16: + return [ + "$open", + "$high", + "$low", + "$close", + "$vwap", + "$bid", + "$ask", + "$volume", + "$bidV", + "$bidV1", + "$bidV3", + "$bidV5", + "$askV", + "$askV1", + "$askV3", + "$askV5", + ] + if shape == 6: + return ["$high", "$low", "$open", "$close", "$vwap", "$volume"] + elif shape == 5: + return ["$high", "$low", "$open", "$close", "$volume"] + raise ValueError(f"Unrecognized data shape: {shape}") + + +def _find_pickle(filename_without_suffix: Path) -> Path: + suffix_list = [".pkl", ".pkl.backtest"] + paths: List[Path] = [] + for suffix in suffix_list: + path = filename_without_suffix.parent / (filename_without_suffix.name + suffix) + if path.exists(): + paths.append(path) + if not paths: + raise FileNotFoundError(f"No file starting with '{filename_without_suffix}' found") + if len(paths) > 1: + raise ValueError(f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}") + return paths[0] + + +@lru_cache(maxsize=10) # 10 * 40M = 400MB +def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: + return pd.read_pickle(_find_pickle(filename_without_suffix)) + + +class IntradayBacktestData: + """Raw market data that is often used in backtesting (thus called BacktestData).""" + + def __init__( + self, + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + deal_price: DealPriceType = "close", + order_dir: int | None = None, + ): + backtest = _read_pickle(data_dir / stock_id) + backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] + + # No longer need for pandas >= 1.4 + # backtest = backtest.droplevel([0, 2]) + + self.data: pd.DataFrame = backtest + self.deal_price_type: DealPriceType = deal_price + self.order_dir: int | None = order_dir + + def __repr__(self): + with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + return f"{self.__class__.__name__}({self.data})" + + def __len__(self): + return len(self.data) + + def get_deal_price(self) -> pd.Series: + """Return a pandas series that can be indexed with time. + See :attribute:`DealPriceType` for details.""" + if self.deal_price_type in ("bid_or_ask", "bid_or_ask_fill"): + if self.order_dir is None: + raise ValueError("Order direction cannot be none when deal_price_type is not close.") + if self.order_dir == OrderDir.SELL: + col = "$bid0" + else: # BUY + col = "$ask0" + elif self.deal_price_type == "close": + col = "$close0" + else: + raise ValueError(f"Unsupported deal_price_type: {self.deal_price_type}") + price = self.data[col] + + if self.deal_price_type == "bid_or_ask_fill": + if self.order_dir == OrderDir.SELL: + fill_col = "$ask0" + else: + fill_col = "$bid0" + price = price.replace(0, np.nan).fillna(self.data[fill_col]) + + return price + + def get_volume(self) -> pd.Series: + """Return a volume series that can be indexed with time.""" + return self.data["$volume0"] + + def get_time_index(self) -> pd.DatetimeIndex: + return cast(pd.DatetimeIndex, self.data.index) + + +class IntradayProcessedData: + """Processed market data after data cleanup and feature engineering. + + It contains both processed data for "today" and "yesterday", as some algorithms + might use the market information of the previous day to assist decision making. + """ + + today: pd.DataFrame + """Processed data for "today". + Number of records must be ``time_length``, and columns must be ``feature_dim``.""" + + yesterday: pd.DataFrame + """Processed data for "yesterday". + Number of records must be ``time_length``, and columns must be ``feature_dim``.""" + + def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index): + proc = _read_pickle(data_dir / stock_id) + # We have to infer the names here because, + # unfortunately they are not included in the original data. + cnames = _infer_processed_data_column_names(feature_dim) + + time_length: int = len(time_index) + + try: + # new data format + proc = proc.loc[pd.IndexSlice[stock_id, :, date]] + assert len(proc) == time_length and len(proc.columns) == feature_dim * 2 + proc_today = proc[cnames] + proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2]) + except (IndexError, KeyError): + # legacy data + proc = proc.loc[pd.IndexSlice[stock_id, date]] + assert time_length * feature_dim * 2 == len(proc) + proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim)) + proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim)) + proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames) + proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames) + + self.today: pd.DataFrame = proc_today + self.yesterday: pd.DataFrame = proc_yesterday + assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim + assert len(self.today) == len(self.yesterday) == time_length + + def __repr__(self): + with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + return f"{self.__class__.__name__}({self.today}, {self.yesterday})" + + +@lru_cache(maxsize=100) # 100 * 50K = 5MB +def load_intraday_backtest_data( + data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None +) -> IntradayBacktestData: + return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) + + +@cachetools.cached( # type: ignore + cache=cachetools.LRUCache(100), # 100 * 50K = 5MB + key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), +) +def load_intraday_processed_data( + data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index +) -> IntradayProcessedData: + return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) + + +def load_orders( + order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None +) -> Sequence[Order]: + """Load orders, and set start time and end time for the orders.""" + + start_time = start_time or pd.Timestamp("0:00:00") + end_time = end_time or pd.Timestamp("23:59:59") + + if order_path.is_file(): + order_df = pd.read_pickle(order_path) + else: + order_df = [] + for file in order_path.iterdir(): + order_data = pd.read_pickle(file) + order_df.append(order_data) + order_df = pd.concat(order_df) + + order_df = order_df.reset_index() + + # Legacy-style orders have "date" instead of "datetime" + if "date" in order_df.columns: + order_df = order_df.rename(columns={"date": "datetime"}) + + # Sometimes "date" are str rather than Timestamp + order_df["datetime"] = pd.to_datetime(order_df["datetime"]) + + orders: List[Order] = [] + + for _, row in order_df.iterrows(): + # filter out orders with amount == 0 + if row["amount"] <= 0: + continue + orders.append( + Order( + row["instrument"], + row["amount"], + int(row["order_type"]), + row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second), + row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second), + ) + ) + + return orders diff --git a/qlib/rl/entries/__init__.py b/qlib/rl/entries/__init__.py new file mode 100644 index 00000000000..169fa985c50 --- /dev/null +++ b/qlib/rl/entries/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Train, test, inference utilities. + +The APIs in this directory are NOT considered final and are subject to change! +""" diff --git a/qlib/rl/entries/test.py b/qlib/rl/entries/test.py new file mode 100644 index 00000000000..ca311407b15 --- /dev/null +++ b/qlib/rl/entries/test.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import copy +from typing import Callable, Sequence + +from tianshou.data import Collector +from tianshou.policy import BasePolicy + +from qlib.constant import INF +from qlib.log import get_module_logger +from qlib.rl.simulator import InitialStateType, Simulator +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.reward import Reward +from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env + + +_logger = get_module_logger(__name__) + + +def backtest( + simulator_fn: Callable[[InitialStateType], Simulator], + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + initial_states: Sequence[InitialStateType], + policy: BasePolicy, + logger: LogWriter | list[LogWriter], + reward: Reward | None = None, + finite_env_type: FiniteEnvType = "subproc", + concurrency: int = 2, +) -> None: + """Backtest with the parallelism provided by RL framework. + + Parameters + ---------- + simulator_fn + Callable receiving initial seed, returning a simulator. + state_interpreter + Interprets the state of simulators. + action_interpreter + Interprets the policy actions. + initial_states + Initial states to iterate over. Every state will be run exactly once. + policy + Policy to test against. + logger + Logger to record the backtest results. Logger must be present because + without logger, all information will be lost. + reward + Optional reward function. For backtest, this is for testing the rewards + and logging them only. + finite_env_type + Type of finite env implementation. + concurrency + Parallel workers. + """ + + # To save bandwidth + min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel + + def env_factory(): + # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env), + # and could be thread unsafe. + # I'm not sure whether it's a design flaw. + # I'll rethink about this when designing the trainer. + + if finite_env_type == "dummy": + # We could only experience the "threading-unsafe" problem in dummy. + state = copy.deepcopy(state_interpreter) + action = copy.deepcopy(action_interpreter) + rew = copy.deepcopy(reward) + else: + state, action, rew = state_interpreter, action_interpreter, reward + + return EnvWrapper( + simulator_fn, + state, + action, + seed_iterator, + rew, + logger=LogCollector(min_loglevel=min_loglevel), + ) + + with DataQueue(initial_states) as seed_iterator: + vector_env = vectorize_env( + env_factory, + finite_env_type, + concurrency, + logger, + ) + + policy.eval() + + with vector_env.collector_guard(): + test_collector = Collector(policy, vector_env) + _logger.info("All ready. Start backtest.") + test_collector.collect(n_step=INF * len(vector_env)) diff --git a/qlib/rl/entries/train.py b/qlib/rl/entries/train.py new file mode 100644 index 00000000000..c852e623528 --- /dev/null +++ b/qlib/rl/entries/train.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TBD diff --git a/qlib/rl/env.py b/qlib/rl/env.py deleted file mode 100644 index 6173a27a869..00000000000 --- a/qlib/rl/env.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import Union - - -from ..backtest.executor import BaseExecutor -from .interpreter import StateInterpreter, ActionInterpreter -from ..utils import init_instance_by_config - - -class BaseRLEnv: - """Base environment for reinforcement learning""" - - def reset(self, **kwargs): - raise NotImplementedError("reset is not implemented!") - - def step(self, action): - """ - step method of rl env - Parameters - ---------- - action : - action from rl policy - - Returns - ------- - env state to rl policy - """ - raise NotImplementedError("step is not implemented!") - - -class QlibRLEnv: - """qlib-based RL env""" - - def __init__( - self, - executor: BaseExecutor, - ): - """ - Parameters - ---------- - executor : BaseExecutor - qlib multi-level/single-level executor, which can be regarded as gamecore in RL - """ - self.executor = executor - - def reset(self, **kwargs): - self.executor.reset(**kwargs) - - -class QlibIntRLEnv(QlibRLEnv): - """(Qlib)-based RL (Env) with (Interpreter)""" - - def __init__( - self, - executor: BaseExecutor, - state_interpreter: Union[dict, StateInterpreter], - action_interpreter: Union[dict, ActionInterpreter], - ): - """ - - Parameters - ---------- - state_interpreter : Union[dict, StateInterpreter] - interpreter that interprets the qlib execute result into rl env state. - - action_interpreter : Union[dict, ActionInterpreter] - interpreter that interprets the rl agent action into qlib order list - """ - super(QlibIntRLEnv, self).__init__(executor=executor) - self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter) - self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter) - - def step(self, action): - """ - step method of rl env, it run as following step: - - Use `action_interpreter.interpret` method to interpret the agent action into order list - - Execute the order list with qlib executor, and get the executed result - - Use `state_interpreter.interpret` method to interpret the executed result into env state - - Parameters - ---------- - action : - action from rl policy - - Returns - ------- - env state to rl policy - """ - _interpret_decision = self.action_interpreter.interpret(action=action) - _execute_result = self.executor.execute(trade_decision=_interpret_decision) - _interpret_state = self.state_interpreter.interpret(execute_result=_execute_result) - return _interpret_state diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index c711b838089..3835b5b9238 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -1,47 +1,150 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations -class BaseInterpreter: - """Base Interpreter""" +from typing import TYPE_CHECKING, TypeVar, Generic, Any - def interpret(self, **kwargs): - raise NotImplementedError("interpret is not implemented!") +import numpy as np +from qlib.typehint import final -class ActionInterpreter(BaseInterpreter): - """Action Interpreter that interpret rl agent action into qlib orders""" +from .simulator import StateType, ActType + +if TYPE_CHECKING: + from .utils.env_wrapper import EnvWrapper + +import gym +from gym import spaces + +ObsType = TypeVar("ObsType") +PolicyActType = TypeVar("PolicyActType") + + +class Interpreter: + """Interpreter is a media between states produced by simulators and states needed by RL policies. + Interpreters are two-way: + + 1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`. + 2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`. + + Inherit one of the two sub-classes to define your own interpreter. + This super-class is only used for isinstance check. - def interpret(self, action, **kwargs): - """interpret method + Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx`` + in interpreter is anti-pattern. In future, we might support register some interpreter-related + states by calling ``self.env.register_state()``, but it's not planned for first iteration. + """ + + +class StateInterpreter(Generic[StateType, ObsType], Interpreter): + """State Interpreter that interpret execution result of qlib executor into rl env state""" + + env: EnvWrapper | None = None + + @property + def observation_space(self) -> gym.Space: + raise NotImplementedError() + + @final # no overridden + def __call__(self, simulator_state: StateType) -> ObsType: + obs = self.interpret(simulator_state) + self.validate(obs) + return obs + + def validate(self, obs: ObsType) -> None: + """Validate whether an observation belongs to the pre-defined observation space.""" + _gym_space_contains(self.observation_space, obs) + + def interpret(self, simulator_state: StateType) -> ObsType: + """Interpret the state of simulator. Parameters ---------- - action : - rl agent action + simulator_state + Retrieved with ``simulator.get_state()``. Returns ------- - qlib orders - + State needed by policy. Should conform with the state space defined in ``observation_space``. """ - raise NotImplementedError("interpret is not implemented!") -class StateInterpreter(BaseInterpreter): - """State Interpreter that interpret execution result of qlib executor into rl env state""" +class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter): + """Action Interpreter that interpret rl agent action into qlib orders""" + + env: "EnvWrapper" | None = None + + @property + def action_space(self) -> gym.Space: + raise NotImplementedError() - def interpret(self, execute_result, **kwargs): - """interpret method + @final # no overridden + def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType: + self.validate(action) + obs = self.interpret(simulator_state, action) + return obs + + def validate(self, action: PolicyActType) -> None: + """Validate whether an action belongs to the pre-defined action space.""" + _gym_space_contains(self.action_space, action) + + def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType: + """Convert the policy action to simulator action. Parameters ---------- - execute_result : - qlib execution result + simulator_state + Retrieved with ``simulator.get_state()``. + action + Raw action given by policy. Returns - ---------- - rl env state + ------- + The action needed by simulator, """ raise NotImplementedError("interpret is not implemented!") + + +def _gym_space_contains(space: gym.Space, x: Any) -> None: + """Strengthened version of gym.Space.contains. + Giving more diagnostic information on why validation fails. + + Throw exception rather than returning true or false. + """ + if isinstance(space, spaces.Dict): + if not isinstance(x, dict) or len(x) != len(space): + raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x) + for k, subspace in space.spaces.items(): + if k not in x: + raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x) + try: + _gym_space_contains(subspace, x[k]) + except GymSpaceValidationError as e: + raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e + + elif isinstance(space, spaces.Tuple): + if isinstance(x, (list, np.ndarray)): + x = tuple(x) # Promote list and ndarray to tuple for contains check + if not isinstance(x, tuple) or len(x) != len(space): + raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x) + for i, (subspace, part) in enumerate(zip(space, x)): + try: + _gym_space_contains(subspace, part) + except GymSpaceValidationError as e: + raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e + + else: + if not space.contains(x): + raise GymSpaceValidationError("Validation error reported by gym.", space, x) + + +class GymSpaceValidationError(Exception): + def __init__(self, message: str, space: gym.Space, x: Any): + self.message = message + self.space = space + self.x = x + + def __str__(self): + return f"{self.message}\n Space: {self.space}\n Sample: {self.x}" diff --git a/qlib/rl/order_execution/__init__.py b/qlib/rl/order_execution/__init__.py new file mode 100644 index 00000000000..048dfecac9d --- /dev/null +++ b/qlib/rl/order_execution/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Currently it supports single-asset order execution. +Multi-asset is on the way. +""" + +from .interpreter import * +from .network import * +from .policy import * +from .simulator_simple import * diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py new file mode 100644 index 00000000000..9bb5dc2cf13 --- /dev/null +++ b/qlib/rl/order_execution/interpreter.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import math +from pathlib import Path +from typing import Any, cast + +import numpy as np +import pandas as pd +from gym import spaces + +from qlib.constant import EPS +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.data import pickle_styled +from qlib.typehint import TypedDict + +from .simulator_simple import SAOEState + +__all__ = [ + "FullHistoryStateInterpreter", + "CurrentStepStateInterpreter", + "CategoricalActionInterpreter", + "TwapRelativeActionInterpreter", +] + + +def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict: + """To 32-bit numeric types. Recursively.""" + if isinstance(value, pd.DataFrame): + return value.to_numpy() + if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == "f"): + return np.array(value, dtype=np.float32) + elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == "i"): + return np.array(value, dtype=np.int32) + elif isinstance(value, dict): + return {k: canonicalize(v) for k, v in value.items()} + else: + return value + + +class FullHistoryObs(TypedDict): + data_processed: Any + data_processed_prev: Any + acquiring: Any + cur_tick: Any + cur_step: Any + num_step: Any + target: Any + position: Any + position_history: Any + + +class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): + """The observation of all the history, including today (until this moment), and yesterday. + + Parameters + ---------- + data_dir + Path to load data after feature engineering. + max_step + Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. + data_ticks + Equal to the total number of records. For example, in SAOE per minute, + the total ticks is the length of day in minutes. + data_dim + Number of dimensions in data. + """ + + def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None: + self.data_dir = data_dir + self.max_step = max_step + self.data_ticks = data_ticks + self.data_dim = data_dim + + def interpret(self, state: SAOEState) -> FullHistoryObs: + processed = pickle_styled.load_intraday_processed_data( + self.data_dir, + state.order.stock_id, + pd.Timestamp(state.order.start_time.date()), + self.data_dim, + state.ticks_index, + ) + + position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) + position_history[0] = state.order.amount + position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + + assert self.env is not None + + # The min, slice here are to make sure that indices fit into the range, + # even after the final step of the simulator (in the done step), + # to make network in policy happy. + return cast( + FullHistoryObs, + canonicalize( + { + "data_processed": self._mask_future_info(processed.today, state.cur_time), + "data_processed_prev": processed.yesterday, + "acquiring": state.order.direction == state.order.BUY, + "cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1), + "cur_step": min(self.env.status["cur_step"], self.max_step - 1), + "num_step": self.max_step, + "target": state.order.amount, + "position": state.position, + "position_history": position_history[: self.max_step], + } + ), + ) + + @property + def observation_space(self): + space = { + "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "acquiring": spaces.Discrete(2), + "cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32), + "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), + # TODO: support arbitrary length index + "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32), + "target": spaces.Box(-EPS, np.inf, shape=()), + "position": spaces.Box(-EPS, np.inf, shape=()), + "position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)), + } + return spaces.Dict(space) + + @staticmethod + def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame: + arr = arr.copy(deep=True) + arr.loc[current:] = 0.0 # mask out data after this moment (inclusive) + return arr + + +class CurrentStateObs(TypedDict): + acquiring: bool + cur_step: int + num_step: int + target: float + position: float + + +class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]): + """The observation of current step. + + Used when policy only depends on the latest state, but not history. + The key list is not full. You can add more if more information is needed by your policy. + """ + + def __init__(self, max_step: int): + self.max_step = max_step + + @property + def observation_space(self): + space = { + "acquiring": spaces.Discrete(2), + "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), + "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32), + "target": spaces.Box(-EPS, np.inf, shape=()), + "position": spaces.Box(-EPS, np.inf, shape=()), + } + return spaces.Dict(space) + + def interpret(self, state: SAOEState) -> CurrentStateObs: + assert self.env is not None + assert self.env.status["cur_step"] <= self.max_step + obs = CurrentStateObs( + { + "acquiring": state.order.direction == state.order.BUY, + "cur_step": self.env.status["cur_step"], + "num_step": self.max_step, + "target": state.order.amount, + "position": state.position, + } + ) + return obs + + +class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): + """Convert a discrete policy action to a continuous action, then multiplied by ``order.amount``. + + Parameters + ---------- + values + It can be a list of length $L$: $[a_1, a_2, \\ldots, a_L]$. + Then when policy givens decision $x$, $a_x$ times order amount is the output. + It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated, + i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. + """ + + def __init__(self, values: int | list[float]): + if isinstance(values, int): + values = [i / values for i in range(0, values + 1)] + self.action_values = values + + @property + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(len(self.action_values)) + + def interpret(self, state: SAOEState, action: int) -> float: + assert 0 <= action < len(self.action_values) + return min(state.position, state.order.amount * self.action_values[action]) + + +class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]): + """Convert a continous ratio to deal amount. + + The ratio is relative to TWAP on the remainder of the day. + For example, there are 5 steps left, and the left position is 300. + With TWAP strategy, in each position, 60 should be traded. + When this interpreter receives action $a$, its output is $60 \\cdot a$. + """ + + @property + def action_space(self) -> spaces.Box: + return spaces.Box(0, np.inf, shape=(), dtype=np.float32) + + def interpret(self, state: SAOEState, action: float) -> float: + assert self.env is not None + estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step) + twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"]) + return min(state.position, twap_volume * action) diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py new file mode 100644 index 00000000000..908f96130fb --- /dev/null +++ b/qlib/rl/order_execution/network.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import cast + +import torch +import torch.nn as nn +from tianshou.data import Batch + +from qlib.typehint import Literal +from .interpreter import FullHistoryObs + +__all__ = ["Recurrent"] + + +class Recurrent(nn.Module): + """The network architecture proposed in `OPD `_. + + At every timestep the input of policy network is divided into two parts, + the public variables and the private variables. which are handled by ``raw_rnn`` + and ``pri_rnn`` in this network, respectively. + + One minor difference is that, in this implementation, we don't assume the direction to be fixed. + Thus, another ``dire_fc`` is added to produce an extra direction-related feature. + """ + + def __init__( + self, + obs_space: FullHistoryObs, + hidden_dim: int = 64, + output_dim: int = 32, + rnn_type: Literal["rnn", "lstm", "gru"] = "gru", + rnn_num_layers: int = 1, + ): + super().__init__() + + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_sources = 3 + + rnn_classes = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU} + + self.rnn_class = rnn_classes[rnn_type] + self.rnn_layers = rnn_num_layers + + self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) + self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) + self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) + + self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU()) + self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU()) + self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) + + self._init_extra_branches() + + self.fc = nn.Sequential( + nn.Linear(hidden_dim * self.num_sources, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + nn.ReLU(), + ) + + def _init_extra_branches(self): + pass + + def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]: + bs, _, data_dim = obs["data_processed"].size() + data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1) + cur_step = obs["cur_step"].long() + cur_tick = obs["cur_tick"].long() + bs_indices = torch.arange(bs, device=device) + + position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step] + steps = ( + torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float() + / obs["num_step"].unsqueeze(-1).float() + ) # [bs, num_step] + priv = torch.stack((position.float(), steps), -1) + + data_in = self.raw_fc(data) + data_out, _ = self.raw_rnn(data_in) + # as it is padded with zero in front, this should be last minute + data_out_slice = data_out[bs_indices, cur_tick] + + priv_in = self.pri_fc(priv) + priv_out = self.pri_rnn(priv_in)[0] + priv_out = priv_out[bs_indices, cur_step] + + sources = [data_out_slice, priv_out] + + dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float()) + sources.append(dir_out) + + return sources, data_out + + def forward(self, batch: Batch) -> torch.Tensor: + """ + Input should be a dict (at least) containing: + + - data_processed: [N, T, C] + - cur_step: [N] (int) + - cur_time: [N] (int) + - position_history: [N, S] (S is number of steps) + - target: [N] + - num_step: [N] (int) + - acquiring: [N] (0 or 1) + """ + + inp = cast(FullHistoryObs, batch) + device = inp["data_processed"].device + + sources, _ = self._source_features(inp, device) + assert len(sources) == self.num_sources + + out = torch.cat(sources, -1) + return self.fc(out) diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py new file mode 100644 index 00000000000..f95a53c7588 --- /dev/null +++ b/qlib/rl/order_execution/policy.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +from typing import Optional, cast + +import numpy as np +import gym +import torch +import torch.nn as nn +from gym.spaces import Discrete +from tianshou.data import Batch, to_torch +from tianshou.policy import PPOPolicy, BasePolicy + +__all__ = ["AllOne", "PPO"] + + +# baselines # + + +class NonlearnablePolicy(BasePolicy): + """Tianshou's BasePolicy with empty ``learn`` and ``process_fn``. + + This could be moved outside in future. + """ + + def __init__(self, obs_space: gym.Space, action_space: gym.Space): + super().__init__() + + def learn(self, batch, batch_size, repeat): + pass + + def process_fn(self, batch, buffer, indice): + pass + + +class AllOne(NonlearnablePolicy): + """Forward returns a batch full of 1. + + Useful when implementing some baselines (e.g., TWAP). + """ + + def forward(self, batch, state=None, **kwargs): + return Batch(act=np.full(len(batch), 1.0), state=state) + + +# ppo # + + +class PPOActor(nn.Module): + def __init__(self, extractor: nn.Module, action_dim: int): + super().__init__() + self.extractor = extractor + self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1)) + + def forward(self, obs, state=None, info={}): + feature = self.extractor(to_torch(obs, device=auto_device(self))) + out = self.layer_out(feature) + return out, state + + +class PPOCritic(nn.Module): + def __init__(self, extractor: nn.Module): + super().__init__() + self.extractor = extractor + self.value_out = nn.Linear(cast(int, extractor.output_dim), 1) + + def forward(self, obs, state=None, info={}): + feature = self.extractor(to_torch(obs, device=auto_device(self))) + return self.value_out(feature).squeeze(dim=-1) + + +class PPO(PPOPolicy): + """A wrapper of tianshou PPOPolicy. + + Differences: + + - Auto-create actor and critic network. Supports discrete action space only. + - Dedup common parameters between actor network and critic network + (not sure whether this is included in latest tianshou or not). + - Support a ``weight_file`` that supports loading checkpoint. + - Some parameters' default values are different from original. + """ + + def __init__( + self, + network: nn.Module, + obs_space: gym.Space, + action_space: gym.Space, + lr: float, + weight_decay: float = 0.0, + discount_factor: float = 1.0, + max_grad_norm: float = 100.0, + reward_normalization: bool = True, + eps_clip: float = 0.3, + value_clip: float = True, + vf_coef: float = 1.0, + gae_lambda: float = 1.0, + max_batchsize: int = 256, + deterministic_eval: bool = True, + weight_file: Optional[Path] = None, + ): + assert isinstance(action_space, Discrete) + actor = PPOActor(network, action_space.n) + critic = PPOCritic(network) + optimizer = torch.optim.Adam( + chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay + ) + super().__init__( + actor, + critic, + optimizer, + torch.distributions.Categorical, + discount_factor=discount_factor, + max_grad_norm=max_grad_norm, + reward_normalization=reward_normalization, + eps_clip=eps_clip, + value_clip=value_clip, + vf_coef=vf_coef, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + deterministic_eval=deterministic_eval, + observation_space=obs_space, + action_space=action_space, + ) + if weight_file is not None: + load_weight(self, weight_file) + + +# utilities: these should be put in a separate (common) file. # + + +def auto_device(module: nn.Module) -> torch.device: + for param in module.parameters(): + return param.device + return torch.device("cpu") # fallback to cpu + + +def load_weight(policy, path): + assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight." + loaded_weight = torch.load(path, map_location="cpu") + try: + policy.load_state_dict(loaded_weight) + except RuntimeError: + # try again by loading the converted weight + # https://github.com/thu-ml/tianshou/issues/468 + for k in list(loaded_weight): + loaded_weight["_actor_critic." + k] = loaded_weight[k] + policy.load_state_dict(loaded_weight) + + +def chain_dedup(*iterables): + seen = set() + for iterable in iterables: + for i in iterable: + if i not in seen: + seen.add(i) + yield i diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py new file mode 100644 index 00000000000..71aaa222be5 --- /dev/null +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Placeholder for qlib-based simulator.""" diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py new file mode 100644 index 00000000000..8022c34cece --- /dev/null +++ b/qlib/rl/order_execution/simulator_simple.py @@ -0,0 +1,403 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from pathlib import Path +from typing import NamedTuple, Any, TypeVar, cast + +import numpy as np +import pandas as pd + +from qlib.backtest.decision import Order, OrderDir +from qlib.constant import EPS +from qlib.rl.simulator import Simulator +from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType +from qlib.rl.utils import LogLevel +from qlib.typehint import TypedDict + +__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"] + +ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point + + +class SAOEMetrics(TypedDict): + """Metrics for SAOE accumulated for a "period". + It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute. + + Warnings + -------- + The type hints are for single elements. In lots of times, they can be vectorized. + For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float. + """ + + stock_id: str + """Stock ID of this record.""" + datetime: pd.Timestamp + """Datetime of this record (this is index in the dataframe).""" + direction: int + """Direction of the order. 0 for sell, 1 for buy.""" + + # Market information. + market_volume: float + """(total) market volume traded in the period.""" + market_price: float + """Deal price. If it's a period of time, this is the average market deal price.""" + + # Strategy records. + + amount: float + """Total amount (volume) strategy intends to trade.""" + inner_amount: float + """Total amount that the lower-level strategy intends to trade + (might be larger than amount, e.g., to ensure ffr).""" + + deal_amount: float + """Amount that successfully takes effect (must be less than inner_amount).""" + trade_price: float + """The average deal price for this strategy.""" + trade_value: float + """Total worth of trading. In the simple simulaton, trade_value = deal_amount * price.""" + position: float + """Position left after this "period".""" + + # Accumulated metrics + + ffr: float + """Completed how much percent of the daily order.""" + + pa: float + """Price advantage compared to baseline (i.e., trade with baseline market price). + The baseline is trade price when using TWAP strategy to execute this order. + Please note that there could be data leak here). + Unit is BP (basis point, 1/10000).""" + + +class SAOEState(NamedTuple): + """Data structure holding a state for SAOE simulator.""" + + order: Order + """The order we are dealing with.""" + cur_time: pd.Timestamp + """Current time, e.g., 9:30.""" + position: float + """Current remaining volume to execute.""" + history_exec: pd.DataFrame + """See :attr:`SingleAssetOrderExecution.history_exec`.""" + history_steps: pd.DataFrame + """See :attr:`SingleAssetOrderExecution.history_steps`.""" + + metrics: SAOEMetrics | None + """Daily metric, only available when the trading is in "done" state.""" + + backtest_data: IntradayBacktestData + """Backtest data is included in the state. + Actually, only the time index of this data is needed, at this moment. + I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented. + Interpreter can use this as they wish, but they should be careful not to leak future data. + """ + + ticks_per_step: int + """How many ticks for each step.""" + ticks_index: pd.DatetimeIndex + """Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59].""" + ticks_for_order: pd.DatetimeIndex + """Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44].""" + + +class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): + """Single-asset order execution (SAOE) simulator. + + As there's no "calendar" in the simple simulator, ticks are used to trade. + A tick is a record (a line) in the pickle-styled data file. + Each tick is considered as a individual trading opportunity. + If such fine granularity is not needed, use ``ticks_per_step`` to + lengthen the ticks for each step. + + In each step, the traded amount are "equally" splitted to each tick, + then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``), + and if it's the last step, try to ensure all the amount to be executed. + + Parameters + ---------- + initial + The seed to start an SAOE simulator is an order. + ticks_per_step + How many ticks per step. + data_dir + Path to load backtest data + vol_threshold + Maximum execution volume (divided by market execution volume). + """ + + history_exec: pd.DataFrame + """All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.""" + + history_steps: pd.DataFrame + """Positions at each step. The position before first step is also recorded. + See :class:`SAOEMetrics` for available columns.""" + + metrics: SAOEMetrics | None + """Metrics. Only available when done.""" + + twap_price: float + """This price is used to compute price advantage. + It"s defined as the average price in the period from order"s start time to end time.""" + + ticks_index: pd.DatetimeIndex + """All available ticks for the day (not restricted to order).""" + + ticks_for_order: pd.DatetimeIndex + """Ticks that is available for trading (sliced by order).""" + + def __init__( + self, + order: Order, + data_dir: Path, + ticks_per_step: int = 30, + deal_price_type: DealPriceType = "close", + vol_threshold: float | None = None, + ) -> None: + self.order = order + self.ticks_per_step: int = ticks_per_step + self.deal_price_type = deal_price_type + self.vol_threshold = vol_threshold + self.data_dir = data_dir + self.backtest_data = load_intraday_backtest_data( + self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction + ) + + self.ticks_index = self.backtest_data.get_time_index() + + # Get time index available for trading + self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time) + + self.cur_time = self.ticks_for_order[0] + # NOTE: astype(float) is necessary in some systems. + # this will align the precision with `.to_numpy()` in `_split_exec_vol` + self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean()) + + self.position = order.amount + + metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member + # NOTE: can empty dataframe contain index? + self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.metrics = None + + self.market_price: np.ndarray | None = None + self.market_vol: np.ndarray | None = None + self.market_vol_limit: np.ndarray | None = None + + def step(self, amount: float) -> None: + """Execute one step or SAOE. + + Parameters + ---------- + amount + The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt. + """ + + assert not self.done() + + self.market_price = self.market_vol = None # avoid misuse + exec_vol = self._split_exec_vol(amount) + assert self.market_price is not None and self.market_vol is not None + + ticks_position = self.position - np.cumsum(exec_vol) + + self.position -= exec_vol.sum() + if self.position < -EPS or (exec_vol < -EPS).any(): + raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})") + + # Get time index available for this step + time_index = self._get_ticks_slice(self.cur_time, self._next_time()) + + self.history_exec = self._dataframe_append( + self.history_exec, + SAOEMetrics( + # It should have the same keys with SAOEMetrics, + # but the values do not necessarily have the annotated type. + # Some values could be vectorized (e.g., exec_vol). + stock_id=self.order.stock_id, + datetime=time_index, + direction=self.order.direction, + market_volume=self.market_vol, + market_price=self.market_price, + amount=exec_vol, + inner_amount=exec_vol, + deal_amount=exec_vol, + trade_price=self.market_price, + trade_value=self.market_price * exec_vol, + position=ticks_position, + ffr=exec_vol / self.order.amount, + pa=price_advantage(self.market_price, self.twap_price, self.order.direction), + ), + ) + + self.history_steps = self._dataframe_append( + self.history_steps, + [self._metrics_collect(self.cur_time, self.market_vol, self.market_price, amount, exec_vol)], + ) + + if self.done(): + if self.env is not None: + self.env.logger.add_any("history_steps", self.history_steps, loglevel=LogLevel.DEBUG) + self.env.logger.add_any("history_exec", self.history_exec, loglevel=LogLevel.DEBUG) + + self.metrics = self._metrics_collect( + self.ticks_index[0], # start time + self.history_exec["market_volume"], + self.history_exec["market_price"], + self.history_steps["amount"].sum(), + self.history_exec["deal_amount"], + ) + + # NOTE (yuge): It looks to me that it's the "correct" decision to + # put all the logs here, because only components like simulators themselves + # have the knowledge about what could appear in the logs, and what's the format. + # But I admit it's not necessarily the most convenient way. + # I'll rethink about it when we have the second environment + # Maybe some APIs like self.logger.enable_auto_log() ? + + if self.env is not None: + for key, value in self.metrics.items(): + if isinstance(value, float): + self.env.logger.add_scalar(key, value) + else: + self.env.logger.add_any(key, value) + + self.cur_time = self._next_time() + + def get_state(self) -> SAOEState: + return SAOEState( + order=self.order, + cur_time=self.cur_time, + position=self.position, + history_exec=self.history_exec, + history_steps=self.history_steps, + metrics=self.metrics, + backtest_data=self.backtest_data, + ticks_per_step=self.ticks_per_step, + ticks_index=self.ticks_index, + ticks_for_order=self.ticks_for_order, + ) + + def done(self) -> bool: + return self.position < EPS or self.cur_time >= self.order.end_time + + def _next_time(self) -> pd.Timestamp: + """The "current time" (``cur_time``) for next step.""" + # Look for next time on time index + current_loc = self.ticks_index.get_loc(self.cur_time) + next_loc = current_loc + self.ticks_per_step + + # Calibrate the next location to multiple of ticks_per_step. + # This is to make sure that: + # as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon. + next_loc = next_loc - next_loc % self.ticks_per_step + + if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.order.end_time: + return self.ticks_index[next_loc] + else: + return self.order.end_time + + def _cur_duration(self) -> pd.Timedelta: + """The "duration" of this step (step that is about to happen).""" + return self._next_time() - self.cur_time + + def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray: + """ + Split the volume in each step into minutes, considering possible constraints. + This follows TWAP strategy. + """ + next_time = self._next_time() + + # get the backtest data for next interval + self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - ONE_SEC].to_numpy() + self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - ONE_SEC].to_numpy() + + assert self.market_vol is not None and self.market_price is not None + + # split the volume equally into each minute + exec_vol = np.repeat(exec_vol_sum / len(self.market_price), len(self.market_price)) + + # apply the volume threshold + market_vol_limit = self.vol_threshold * self.market_vol if self.vol_threshold is not None else np.inf + exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore + + # Complete all the order amount at the last moment. + if next_time >= self.order.end_time: + exec_vol[-1] += self.position - exec_vol.sum() + exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore + + return exec_vol + + def _metrics_collect( + self, + datetime: pd.Timestamp, + market_vol: np.ndarray, + market_price: np.ndarray, + amount: float, # intended to trade such amount + exec_vol: np.ndarray, + ) -> SAOEMetrics: + assert len(market_vol) == len(market_price) == len(exec_vol) + + if np.abs(np.sum(exec_vol)) < EPS: + exec_avg_price = 0.0 + else: + exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan + if hasattr(exec_avg_price, "item"): # could be numpy scalar + exec_avg_price = exec_avg_price.item() # type: ignore + + return SAOEMetrics( + stock_id=self.order.stock_id, + datetime=datetime, + direction=self.order.direction, + market_volume=market_vol.sum(), + market_price=market_price.mean(), + amount=amount, + inner_amount=exec_vol.sum(), + deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions + trade_price=exec_avg_price, + trade_value=np.sum(market_price * exec_vol), + position=self.position, + ffr=float(exec_vol.sum() / self.order.amount), + pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction), + ) + + def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex: + if not include_end: + end = end - ONE_SEC + return self.ticks_index[self.ticks_index.slice_indexer(start, end)] + + @staticmethod + def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: + # dataframe.append is deprecated + other_df = pd.DataFrame(other).set_index("datetime") + other_df.index.name = "datetime" + return pd.concat([df, other_df], axis=0) + + +_float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray) + + +def price_advantage( + exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int +) -> _float_or_ndarray: + if baseline_price == 0: # something is wrong with data. Should be nan here + if isinstance(exec_price, float): + return 0.0 + else: + return np.zeros_like(exec_price) + if direction == OrderDir.BUY: + res = (1 - exec_price / baseline_price) * 10000 + elif direction == OrderDir.SELL: + res = (exec_price / baseline_price - 1) * 10000 + else: + raise ValueError(f"Unexpected order direction: {direction}") + res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0) + if res_wo_nan.size == 1: + return res_wo_nan.item() + else: + return cast(_float_or_ndarray, res_wo_nan) diff --git a/qlib/rl/reward.py b/qlib/rl/reward.py new file mode 100644 index 00000000000..20d98587427 --- /dev/null +++ b/qlib/rl/reward.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Generic, Any, TypeVar, TYPE_CHECKING + +from qlib.typehint import final + +if TYPE_CHECKING: + from .utils.env_wrapper import EnvWrapper + +SimulatorState = TypeVar("SimulatorState") + + +class Reward(Generic[SimulatorState]): + """ + Reward calculation component that takes a single argument: state of simulator. Returns a real number: reward. + + Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe. + """ + + env: EnvWrapper | None = None + + @final + def __call__(self, simulator_state: SimulatorState) -> float: + return self.reward(simulator_state) + + def reward(self, simulator_state: SimulatorState) -> float: + """Implement this method for your own reward.""" + raise NotImplementedError("Implement reward calculation recipe in `reward()`.") + + def log(self, name, value): + self.env.logger.add_scalar(name, value) + + +class RewardCombination(Reward): + """Combination of multiple reward.""" + + def __init__(self, rewards: dict[str, tuple[Reward, float]]): + self.rewards = rewards + + def reward(self, simulator_state: Any) -> float: + total_reward = 0.0 + for name, (reward_fn, weight) in self.rewards.items(): + rew = reward_fn(simulator_state) * weight + total_reward += rew + self.log(name, rew) + return total_reward + + +# TODO: +# reward_factory is disabled for now + +# _RegistryConfigReward = RegistryConfig[REWARDS] + + +# @configclass +# class _WeightedRewardConfig: +# weight: float +# reward: _RegistryConfigReward + + +# RewardConfig = Union[_RegistryConfigReward, Dict[str, Union[_RegistryConfigReward, _WeightedRewardConfig]]] + + +# def reward_factory(reward_config: RewardConfig) -> Reward: +# """ +# Use this factory to instantiate the reward from config. +# Simply using ``reward_config.build()`` might not work because reward can have complex combinations. +# """ +# if isinstance(reward_config, dict): +# # as reward combination +# rewards = {} +# for name, rew in reward_config.items(): +# if not isinstance(rew, _WeightedRewardConfig): +# # default weight is 1. +# rew = _WeightedRewardConfig(weight=1., rew=rew) +# # no recursive build in this step +# rewards[name] = (rew.reward.build(), rew.weight) +# return RewardCombination(rewards) +# else: +# # single reward +# return reward_config.build() diff --git a/qlib/rl/seed.py b/qlib/rl/seed.py new file mode 100644 index 00000000000..93d452a4a2a --- /dev/null +++ b/qlib/rl/seed.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Defines a set of initial state definitions and state-set definitions. + +With single-asset order execution only, the only seed is order. +""" + +from typing import TypeVar + +InitialStateType = TypeVar("InitialStateType") +"""Type of data that creates the simulator.""" diff --git a/qlib/rl/simulator.py b/qlib/rl/simulator.py new file mode 100644 index 00000000000..56fc12042c1 --- /dev/null +++ b/qlib/rl/simulator.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import TypeVar, Generic, Any, TYPE_CHECKING + +from .seed import InitialStateType + +if TYPE_CHECKING: + from .utils.env_wrapper import EnvWrapper + +StateType = TypeVar("StateType") +"""StateType stores all the useful data in the simulation process +(as well as utilities to generate/retrieve data when needed).""" + +ActType = TypeVar("ActType") +"""This ActType is the type of action at the simulator end.""" + + +class Simulator(Generic[InitialStateType, StateType, ActType]): + """ + Simulator that resets with ``__init__``, and transits with ``step(action)``. + + To make the data-flow clear, we make the following restrictions to Simulator: + + 1. The only way to modify the inner status of a simulator is by using ``step(action)``. + 2. External modules can *read* the status of a simulator by using ``simulator.get_state()``, + and check whether the simulator is in the ending state by calling ``simulator.done()``. + + A simulator is defined to be bounded with three types: + + - *InitialStateType* that is the type of the data used to create the simulator. + - *StateType* that is the type of the **status** (state) of the simulator. + - *ActType* that is the type of the **action**, which is the input received in each step. + + Different simulators might share the same StateType. For example, when they are dealing with the same task, + but with different simulation implementation. With the same type, they can safely share other components in the MDP. + + Simulators are ephemeral. The lifecycle of a simulator starts with an initial state, and ends with the trajectory. + In another word, when the trajectory ends, simulator is recycled. + If simulators want to share context between (e.g., for speed-up purposes), + this could be done by accessing the weak reference of environment wrapper. + + Attributes + ---------- + env + A reference of env-wrapper, which could be useful in some corner cases. + Simulators are discouraged to use this, because it's prone to induce errors. + """ + + env: EnvWrapper | None = None + + def __init__(self, initial: InitialStateType, **kwargs: Any) -> None: + pass + + def step(self, action: ActType) -> None: + """Receives an action of ActType. + + Simulator should update its internal state, and return None. + The updated state can be retrieved with ``simulator.get_state()``. + """ + raise NotImplementedError() + + def get_state(self) -> StateType: + raise NotImplementedError() + + def done(self) -> bool: + """Check whether the simulator is in a "done" state. + When simulator is in a "done" state, + it should no longer receives any ``step`` request. + As simulators are ephemeral, to reset the simulator, + the old one should be destroyed and a new simulator can be created. + """ + raise NotImplementedError() diff --git a/qlib/rl/utils/__init__.py b/qlib/rl/utils/__init__.py new file mode 100644 index 00000000000..4a1fa9d9058 --- /dev/null +++ b/qlib/rl/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .data_queue import * +from .env_wrapper import * +from .finite_env import * +from .log import * diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py new file mode 100644 index 00000000000..1e9dab24da4 --- /dev/null +++ b/qlib/rl/utils/data_queue.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import multiprocessing +import threading +import time +import warnings +from queue import Empty +from typing import TypeVar, Generic, Sequence, cast + +from qlib.log import get_module_logger + +_logger = get_module_logger(__name__) + +T = TypeVar("T") + +__all__ = ["DataQueue"] + + +class DataQueue(Generic[T]): + """Main process (producer) produces data and stores them in a queue. + Sub-processes (consumers) can retrieve the data-points from the queue. + Data-points are generated via reading items from ``dataset``. + + :class:`DataQueue` is ephemeral. You must create a new DataQueue + when the ``repeat`` is exhausted. + + See the documents of :class:`qlib.rl.utils.FiniteVectorEnv` for more background. + + Parameters + ---------- + dataset + The dataset to read data from. Must implement ``__len__`` and ``__getitem__``. + repeat + Iterate over the data-points for how many times. Use ``-1`` to iterate forever. + shuffle + If ``shuffle`` is true, the items will be read in random order. + producer_num_workers + Concurrent workers for data-loading. + queue_maxsize + Maximum items to put into queue before it jams. + + Examples + -------- + >>> data_queue = DataQueue(my_dataset) + >>> with data_queue: + ... ... + + In worker: + + >>> for data in data_queue: + ... print(data) + """ + + def __init__( + self, + dataset: Sequence[T], + repeat: int = 1, + shuffle: bool = True, + producer_num_workers: int = 0, + queue_maxsize: int = 0, + ): + if queue_maxsize == 0: + if os.cpu_count() is not None: + queue_maxsize = cast(int, os.cpu_count()) + _logger.info(f"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming.") + else: + queue_maxsize = 1 + _logger.warning(f"CPU count not available. Setting queue maxsize to 1.") + + self.dataset: Sequence[T] = dataset + self.repeat: int = repeat + self.shuffle: bool = shuffle + self.producer_num_workers: int = producer_num_workers + + self._activated: bool = False + self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize) + self._done = multiprocessing.Value("i", 0) + + def __enter__(self): + self.activate() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup() + + def cleanup(self): + with self._done.get_lock(): + self._done.value += 1 + for repeat in range(500): + if repeat >= 1: + warnings.warn(f"After {repeat} cleanup, the queue is still not empty.", category=RuntimeWarning) + while not self._queue.empty(): + try: + self._queue.get(block=False) + except Empty: + pass + # Sometimes when the queue gets emptied, more data have already been sent, + # and they are on the way into the queue. + # If these data didn't get consumed, it will jam the queue and make the process hang. + # We wait a second here for potential data arriving, and check again (for ``repeat`` times). + time.sleep(1.0) + if self._queue.empty(): + break + _logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}") + + def get(self, block=True): + if not hasattr(self, "_first_get"): + self._first_get = True + if self._first_get: + timeout = 5.0 + self._first_get = False + else: + timeout = 0.5 + while True: + try: + return self._queue.get(block=block, timeout=timeout) + except Empty: + if self._done.value: + raise StopIteration # pylint: disable=raise-missing-from + + def put(self, obj, block=True, timeout=None): + return self._queue.put(obj, block=block, timeout=timeout) + + def mark_as_done(self): + with self._done.get_lock(): + self._done.value = 1 + + def done(self): + return self._done.value + + def activate(self): + if self._activated: + raise ValueError("DataQueue can not activate twice.") + thread = threading.Thread(target=self._producer, daemon=True) + thread.start() + self._activated = True + return self + + def __del__(self): + _logger.debug(f"__del__ of {__name__}.DataQueue") + self.cleanup() + + def __iter__(self): + if not self._activated: + raise ValueError( + "Need to call activate() to launch a daemon worker " "to produce data into data queue before using it." + ) + return self._consumer() + + def _consumer(self): + while True: + try: + yield self.get() + except StopIteration: + _logger.debug("Data consumer timed-out from get.") + return + + def _producer(self): + # pytorch dataloader is used here only because we need its sampler and multi-processing + from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel + + dataloader = DataLoader( + cast(Dataset[T], self.dataset), + batch_size=None, + num_workers=self.producer_num_workers, + shuffle=self.shuffle, + collate_fn=lambda t: t, # identity collate fn + ) + repeat = 10**18 if self.repeat == -1 else self.repeat + for _rep in range(repeat): + for data in dataloader: + if self._done.value: + # Already done. + return + self._queue.put(data) + _logger.debug(f"Dataloader loop done. Repeat {_rep}.") + self.mark_as_done() diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py new file mode 100644 index 00000000000..f343e5b9b4e --- /dev/null +++ b/qlib/rl/utils/env_wrapper.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import weakref +from typing import Callable, Any, Iterable, Iterator, Generic, cast + +import gym + +from qlib.rl.aux_info import AuxiliaryInfoCollector +from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType +from qlib.rl.reward import Reward +from qlib.typehint import TypedDict + +from .finite_env import generate_nan_observation +from .log import LogCollector, LogLevel + +__all__ = ["InfoDict", "EnvWrapperStatus", "EnvWrapper"] + +# in this case, there won't be any seed for simulator +SEED_INTERATOR_MISSING = "_missing_" + + +class InfoDict(TypedDict): + """The type of dict that is used in the 4th return value of ``env.step()``.""" + + aux_info: dict + """Any information depends on auxiliary info collector.""" + log: dict[str, Any] + """Collected by LogCollector.""" + + +class EnvWrapperStatus(TypedDict): + """ + This is the status data structure used in EnvWrapper. + The fields here are in the semantics of RL. + For example, ``obs`` means the observation fed into policy. + ``action`` means the raw action returned by policy. + """ + + cur_step: int + done: bool + initial_state: Any | None + obs_history: list + action_history: list + reward_history: list + + +class EnvWrapper( + gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType] +): + """Qlib-based RL environment, subclassing ``gym.Env``. + A wrapper of components, including simulator, state-interpreter, action-interpreter, reward. + + This is what the framework of simulator - interpreter - policy looks like in RL training. + All the components other than policy needs to be assembled into a single object called "environment". + The "environment" are replicated into multiple workers, and (at least in tianshou's implementation), + one single policy (agent) plays against a batch of environments. + + Parameters + ---------- + simulator_fn + A callable that is the simulator factory. + When ``seed_iterator`` is present, the factory should take one argument, + that is the seed (aka initial state). + Otherwise, it should take zero argument. + state_interpreter + State-observation converter. + action_interpreter + Policy-simulator action converter. + seed_iterator + An iterable of seed. With the help of :class:`qlib.rl.utils.DataQueue`, + environment workers in different processes can share one ``seed_iterator``. + reward_fn + A callable that accepts the StateType and returns a float (at least in single-agent case). + aux_info_collector + Collect auxiliary information. Could be useful in MARL. + logger + Log collector that collects the logs. The collected logs are sent back to main process, + via the return value of ``env.step()``. + + Attributes + ---------- + status : EnvWrapperStatus + Status indicator. All terms are in *RL language*. + It can be used if users care about data on the RL side. + Can be none when no trajectory is available. + """ + + simulator: Simulator[InitialStateType, StateType, ActType] + seed_iterator: str | Iterator[InitialStateType] | None + + def __init__( + self, + simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]], + state_interpreter: StateInterpreter[StateType, ObsType], + action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], + seed_iterator: Iterable[InitialStateType] | None, + reward_fn: Reward | None = None, + aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None, + logger: LogCollector | None = None, + ): + # Assign weak reference to wrapper. + # + # Use weak reference here, because: + # 1. Logically, the other components should be able to live without an env_wrapper. + # For example, they might live in a strategy_wrapper in future. + # Therefore injecting a "hard" attribute called "env" is not appropripate. + # 2. When the environment gets destroyed, it gets destoryed. + # We don't want it to silently live inside some interpreters. + # 3. Avoid circular reference. + # 4. When the components get serialized, we can throw away the env without any burden. + # (though this part is not implemented yet) + for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]: + if obj is not None: + obj.env = weakref.proxy(self) # type: ignore + + self.simulator_fn = simulator_fn + self.state_interpreter = state_interpreter + self.action_interpreter = action_interpreter + + if seed_iterator is None: + # In this case, there won't be any seed for simulator + # We can't set it to None because None actually means something else. + # If `seed_iterator` is None, it means that it's exhausted. + self.seed_iterator = SEED_INTERATOR_MISSING + else: + self.seed_iterator = iter(seed_iterator) + self.reward_fn = reward_fn + + self.aux_info_collector = aux_info_collector + self.logger: LogCollector = logger or LogCollector() + self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) + + @property + def action_space(self): + return self.action_interpreter.action_space + + @property + def observation_space(self): + return self.state_interpreter.observation_space + + def reset(self, **kwargs: Any) -> ObsType: + """ + Try to get a state from state queue, and init the simulator with this state. + If the queue is exhausted, generate an invalid (nan) observation. + """ + + try: + if self.seed_iterator is None: + raise RuntimeError("You can trying to get a state from a dead environment wrapper.") + + # TODO: simulator/observation might need seed to prefetch something + # as only seed has the ability to do the work beforehands + + # NOTE: though logger is reset here, logs in this function won't work, + # because we can't send them outside. + # See https://github.com/thu-ml/tianshou/issues/605 + self.logger.reset() + + if self.seed_iterator is SEED_INTERATOR_MISSING: + # no initial state + initial_state = None + self.simulator = cast(Callable[[], Simulator], self.simulator_fn)() + else: + initial_state = next(cast(Iterator[InitialStateType], self.seed_iterator)) + self.simulator = self.simulator_fn(initial_state) + + self.status = EnvWrapperStatus( + cur_step=0, + done=False, + initial_state=initial_state, + obs_history=[], + action_history=[], + reward_history=[], + ) + + self.simulator.env = cast(EnvWrapper, weakref.proxy(self)) + + sim_state = self.simulator.get_state() + obs = self.state_interpreter(sim_state) + + self.status["obs_history"].append(obs) + + return obs + + except StopIteration: + # The environment should be recycled because it's in a dead state. + self.seed_iterator = None + return generate_nan_observation(self.observation_space) + + def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]: + """Environment step. + + See the code along with comments to get a sequence of things happening here. + """ + + if self.seed_iterator is None: + raise RuntimeError("State queue is already exhausted, but the environment is still receiving action.") + + # Clear the logged information from last step + self.logger.reset() + + # Action is what we have got from policy + self.status["action_history"].append(policy_action) + action = self.action_interpreter(self.simulator.get_state(), policy_action) + + # This update must be after action interpreter and before simulator. + self.status["cur_step"] += 1 + + # Use the converted action of update the simulator + self.simulator.step(action) + + # Update "done" first, as this status might be used by reward_fn later + done = self.simulator.done() + self.status["done"] = done + + # Get state and calculate observation + sim_state = self.simulator.get_state() + obs = self.state_interpreter(sim_state) + self.status["obs_history"].append(obs) + + # Reward and extra info + if self.reward_fn is not None: + rew = self.reward_fn(sim_state) + else: + # No reward. Treated as 0. + rew = 0.0 + self.status["reward_history"].append(rew) + + if self.aux_info_collector is not None: + aux_info = self.aux_info_collector(sim_state) + else: + aux_info = {} + + # Final logging stuff: RL-specific logs + if done: + self.logger.add_scalar("steps_per_episode", self.status["cur_step"]) + self.logger.add_scalar("reward", rew) + self.logger.add_any("obs", obs, loglevel=LogLevel.DEBUG) + self.logger.add_any("policy_act", policy_action, loglevel=LogLevel.DEBUG) + + info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info) + return obs, rew, done, info_dict + + def render(self): + raise NotImplementedError("Render is not implemented in EnvWrapper.") diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py new file mode 100644 index 00000000000..fc9c2c75ea5 --- /dev/null +++ b/qlib/rl/utils/finite_env.py @@ -0,0 +1,337 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This is to support finite env in vector env. +See https://github.com/thu-ml/tianshou/issues/322 for details. +""" + +from __future__ import annotations + +import copy +import warnings +from contextlib import contextmanager + +import gym +import numpy as np +from typing import Any, Set, Callable, Type + +from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv + +from qlib.typehint import Literal +from .log import LogWriter + +__all__ = [ + "generate_nan_observation", + "check_nan_observation", + "FiniteVectorEnv", + "FiniteDummyVectorEnv", + "FiniteSubprocVectorEnv", + "FiniteShmemVectorEnv", + "FiniteEnvType", + "vectorize_env", +] + + +FiniteEnvType = Literal["dummy", "subproc", "shmem"] + + +def fill_invalid(obj): + if isinstance(obj, (int, float, bool)): + return fill_invalid(np.array(obj)) + if hasattr(obj, "dtype"): + if isinstance(obj, np.ndarray): + if np.issubdtype(obj.dtype, np.floating): + return np.full_like(obj, np.nan) + return np.full_like(obj, np.iinfo(obj.dtype).max) + # dealing with corner cases that numpy number is not supported by tianshou's sharray + return fill_invalid(np.array(obj)) + elif isinstance(obj, dict): + return {k: fill_invalid(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [fill_invalid(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(fill_invalid(v) for v in obj) + raise ValueError(f"Unsupported value to fill with invalid: {obj}") + + +def is_invalid(arr): + if hasattr(arr, "dtype"): + if np.issubdtype(arr.dtype, np.floating): + return np.isnan(arr).all() + return (np.iinfo(arr.dtype).max == arr).all() + if isinstance(arr, dict): + return all(is_invalid(o) for o in arr.values()) + if isinstance(arr, (list, tuple)): + return all(is_invalid(o) for o in arr) + if isinstance(arr, (int, float, bool, np.number)): + return is_invalid(np.array(arr)) + return True + + +def generate_nan_observation(obs_space: gym.Space) -> Any: + """The NaN observation that indicates the environment receives no seed. + + We assume that obs is complex and there must be something like float. + Otherwise this logic doesn't work. + """ + + sample = obs_space.sample() + sample = fill_invalid(sample) + return sample + + +def check_nan_observation(obs: Any) -> bool: + """Check whether obs is generated by :func:`generate_nan_observation`.""" + return is_invalid(obs) + + +class FiniteVectorEnv(BaseVectorEnv): + """To allow the paralleled env workers consume a single DataQueue until it's exhausted. + + See `tianshou issue #322 `_. + + The requirement is to make every possible seed (stored in :class:`qlib.rl.utils.DataQueue` in our case) + consumed by exactly one environment. This is not possible by tianshou's native VectorEnv and Collector, + because tianshou is unaware of this "exactly one" constraint, and might launch extra workers. + + Consider a corner case, where concurrency is 2, but there is only one seed in DataQueue. + The reset of two workers must be both called according to the logic in collect. + The returned results of two workers are collected, regardless of what they are. + The problem is, one of the reset result must be invalid, or repeated, + because there's only one need in queue, and collector isn't aware of such situation. + + Luckily, we can hack the vector env, and make a protocol between single env and vector env. + The single environment (should be :class:`qlib.rl.utils.EnvWrapper` in our case) is responsible for + reading from queue, and generate a special observation when the queue is exhausted. The special obs + is called "nan observation", because simply using none causes problems in shared-memory vector env. + :class:`FiniteVectorEnv` then read the observations from all workers, and select those non-nan + observation. It also maintains an ``_alive_env_ids`` to track which workers should never be + called again. When also the environments are exhausted, it will raise StopIteration exception. + + The usage of this vector env in collector are two parts: + + 1. If the data queue is finite (usually when inference), collector should collect "infinity" number of + episodes, until the vector env exhausts by itself. + 2. If the data queue is infinite (usually in training), collector can set number of episodes / steps. + In this case, data would be randomly ordered, and some repetitions wouldn't matter. + + One extra function of this vector env is that it has a logger that explicitly collects logs + from child workers. See :class:`qlib.rl.utils.LogWriter`. + """ + + def __init__( + self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any + ) -> None: + super().__init__(env_fns, **kwargs) + + self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger] + self._alive_env_ids: Set[int] = set() + self._reset_alive_envs() + self._default_obs = self._default_info = self._default_rew = None + self._zombie = False + + self._collector_guarded: bool = False + + def _reset_alive_envs(self): + if not self._alive_env_ids: + # starting or running out + self._alive_env_ids = set(range(self.env_num)) + + # to workaround with tianshou's buffer and batch + def _set_default_obs(self, obs): + if obs is not None and self._default_obs is None: + self._default_obs = copy.deepcopy(obs) + + def _set_default_info(self, info): + if info is not None and self._default_info is None: + self._default_info = copy.deepcopy(info) + + def _set_default_rew(self, rew): + if rew is not None and self._default_rew is None: + self._default_rew = copy.deepcopy(rew) + + def _get_default_obs(self): + return copy.deepcopy(self._default_obs) + + def _get_default_info(self): + return copy.deepcopy(self._default_info) + + def _get_default_rew(self): + return copy.deepcopy(self._default_rew) + + # END + + @staticmethod + def _postproc_env_obs(obs): + # reserved for shmem vector env to restore empty observation + if obs is None or check_nan_observation(obs): + return None + return obs + + @contextmanager + def collector_guard(self): + """Guard the collector. Recommended to guard every collect. + + This guard is for two purposes. + + 1. Catch and ignore the StopIteration exception, which is the stopping signal + thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit. + 2. Notify the loggers that the collect is done what it's done. + + Examples + -------- + >>> with finite_env.collector_guard(): + ... collector.collect(n_episode=INF) + """ + self._collector_guarded = True + + try: + yield self + except StopIteration: + pass + finally: + self._collector_guarded = False + + # At last trigger the loggers + for logger in self._logger: + logger.on_env_all_done() + + def reset(self, id=None): + assert not self._zombie + + # Check whether it's guarded by collector_guard() + if not self._collector_guarded: + warnings.warn( + "Collector is not guarded by FiniteEnv. " + "This may cause unexpected problems, like unexpected StopIteration exception, " + "or missing logs.", + RuntimeWarning, + ) + + id = self._wrap_id(id) + self._reset_alive_envs() + + # ask super to reset alive envs and remap to current index + request_id = list(filter(lambda i: i in self._alive_env_ids, id)) + obs = [None] * len(id) + id2idx = {i: k for k, i in enumerate(id)} + if request_id: + for i, o in zip(request_id, super().reset(request_id)): + obs[id2idx[i]] = self._postproc_env_obs(o) + + for i, o in zip(id, obs): + if o is None and i in self._alive_env_ids: + self._alive_env_ids.remove(i) + + # logging + for i, o in zip(id, obs): + if i in self._alive_env_ids: + for logger in self._logger: + logger.on_env_reset(i, obs) + + # fill empty observation with default(fake) observation + for o in obs: + self._set_default_obs(o) + for i, o in enumerate(obs): + if o is None: + obs[i] = self._get_default_obs() + + if not self._alive_env_ids: + # comment this line so that the env becomes indisposable + # self.reset() + self._zombie = True + raise StopIteration + + return np.stack(obs) + + def step(self, action, id=None): + assert not self._zombie + id = self._wrap_id(id) + id2idx = {i: k for k, i in enumerate(id)} + request_id = list(filter(lambda i: i in self._alive_env_ids, id)) + result = [[None, None, False, None] for _ in range(len(id))] + + # ask super to step alive envs and remap to current index + if request_id: + valid_act = np.stack([action[id2idx[i]] for i in request_id]) + for i, r in zip(request_id, zip(*super().step(valid_act, request_id))): + result[id2idx[i]] = list(r) + result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0]) + + # logging + for i, r in zip(id, result): + if i in self._alive_env_ids: + for logger in self._logger: + logger.on_env_step(i, *r) + + # fill empty observation/info with default(fake) + for _, r, ___, i in result: + self._set_default_info(i) + self._set_default_rew(r) + for i, r in enumerate(result): + if r[0] is None: + result[i][0] = self._get_default_obs() + if r[1] is None: + result[i][1] = self._get_default_rew() + if r[3] is None: + result[i][3] = self._get_default_info() + + return list(map(np.stack, zip(*result))) + + +class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): + pass + + +class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): + pass + + +class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv): + pass + + +def vectorize_env( + env_factory: Callable[..., gym.Env], + env_type: FiniteEnvType, + concurrency: int, + logger: LogWriter | list[LogWriter], +) -> FiniteVectorEnv: + """Helper function to create a vector env. + + Parameters + ---------- + env_factory + Callable to instantiate one single ``gym.Env``. + All concurrent workers will have the same ``env_factory``. + env_type + dummy or subproc or shmem. Corresponding to + `parallelism in tianshou `_. + concurrency + Concurrent environment workers. + logger + Log writers. + + Warnings + -------- + Please do not use lambda expression here for ``env_factory`` as it may create incorrectly-shared instances. + + Don't do: :: + + vectorize_env(lambda: EnvWrapper(...), ...) + + Please do: :: + + def env_factory(): ... + vectorize_env(env_factory, ...) + """ + env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = { + "dummy": FiniteDummyVectorEnv, + "subproc": FiniteSubprocVectorEnv, + "shmem": FiniteShmemVectorEnv, + } + + finite_env_cls = env_type_cls_mapping[env_type] + + return finite_env_cls(logger, [env_factory for _ in range(concurrency)]) diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py new file mode 100644 index 00000000000..3d495b11dac --- /dev/null +++ b/qlib/rl/utils/log.py @@ -0,0 +1,398 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Distributed logger for RL. + +:class:`LogCollector` runs in every environment workers. It collects log info from simulator states, +and add them (as a dict) to auxiliary info returned for each step. + +:class:`LogWriter` runs in the central worker. It decodes the dict collected by :class:`LogCollector` +in each worker, and writes them to console, log files, or tensorboard... + +The two modules communicate by the "log" field in "info" returned by ``env.step()``. +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from enum import IntEnum +from pathlib import Path +from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence + +import numpy as np +import pandas as pd + +from qlib.log import get_module_logger + +if TYPE_CHECKING: + from .env_wrapper import InfoDict + + +__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"] + +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + + +class LogLevel(IntEnum): + """Log-levels for RL training. + The behavior of handling each log level depends on the implementation of :class:`LogWriter`. + """ + + DEBUG = 10 + """If you only want to see the metric in debug mode.""" + PERIODIC = 20 + """If you want to see the metric periodically.""" + # FIXME: I haven't given much thought about this. Let's hold it for one iteration. + + INFO = 30 + """Important log messages.""" + CRITICAL = 40 + """LogWriter should always handle CRITICAL messages""" + + +class LogCollector: + """Logs are first collected in each environment worker, + and then aggregated to stream at the central thread in vector env. + + In :class:`LogCollector`, every metric is added to a dict, which needs to be ``reset()`` at each step. + The dict is sent via the ``info`` in ``env.step()``, and decoded by the :class:`LogWriter` at vector env. + + ``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe. + """ + + _logged: dict[str, tuple[int, Any]] + _min_loglevel: int + + def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC): + self._min_loglevel = int(min_loglevel) + + def reset(self): + """Clear all collected contents.""" + self._logged = {} + + def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None: + if name in self._logged: + raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.") + self._logged[name] = (int(loglevel), metric) + + def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + """Add a string with name into logged contents.""" + if loglevel < self._min_loglevel: + return + if not isinstance(string, str): + raise TypeError(f"{string} is not a string.") + self._add_metric(name, string, loglevel) + + def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + """Add a scalar with name into logged contents. + Scalar will be converted into a float. + """ + if loglevel < self._min_loglevel: + return + + if hasattr(scalar, "item"): + # could be single-item number + scalar = scalar.item() + if not isinstance(scalar, (float, int)): + raise TypeError(f"{scalar} is not and can not be converted into float or integer.") + scalar = float(scalar) + self._add_metric(name, scalar, loglevel) + + def add_array( + self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC + ) -> None: + """Add an array with name into logging.""" + if loglevel < self._min_loglevel: + return + + if not isinstance(array, (np.ndarray, pd.DataFrame, pd.Series)): + raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.") + self._add_metric(name, array, loglevel) + + def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + """Log something with any type. + + As it's an "any" object, the only LogWriter accepting it is pickle. + Therefore pickle must be able to serialize it. + """ + if loglevel < self._min_loglevel: + return + + # FIXME: detect and rescue object that could be scalar or array + + self._add_metric(name, obj, loglevel) + + def logs(self) -> dict[str, np.ndarray]: + return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()} + + +class LogWriter(Generic[ObsType, ActType]): + """Base class for log writers, triggered at every reset and step by finite env. + + What to do with a specific log depends on the implementation of subclassing :class:`LogWriter`. + The general principle is that, it should handle logs above its loglevel (inclusive), + and discard logs that are not acceptable. For instance, console loggers obviously can't handle an image. + """ + + episode_count: int + """Counter of episodes.""" + + step_count: int + """Counter of steps.""" + + global_step: int + """Counter of steps. Won"t be cleared in ``clear``.""" + + global_episode: int + """Counter of episodes. Won"t be cleared in ``clear``.""" + + active_env_ids: Set[int] + """Active environment ids in vector env.""" + + episode_lengths: dict[int, int] + """Map from environment id to episode length.""" + + episode_rewards: dict[int, list[float]] + """Map from environment id to episode total reward.""" + + episode_logs: dict[int, list] + """Map from environment id to episode logs.""" + + def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC): + self.loglevel = loglevel + + self.global_step = 0 + self.global_episode = 0 + + # Information, logs of one episode is stored here. + # This assumes that episode is not too long to fit into the memory. + self.episode_lengths = dict() + self.episode_rewards = dict() + self.episode_logs = dict() + + self.clear() + + def clear(self): + self.episode_count = self.step_count = 0 + self.active_env_ids = set() + self.logs = [] + + def aggregation(self, array: Sequence[Any]) -> Any: + """Aggregation function from step-wise to episode-wise. + + If it's a sequence of float, take the mean. + Otherwise, take the first element. + """ + assert len(array) > 0, "The aggregated array must be not empty." + if all(isinstance(v, float) for v in array): + return np.mean(array) + else: + return array[0] + + def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + """This is triggered at the end of each trajectory. + + Parameters + ---------- + length + Length of this trajectory. + rewards + A list of rewards at each step of this episode. + contents + Logged contents for every steps. + """ + + def log_step(self, reward: float, contents: dict[str, Any]) -> None: + """This is triggered at each step. + + Parameters + ---------- + reward + Reward for this step. + contents + Logged contents for this step. + """ + + def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) -> None: + """Callback for finite env, on each step.""" + + # Update counter + self.global_step += 1 + self.step_count += 1 + + self.active_env_ids.add(env_id) + self.episode_lengths[env_id] += 1 + # TODO: reward can be a list of list for MARL + self.episode_rewards[env_id].append(rew) + + values: dict[str, Any] = {} + + for key, (loglevel, value) in info["log"].items(): + if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME) + values[key] = value + self.episode_logs[env_id].append(values) + + self.log_step(rew, values) + + if done: + # Update counter + self.global_episode += 1 + self.episode_count += 1 + + self.log_episode(self.episode_lengths[env_id], self.episode_rewards[env_id], self.episode_logs[env_id]) + + def on_env_reset(self, env_id: int, obs: ObsType) -> None: + """Callback for finite env. + + Reset episode statistics. Nothing task-specific is logged here because of + `a limitation of tianshou `__. + """ + self.episode_lengths[env_id] = 0 + self.episode_rewards[env_id] = [] + self.episode_logs[env_id] = [] + + def on_env_all_done(self) -> None: + """All done. Time for cleanup.""" + + +class ConsoleWriter(LogWriter): + """Write log messages to console periodically. + + It tracks an average meter for each metric, which is the average value since last ``clear()`` till now. + The display format for each metric is `` ()``. + + Non-single-number metrics are auto skipped. + """ + + prefix: str + """Prefix can be set via ``writer.prefix``.""" + + def __init__( + self, + log_every_n_episode: int = 20, + total_episodes: int | None = None, + float_format: str = ":.4f", + counter_format: str = ":4d", + loglevel: int | LogLevel = LogLevel.PERIODIC, + ): + super().__init__(loglevel) + # TODO: support log_every_n_step + self.log_every_n_episode = log_every_n_episode + self.total_episodes = total_episodes + + self.counter_format = counter_format + self.float_format = float_format + + self.prefix = "" + + self.console_logger = get_module_logger(__name__, level=logging.INFO) + + def clear(self): + super().clear() + # Clear average meters + self.metric_counts: dict[str, int] = defaultdict(int) + self.metric_sums: dict[str, float] = defaultdict(float) + + def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + # Aggregate step-wise to episode-wise + episode_wise_contents: dict[str, list] = defaultdict(list) + + for step_contents in contents: + for name, value in step_contents.items(): + if isinstance(value, float): + episode_wise_contents[name].append(value) + + # Generate log contents and track them in average-meter. + # This should be done at every step, regardless of periodic or not. + logs: dict[str, float] = {} + for name, values in episode_wise_contents.items(): + logs[name] = self.aggregation(values) # type: ignore + + for name, value in logs.items(): + self.metric_counts[name] += 1 + self.metric_sums[name] += value + + if self.episode_count % self.log_every_n_episode == 0 or self.episode_count == self.total_episodes: + # Only log periodically or at the end + self.console_logger.info(self.generate_log_message(logs)) + + def generate_log_message(self, logs: dict[str, float]) -> str: + if self.prefix: + msg_prefix = self.prefix + " " + else: + msg_prefix = "" + if self.total_episodes is None: + msg_prefix += "[Step {" + self.counter_format + "}]" + else: + msg_prefix += "[{" + self.counter_format + "}/" + str(self.total_episodes) + "]" + msg_prefix = msg_prefix.format(self.episode_count) + + msg = "" + for name, value in logs.items(): + # Double-space as delimiter + format_template = r" {} {" + self.float_format + "} ({" + self.float_format + "})" + msg += format_template.format(name, value, self.metric_sums[name] / self.metric_counts[name]) + + msg = msg_prefix + " " + msg + + return msg + + +class CsvWriter(LogWriter): + """Dump all episode metrics to a ``result.csv``. + + This is not the correct implementation. It's only used for first iteration. + """ + + SUPPORTED_TYPES = (float, str, pd.Timestamp) + + all_records: list[dict[str, Any]] + + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC): + super().__init__(loglevel) + self.output_dir = output_dir + self.output_dir.mkdir(exist_ok=True) + + def clear(self): + super().clear() + self.all_records = [] + + def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup + episode_wise_contents: dict[str, list] = defaultdict(list) + + for step_contents in contents: + for name, value in step_contents.items(): + if isinstance(value, self.SUPPORTED_TYPES): + episode_wise_contents[name].append(value) + + logs: dict[str, float] = {} + for name, values in episode_wise_contents.items(): + logs[name] = self.aggregation(values) # type: ignore + + self.all_records.append(logs) + + def on_env_all_done(self) -> None: + # FIXME: this is temporary + pd.DataFrame.from_records(self.all_records).to_csv(self.output_dir / "result.csv", index=False) + + +# The following are not implemented yet. + + +class PickleWriter(LogWriter): + """Dump logs to pickle files.""" + + +class TensorboardWriter(LogWriter): + """Write logs to event files that can be visualized with tensorboard.""" + + +class MlflowWriter(LogWriter): + """Add logs to mlflow.""" + + +class LogBuffer(LogWriter): + """Keep everything in memory.""" diff --git a/qlib/typehint.py b/qlib/typehint.py new file mode 100644 index 00000000000..3a584b4b751 --- /dev/null +++ b/qlib/typehint.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Commonly used types.""" + +import sys + +__all__ = ["Literal", "TypedDict", "final"] + +if sys.version_info >= (3, 8): + from typing import Literal, TypedDict, final # type: ignore # pylint: disable=no-name-in-module +else: + from typing_extensions import Literal, TypedDict, final diff --git a/setup.py b/setup.py index 8780e8be73c..d95f59a250a 100644 --- a/setup.py +++ b/setup.py @@ -134,7 +134,12 @@ def get_version(rel_path: str) -> str: "sphinx", "sphinx_rtd_theme", "pre-commit", - ] + ], + "rl": [ + "tianshou", + "gym", + "torch", + ], }, include_package_data=True, classifiers=[ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000000..7a869da16ee --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import os +import sys + +"""Ignore RL tests on non-linux platform.""" +collect_ignore = [] + +if sys.platform != "linux": + for root, dirs, files in os.walk("rl"): + for file in files: + collect_ignore.append(os.path.join(root, file)) diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000000..2ea41da8c39 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +filterwarnings = + ignore:.*rng.randint:DeprecationWarning + ignore:.*Casting input x to numpy array:UserWarning diff --git a/tests/rl/test_data_queue.py b/tests/rl/test_data_queue.py new file mode 100644 index 00000000000..0b0c61280de --- /dev/null +++ b/tests/rl/test_data_queue.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import multiprocessing +import time + +import numpy as np +import pandas as pd + +from torch.utils.data import Dataset, DataLoader +from qlib.rl.utils.data_queue import DataQueue + + +class DummyDataset(Dataset): + def __init__(self, length): + self.length = length + + def __getitem__(self, index): + assert 0 <= index < self.length + return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD")) + + def __len__(self): + return self.length + + +def _worker(dataloader, collector): + # for i in range(3): + for i, data in enumerate(dataloader): + collector.put(len(data)) + + +def _queue_to_list(queue): + result = [] + while not queue.empty(): + result.append(queue.get()) + return result + + +def test_pytorch_dataloader(): + dataset = DummyDataset(100) + dataloader = DataLoader(dataset, batch_size=None, num_workers=1) + queue = multiprocessing.Queue() + _worker(dataloader, queue) + assert len(set(_queue_to_list(queue))) == 100 + + +def test_multiprocess_shared_dataloader(): + dataset = DummyDataset(100) + with DataQueue(dataset, producer_num_workers=1) as data_queue: + queue = multiprocessing.Queue() + processes = [] + for _ in range(3): + processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue))) + processes[-1].start() + for p in processes: + p.join() + assert len(set(_queue_to_list(queue))) == 100 + + +def test_exit_on_crash_finite(): + def _exit_finite(): + dataset = DummyDataset(100) + + with DataQueue(dataset, producer_num_workers=4) as data_queue: + time.sleep(3) + raise ValueError + + # https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess + + process = multiprocessing.Process(target=_exit_finite) + process.start() + process.join() + + +def test_exit_on_crash_infinite(): + def _exit_infinite(): + dataset = DummyDataset(100) + with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue: + time.sleep(3) + raise ValueError + + process = multiprocessing.Process(target=_exit_infinite) + process.start() + process.join() + + +if __name__ == "__main__": + test_multiprocess_shared_dataloader() diff --git a/tests/rl/test_finite_env.py b/tests/rl/test_finite_env.py new file mode 100644 index 00000000000..d6f2a2ec959 --- /dev/null +++ b/tests/rl/test_finite_env.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import Counter + +import gym +import numpy as np +from tianshou.data import Batch, Collector +from tianshou.policy import BasePolicy +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from qlib.rl.utils.finite_env import ( + LogWriter, + FiniteDummyVectorEnv, + FiniteShmemVectorEnv, + FiniteSubprocVectorEnv, + check_nan_observation, + generate_nan_observation, +) + + +_test_space = gym.spaces.Dict( + { + "sensors": gym.spaces.Dict( + { + "position": gym.spaces.Box(low=-100, high=100, shape=(3,)), + "velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)), + "front_cam": gym.spaces.Tuple( + (gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3))) + ), + "rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), + } + ), + "ext_controller": gym.spaces.MultiDiscrete((5, 2, 2)), + "inner_state": gym.spaces.Dict( + { + "charge": gym.spaces.Discrete(100), + "system_checks": gym.spaces.MultiBinary(10), + "job_status": gym.spaces.Dict( + { + "task": gym.spaces.Discrete(5), + "progress": gym.spaces.Box(low=0, high=100, shape=()), + } + ), + } + ), + } +) + + +class FiniteEnv(gym.Env): + def __init__(self, dataset, num_replicas, rank): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None) + self.iterator = None + self.observation_space = gym.spaces.Discrete(255) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + if self.iterator is None: + self.iterator = iter(self.loader) + try: + self.current_sample, self.step_count = next(self.iterator) + self.current_step = 0 + return self.current_sample + except StopIteration: + self.iterator = None + return generate_nan_observation(self.observation_space) + + def step(self, action): + self.current_step += 1 + assert self.current_step <= self.step_count + return ( + 0, + 1.0, + self.current_step >= self.step_count, + {"sample": self.current_sample, "action": action, "metric": 2.0}, + ) + + +class FiniteEnvWithComplexObs(FiniteEnv): + def __init__(self, dataset, num_replicas, rank): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None) + self.iterator = None + self.observation_space = gym.spaces.Discrete(255) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + if self.iterator is None: + self.iterator = iter(self.loader) + try: + self.current_sample, self.step_count = next(self.iterator) + self.current_step = 0 + return _test_space.sample() + except StopIteration: + self.iterator = None + return generate_nan_observation(self.observation_space) + + def step(self, action): + self.current_step += 1 + assert self.current_step <= self.step_count + return ( + _test_space.sample(), + 1.0, + self.current_step >= self.step_count, + {"sample": _test_space.sample(), "action": action, "metric": 2.0}, + ) + + +class DummyDataset(Dataset): + def __init__(self, length): + self.length = length + self.episodes = [3 * i % 5 + 1 for i in range(self.length)] + + def __getitem__(self, index): + assert 0 <= index < self.length + return index, self.episodes[index] + + def __len__(self): + return self.length + + +class AnyPolicy(BasePolicy): + def forward(self, batch, state=None): + return Batch(act=np.stack([1] * len(batch))) + + def learn(self, batch): + pass + + +def _finite_env_factory(dataset, num_replicas, rank, complex=False): + if complex: + return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank) + return lambda: FiniteEnv(dataset, num_replicas, rank) + + +class MetricTracker(LogWriter): + def __init__(self, length): + super().__init__() + self.counter = Counter() + self.finished = set() + self.length = length + + def on_env_step(self, env_id, obs, rew, done, info): + assert rew == 1.0 + index = info["sample"] + if done: + # assert index not in self.finished + self.finished.add(index) + self.counter[index] += 1 + + def validate(self): + assert len(self.finished) == self.length + for k, v in self.counter.items(): + assert v == k * 3 % 5 + 1 + + +class DoNothingTracker(LogWriter): + def on_env_step(self, *args, **kwargs): + pass + + +def test_finite_dummy_vector_env(): + length = 100 + dataset = DummyDataset(length) + envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs._collector_guarded = True + policy = AnyPolicy() + test_collector = Collector(policy, envs, exploration_noise=True) + + for _ in range(1): + envs._logger = [MetricTracker(length)] + try: + test_collector.collect(n_step=10**18) + except StopIteration: + envs._logger[0].validate() + + +def test_finite_shmem_vector_env(): + length = 100 + dataset = DummyDataset(length) + envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs._collector_guarded = True + policy = AnyPolicy() + test_collector = Collector(policy, envs, exploration_noise=True) + + for _ in range(1): + envs._logger = [MetricTracker(length)] + try: + test_collector.collect(n_step=10**18) + except StopIteration: + envs._logger[0].validate() + + +def test_finite_subproc_vector_env(): + length = 100 + dataset = DummyDataset(length) + envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs._collector_guarded = True + policy = AnyPolicy() + test_collector = Collector(policy, envs, exploration_noise=True) + + for _ in range(1): + envs._logger = [MetricTracker(length)] + try: + test_collector.collect(n_step=10**18) + except StopIteration: + envs._logger[0].validate() + + +def test_nan(): + assert check_nan_observation(generate_nan_observation(_test_space)) + assert not check_nan_observation(_test_space.sample()) + + +def test_finite_dummy_vector_env_complex(): + length = 100 + dataset = DummyDataset(length) + envs = FiniteDummyVectorEnv( + DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)] + ) + envs._collector_guarded = True + policy = AnyPolicy() + test_collector = Collector(policy, envs, exploration_noise=True) + + try: + test_collector.collect(n_step=10**18) + except StopIteration: + pass + + +def test_finite_shmem_vector_env_complex(): + length = 100 + dataset = DummyDataset(length) + envs = FiniteShmemVectorEnv( + DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)] + ) + envs._collector_guarded = True + policy = AnyPolicy() + test_collector = Collector(policy, envs, exploration_noise=True) + + try: + test_collector.collect(n_step=10**18) + except StopIteration: + pass diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py new file mode 100644 index 00000000000..240ffc1e142 --- /dev/null +++ b/tests/rl/test_logger.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from random import randint, choice +from pathlib import Path + +import re +import gym +import numpy as np +import pandas as pd +from gym import spaces +from tianshou.data import Collector, Batch +from tianshou.policy import BasePolicy + +from qlib.log import set_log_with_config +from qlib.config import C +from qlib.constant import INF +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.simulator import Simulator +from qlib.rl.utils.data_queue import DataQueue +from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper +from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter +from qlib.rl.utils.finite_env import vectorize_env + + +class SimpleEnv(gym.Env[int, int]): + def __init__(self): + self.logger = LogCollector() + self.observation_space = gym.spaces.Discrete(2) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + self.step_count = 0 + return 0 + + def step(self, action: int): + self.logger.reset() + + self.logger.add_scalar("reward", 42.0) + + self.logger.add_scalar("a", randint(1, 10)) + self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]})) + + if self.step_count >= 3: + done = choice([False, True]) + else: + done = False + + if 2 <= self.step_count <= 3: + self.logger.add_scalar("c", randint(11, 20)) + + self.step_count += 1 + + return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={}) + + +class AnyPolicy(BasePolicy): + def forward(self, batch, state=None): + return Batch(act=np.stack([1] * len(batch))) + + def learn(self, batch): + pass + + +def test_simple_env_logger(caplog): + set_log_with_config(C.logging_config) + for venv_cls_name in ["dummy", "shmem", "subproc"]: + writer = ConsoleWriter() + csv_writer = CsvWriter(Path(__file__).parent / ".output") + venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer]) + with venv.collector_guard(): + collector = Collector(AnyPolicy(), venv) + collector.collect(n_episode=30) + + output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") + assert output_file.columns.tolist() == ["reward", "a", "c"] + assert len(output_file) >= 30 + + line_counter = 0 + for line in caplog.text.splitlines(): + line = line.strip() + if line: + line_counter += 1 + assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line) + assert line_counter >= 3 + + +class SimpleSimulator(Simulator[int, float, float]): + def __init__(self, initial: int, **kwargs) -> None: + self.initial = float(initial) + + def step(self, action: float) -> None: + import torch + + self.initial += action + self.env.logger.add_scalar("test_a", torch.tensor(233.0)) + self.env.logger.add_scalar("test_b", np.array(200)) + + def get_state(self) -> float: + return self.initial + + def done(self) -> bool: + return self.initial % 1 > 0.5 + + +class DummyStateInterpreter(StateInterpreter[float, float]): + def interpret(self, state: float) -> float: + return state + + @property + def observation_space(self) -> spaces.Box: + return spaces.Box(0, np.inf, shape=(), dtype=np.float32) + + +class DummyActionInterpreter(ActionInterpreter[float, int, float]): + def interpret(self, state: float, action: int) -> float: + return action / 100 + + @property + def action_space(self) -> spaces.Box: + return spaces.Discrete(5) + + +class RandomFivePolicy(BasePolicy): + def forward(self, batch, state=None): + return Batch(act=np.random.randint(5, size=len(batch))) + + def learn(self, batch): + pass + + +def test_logger_with_env_wrapper(): + with DataQueue(list(range(20)), shuffle=False) as data_iterator: + env_wrapper_factory = lambda: EnvWrapper( + SimpleSimulator, + DummyStateInterpreter(), + DummyActionInterpreter(), + data_iterator, + logger=LogCollector(LogLevel.DEBUG), + ) + + # loglevel can be debug here because metrics can all dump into csv + # otherwise, csv writer might crash + csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG) + venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer) + with venv.collector_guard(): + collector = Collector(RandomFivePolicy(), venv) + collector.collect(n_episode=INF * len(venv)) + + output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") + assert len(output_df) == 20 + # obs has a increasing trend + assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum() + assert (output_df["test_a"] == 233).all() + assert (output_df["test_b"] == 200).all() + assert "steps_per_episode" in output_df and "reward" in output_df diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py new file mode 100644 index 00000000000..2ac0d9cbd3c --- /dev/null +++ b/tests/rl/test_saoe_simple.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from functools import partial +from pathlib import Path +from typing import NamedTuple + +import numpy as np +import pandas as pd +import pytest + +import torch +from tianshou.data import Batch + +from qlib.backtest import Order +from qlib.config import C +from qlib.log import set_log_with_config +from qlib.rl.data import pickle_styled +from qlib.rl.entries.test import backtest +from qlib.rl.order_execution import * +from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus + +pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") + + +DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "intraday_saoe" +DATA_DIR = DATA_ROOT_DIR / "us" +BACKTEST_DATA_DIR = DATA_DIR / "backtest" +FEATURE_DATA_DIR = DATA_DIR / "processed" +ORDER_DIR = DATA_DIR / "order" / "valid_bidir" + +CN_DATA_DIR = DATA_ROOT_DIR / "cn" +CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest" +CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed" +CN_ORDER_DIR = CN_DATA_DIR / "order" / "test" +CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights" + + +def test_pickle_data_inspect(): + data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) + assert len(data) == 390 + + data = pickle_styled.load_intraday_processed_data( + DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index() + ) + assert len(data.today) == len(data.yesterday) == 390 + + +def test_simulator_first_step(): + order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) + + simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + state = simulator.get_state() + assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00") + assert state.position == 30.0 + + simulator.step(15.0) + state = simulator.get_state() + assert len(state.history_exec) == 30 + assert state.history_exec.index[0] == pd.Timestamp("2013-12-11 09:30:00") + assert state.history_exec["market_volume"].iloc[0] == 450072.0 + assert abs(state.history_exec["market_price"].iloc[0] - 25.370001) < 1e-4 + assert (state.history_exec["amount"] == 0.5).all() + assert (state.history_exec["deal_amount"] == 0.5).all() + assert abs(state.history_exec["trade_price"].iloc[0] - 25.370001) < 1e-4 + assert abs(state.history_exec["trade_value"].iloc[0] - 12.68500) < 1e-4 + assert state.history_exec["position"].iloc[0] == 29.5 + assert state.history_exec["ffr"].iloc[0] == 1 / 60 + + assert state.history_steps["market_volume"].iloc[0] == 5041147.0 + assert state.history_steps["amount"].iloc[0] == 15.0 + assert state.history_steps["deal_amount"].iloc[0] == 15.0 + assert state.history_steps["ffr"].iloc[0] == 0.5 + assert ( + state.history_steps["pa"].iloc[0] + == (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000 + ) + + assert state.position == 15.0 + assert state.cur_time == pd.Timestamp("2013-12-11 10:00:00") + + +def test_simulator_stop_twap(): + order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) + + simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + for _ in range(13): + simulator.step(1.0) + + state = simulator.get_state() + assert len(state.history_exec) == 390 + assert (state.history_exec["deal_amount"] == 13 / 390).all() + assert state.history_steps["position"].iloc[0] == 12 and state.history_steps["position"].iloc[-1] == 0 + + assert (state.metrics["ffr"] - 1) < 1e-3 + assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4 + assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum()) + assert state.position == 0.0 + assert abs(state.metrics["trade_price"] - state.metrics["market_price"]) < 1e-4 + assert abs(state.metrics["pa"]) < 1e-2 + + assert simulator.done() + + +def test_simulator_stop_early(): + order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) + + with pytest.raises(ValueError): + simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator.step(2.0) + + simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator.step(1.0) + + with pytest.raises(AssertionError): + simulator.step(1.0) + + +def test_simulator_start_middle(): + order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) + + simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + assert len(simulator.ticks_for_order) == 330 + assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") + simulator.step(2.0) + assert simulator.cur_time == pd.Timestamp("2013-12-11 10:30:00") + + for _ in range(10): + simulator.step(1.0) + + simulator.step(2.0) + assert len(simulator.history_exec) == 330 + assert simulator.done() + assert abs(simulator.history_exec["amount"].iloc[-1] - (1 + 2 / 15)) < 1e-4 + assert abs(simulator.metrics["ffr"] - 1) < 1e-4 + + +def test_interpreter(): + order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) + + simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + assert len(simulator.ticks_for_order) == 330 + assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") + + # emulate a env status + class EmulateEnvWrapper(NamedTuple): + status: EnvWrapperStatus + + interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + interpreter_step = CurrentStepStateInterpreter(13) + interpreter_action = CategoricalActionInterpreter(20) + interpreter_action_twap = TwapRelativeActionInterpreter() + + wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[]) + + # first step + interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs)) + + obs = interpreter(simulator.get_state()) + assert obs["cur_tick"] == 45 + assert obs["cur_step"] == 0 + assert obs["position"] == 15.0 + assert obs["position_history"][0] == 15.0 + assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(45)) + assert np.sum(obs["data_processed"][45:]) == 0 + assert obs["data_processed_prev"].shape == (390, 5) + + # first step: second interpreter + interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs)) + + obs = interpreter_step(simulator.get_state()) + assert obs["acquiring"] == 1 + assert obs["position"] == 15.0 + + # second step + simulator.step(5.0) + interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)) + + obs = interpreter(simulator.get_state()) + assert obs["cur_tick"] == 60 + assert obs["cur_step"] == 1 + assert obs["position"] == 10.0 + assert obs["position_history"][:2].tolist() == [15.0, 10.0] + assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(60)) + assert np.sum(obs["data_processed"][60:]) == 0 + + # second step: action + action = interpreter_action(simulator.get_state(), 1) + assert action == 15 / 20 + + interpreter_action_twap.env = EmulateEnvWrapper( + status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs) + ) + action = interpreter_action_twap(simulator.get_state(), 1.5) + assert action == 1.5 + + # fast-forward + for _ in range(10): + simulator.step(0.0) + + # last step + simulator.step(5.0) + interpreter.env = EmulateEnvWrapper( + status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs) + ) + + assert interpreter.env.status["done"] + + obs = interpreter(simulator.get_state()) + assert obs["cur_tick"] == 375 + assert obs["cur_step"] == 12 + assert obs["position"] == 0.0 + assert obs["position_history"][1:11].tolist() == [10.0] * 10 + assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(375)) + assert np.sum(obs["data_processed"][375:]) == 0 + + +def test_network_sanity(): + # we won't check the correctness of networks here + order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59")) + + simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + assert len(simulator.ticks_for_order) == 390 + + class EmulateEnvWrapper(NamedTuple): + status: EnvWrapperStatus + + interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + action_interp = CategoricalActionInterpreter(13) + + wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[]) + + network = Recurrent(interpreter.observation_space) + policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3) + + for i in range(14): + interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs)) + obs = interpreter(simulator.get_state()) + batch = Batch(obs=[obs]) + output = policy(batch) + assert 0 <= output["act"].item() <= 13 + if i < 13: + simulator.step(1.0) + else: + assert obs["cur_tick"] == 389 + assert obs["cur_step"] == 12 + assert obs["position_history"][-1] == 3 + + +@pytest.mark.parametrize("finite_env_type", ["dummy", "subproc", "shmem"]) +def test_twap_strategy(finite_env_type): + set_log_with_config(C.logging_config) + orders = pickle_styled.load_orders(ORDER_DIR) + assert len(orders) == 248 + + state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + action_interp = TwapRelativeActionInterpreter() + policy = AllOne(state_interp.observation_space, action_interp.action_space) + csv_writer = CsvWriter(Path(__file__).parent / ".output") + + backtest( + partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30), + state_interp, + action_interp, + orders, + policy, + [ConsoleWriter(total_episodes=len(orders)), csv_writer], + concurrency=4, + finite_env_type=finite_env_type, + ) + + metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") + assert len(metrics) == 248 + assert np.isclose(metrics["ffr"].mean(), 1.0) + assert np.isclose(metrics["pa"].mean(), 0.0) + assert np.allclose(metrics["pa"], 0.0, atol=2e-3) + + +def test_cn_ppo_strategy(): + set_log_with_config(C.logging_config) + # The data starts with 9:31 and ends with 15:00 + orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) + assert len(orders) == 40 + + state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6) + action_interp = CategoricalActionInterpreter(4) + network = Recurrent(state_interp.observation_space) + policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) + policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu")) + csv_writer = CsvWriter(Path(__file__).parent / ".output") + + backtest( + partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + state_interp, + action_interp, + orders, + policy, + [ConsoleWriter(total_episodes=len(orders)), csv_writer], + concurrency=4, + ) + + metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") + assert len(metrics) == len(orders) + assert np.isclose(metrics["ffr"].mean(), 1.0) + assert np.isclose(metrics["pa"].mean(), -16.21578303474833) + assert np.isclose(metrics["market_price"].mean(), 58.68277690875527) + assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)