diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 88df0c3e2e6..f64b4d32809 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -35,7 +35,7 @@ jobs:
pip install numpy==1.19.5 ruamel.yaml
pip install pyqlib --ignore-installed
- - name: Make html with sphnix
+ - name: Make html with sphinx
run: |
pip install -U sphinx
pip install sphinx_rtd_theme readthedocs_sphinx_ext
@@ -97,12 +97,21 @@ jobs:
run: |
pip install --upgrade pip
pip install flake8
- flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
+ flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
+
+ # https://github.com/python/mypy/issues/10600
+ - name: Check Qlib with mypy
+ run: |
+ pip install mypy
+ mypy qlib --install-types --non-interactive || true
+ mypy qlib
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
+ azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
+ mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
run: |
@@ -113,6 +122,7 @@ jobs:
- name: Install Qlib from source
run: |
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
+ pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
@@ -129,4 +139,3 @@ jobs:
- name: Test workflow by config (install from source)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
-
diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml
index 6f40a3b1f79..0deab4d731c 100644
--- a/.github/workflows/test_macos.yml
+++ b/.github/workflows/test_macos.yml
@@ -38,7 +38,7 @@ jobs:
run: |
pip install --upgrade pip
pip install flake8
- flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
+ flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
- name: Install Qlib with pip
run: |
@@ -65,6 +65,8 @@ jobs:
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
+ azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
+ mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
@@ -75,6 +77,7 @@ jobs:
python -m pip install --upgrade cython
python -m pip install numpy jupyter jupyter_contrib_nbextensions
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
+ python -m pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
run: |
diff --git a/.gitignore b/.gitignore
index a563ed5c7f7..3f9c159aa4c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -27,6 +27,10 @@ examples/estimator/estimator_example/
*.egg-info/
+# test related
+test-output.xml
+.output
+.data
# special software
mlruns/
@@ -34,6 +38,7 @@ mlruns/
tags
.pytest_cache/
+.mypy_cache/
.vscode/
*.swp
diff --git a/.mypy.ini b/.mypy.ini
new file mode 100644
index 00000000000..195a0505f70
--- /dev/null
+++ b/.mypy.ini
@@ -0,0 +1,17 @@
+[mypy]
+exclude = (?x)(
+ ^qlib/backtest
+ | ^qlib/contrib
+ | ^qlib/data
+ | ^qlib/model
+ | ^qlib/strategy
+ | ^qlib/tests
+ | ^qlib/utils
+ | ^qlib/workflow
+ | ^qlib/config\.py$
+ | ^qlib/log\.py$
+ | ^qlib/__init__\.py$
+ )
+ignore_missing_imports = true
+disallow_incomplete_defs = true
+follow_imports = skip
diff --git a/qlib/constant.py b/qlib/constant.py
index fa014a29570..458890957d2 100644
--- a/qlib/constant.py
+++ b/qlib/constant.py
@@ -8,3 +8,6 @@
# Epsilon for avoiding division by zero.
EPS = 1e-12
+
+# Infinity in integer
+INF = 10**18
diff --git a/qlib/log.py b/qlib/log.py
index 8533a7ba377..a2fc5deff09 100644
--- a/qlib/log.py
+++ b/qlib/log.py
@@ -61,7 +61,11 @@ def get_module_logger(module_name, level: Optional[int] = None) -> QlibLogger:
if level is None:
level = C.logging_level
- module_name = "qlib.{}".format(module_name)
+ if not module_name.startswith("qlib."):
+ # Add a prefix of qlib. when the requested ``module_name`` doesn't start with ``qlib.``.
+ # If the module_name is already qlib.xxx, we do not format here. Otherwise, it will become qlib.qlib.xxx.
+ module_name = "qlib.{}".format(module_name)
+
# Get logger.
module_logger = QlibLogger(module_name)
module_logger.setLevel(level)
diff --git a/qlib/rl/aux_info.py b/qlib/rl/aux_info.py
new file mode 100644
index 00000000000..65cd95d5dd7
--- /dev/null
+++ b/qlib/rl/aux_info.py
@@ -0,0 +1,43 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+from typing import Generic, TYPE_CHECKING, TypeVar
+
+from qlib.typehint import final
+
+from .simulator import StateType
+
+if TYPE_CHECKING:
+ from .utils.env_wrapper import EnvWrapper
+
+
+__all__ = ["AuxiliaryInfoCollector"]
+
+AuxInfoType = TypeVar("AuxInfoType")
+
+
+class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
+ """Override this class to collect customized auxiliary information from environment."""
+
+ env: EnvWrapper | None = None
+
+ @final
+ def __call__(self, simulator_state: StateType) -> AuxInfoType:
+ return self.collect(simulator_state)
+
+ def collect(self, simulator_state: StateType) -> AuxInfoType:
+ """Override this for customized auxiliary info.
+ Usually useful in Multi-agent RL.
+
+ Parameters
+ ----------
+ simulator_state
+ Retrieved with ``simulator.get_state()``.
+
+ Returns
+ -------
+ Auxiliary information.
+ """
+ raise NotImplementedError("collect is not implemented!")
diff --git a/qlib/rl/data/__init__.py b/qlib/rl/data/__init__.py
new file mode 100644
index 00000000000..d73517534c8
--- /dev/null
+++ b/qlib/rl/data/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Common utilities to handle ad-hoc-styled data.
+
+Most of these snippets comes from research project (paper code).
+Please take caution when using them in production.
+"""
diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py
new file mode 100644
index 00000000000..6cf38680147
--- /dev/null
+++ b/qlib/rl/data/pickle_styled.py
@@ -0,0 +1,257 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""This module contains utilities to read financial data from pickle-styled files.
+
+This is the format used in `OPD paper `__. NOT the standard data format in qlib.
+
+The data here are all wrapped with ``@lru_cache``, which saves the expensive IO cost to repetitively read the data.
+We also encourage users to use ``get_xxx_yyy`` rather than ``XxxYyy`` (although they are the same thing),
+because ``get_xxx_yyy`` is cache-optimized.
+
+Note that these pickle files are dumped with Python 3.8. Python lower than 3.7 might not be able to load them.
+See `PEP 574 `__ for details.
+
+This file shows resemblence to qlib.backtest.high_performance_ds. We might merge those two in future.
+"""
+
+# TODO: merge with qlib/backtest/high_performance_ds.py
+
+from __future__ import annotations
+
+from functools import lru_cache
+from typing import List, Sequence, cast
+from pathlib import Path
+
+import cachetools
+import numpy as np
+import pandas as pd
+from cachetools.keys import hashkey
+
+from qlib.backtest.decision import OrderDir, Order
+from qlib.typehint import Literal
+
+
+DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
+"""Several ad-hoc deal price.
+``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
+``bid_or_ask_fill``: Based on ``bid_or_ask``. If price is 0, use another price (``$ask0`` / ``$bid0``) instead.
+``close``: Use close price (``$close0``) as deal price.
+"""
+
+
+def _infer_processed_data_column_names(shape: int) -> list[str]:
+ if shape == 16:
+ return [
+ "$open",
+ "$high",
+ "$low",
+ "$close",
+ "$vwap",
+ "$bid",
+ "$ask",
+ "$volume",
+ "$bidV",
+ "$bidV1",
+ "$bidV3",
+ "$bidV5",
+ "$askV",
+ "$askV1",
+ "$askV3",
+ "$askV5",
+ ]
+ if shape == 6:
+ return ["$high", "$low", "$open", "$close", "$vwap", "$volume"]
+ elif shape == 5:
+ return ["$high", "$low", "$open", "$close", "$volume"]
+ raise ValueError(f"Unrecognized data shape: {shape}")
+
+
+def _find_pickle(filename_without_suffix: Path) -> Path:
+ suffix_list = [".pkl", ".pkl.backtest"]
+ paths: List[Path] = []
+ for suffix in suffix_list:
+ path = filename_without_suffix.parent / (filename_without_suffix.name + suffix)
+ if path.exists():
+ paths.append(path)
+ if not paths:
+ raise FileNotFoundError(f"No file starting with '{filename_without_suffix}' found")
+ if len(paths) > 1:
+ raise ValueError(f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}")
+ return paths[0]
+
+
+@lru_cache(maxsize=10) # 10 * 40M = 400MB
+def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
+ return pd.read_pickle(_find_pickle(filename_without_suffix))
+
+
+class IntradayBacktestData:
+ """Raw market data that is often used in backtesting (thus called BacktestData)."""
+
+ def __init__(
+ self,
+ data_dir: Path,
+ stock_id: str,
+ date: pd.Timestamp,
+ deal_price: DealPriceType = "close",
+ order_dir: int | None = None,
+ ):
+ backtest = _read_pickle(data_dir / stock_id)
+ backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
+
+ # No longer need for pandas >= 1.4
+ # backtest = backtest.droplevel([0, 2])
+
+ self.data: pd.DataFrame = backtest
+ self.deal_price_type: DealPriceType = deal_price
+ self.order_dir: int | None = order_dir
+
+ def __repr__(self):
+ with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
+ return f"{self.__class__.__name__}({self.data})"
+
+ def __len__(self):
+ return len(self.data)
+
+ def get_deal_price(self) -> pd.Series:
+ """Return a pandas series that can be indexed with time.
+ See :attribute:`DealPriceType` for details."""
+ if self.deal_price_type in ("bid_or_ask", "bid_or_ask_fill"):
+ if self.order_dir is None:
+ raise ValueError("Order direction cannot be none when deal_price_type is not close.")
+ if self.order_dir == OrderDir.SELL:
+ col = "$bid0"
+ else: # BUY
+ col = "$ask0"
+ elif self.deal_price_type == "close":
+ col = "$close0"
+ else:
+ raise ValueError(f"Unsupported deal_price_type: {self.deal_price_type}")
+ price = self.data[col]
+
+ if self.deal_price_type == "bid_or_ask_fill":
+ if self.order_dir == OrderDir.SELL:
+ fill_col = "$ask0"
+ else:
+ fill_col = "$bid0"
+ price = price.replace(0, np.nan).fillna(self.data[fill_col])
+
+ return price
+
+ def get_volume(self) -> pd.Series:
+ """Return a volume series that can be indexed with time."""
+ return self.data["$volume0"]
+
+ def get_time_index(self) -> pd.DatetimeIndex:
+ return cast(pd.DatetimeIndex, self.data.index)
+
+
+class IntradayProcessedData:
+ """Processed market data after data cleanup and feature engineering.
+
+ It contains both processed data for "today" and "yesterday", as some algorithms
+ might use the market information of the previous day to assist decision making.
+ """
+
+ today: pd.DataFrame
+ """Processed data for "today".
+ Number of records must be ``time_length``, and columns must be ``feature_dim``."""
+
+ yesterday: pd.DataFrame
+ """Processed data for "yesterday".
+ Number of records must be ``time_length``, and columns must be ``feature_dim``."""
+
+ def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
+ proc = _read_pickle(data_dir / stock_id)
+ # We have to infer the names here because,
+ # unfortunately they are not included in the original data.
+ cnames = _infer_processed_data_column_names(feature_dim)
+
+ time_length: int = len(time_index)
+
+ try:
+ # new data format
+ proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
+ assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
+ proc_today = proc[cnames]
+ proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
+ except (IndexError, KeyError):
+ # legacy data
+ proc = proc.loc[pd.IndexSlice[stock_id, date]]
+ assert time_length * feature_dim * 2 == len(proc)
+ proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
+ proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
+ proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
+ proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
+
+ self.today: pd.DataFrame = proc_today
+ self.yesterday: pd.DataFrame = proc_yesterday
+ assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
+ assert len(self.today) == len(self.yesterday) == time_length
+
+ def __repr__(self):
+ with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
+ return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
+
+
+@lru_cache(maxsize=100) # 100 * 50K = 5MB
+def load_intraday_backtest_data(
+ data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
+) -> IntradayBacktestData:
+ return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
+
+
+@cachetools.cached( # type: ignore
+ cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
+ key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
+)
+def load_intraday_processed_data(
+ data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
+) -> IntradayProcessedData:
+ return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
+
+
+def load_orders(
+ order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
+) -> Sequence[Order]:
+ """Load orders, and set start time and end time for the orders."""
+
+ start_time = start_time or pd.Timestamp("0:00:00")
+ end_time = end_time or pd.Timestamp("23:59:59")
+
+ if order_path.is_file():
+ order_df = pd.read_pickle(order_path)
+ else:
+ order_df = []
+ for file in order_path.iterdir():
+ order_data = pd.read_pickle(file)
+ order_df.append(order_data)
+ order_df = pd.concat(order_df)
+
+ order_df = order_df.reset_index()
+
+ # Legacy-style orders have "date" instead of "datetime"
+ if "date" in order_df.columns:
+ order_df = order_df.rename(columns={"date": "datetime"})
+
+ # Sometimes "date" are str rather than Timestamp
+ order_df["datetime"] = pd.to_datetime(order_df["datetime"])
+
+ orders: List[Order] = []
+
+ for _, row in order_df.iterrows():
+ # filter out orders with amount == 0
+ if row["amount"] <= 0:
+ continue
+ orders.append(
+ Order(
+ row["instrument"],
+ row["amount"],
+ int(row["order_type"]),
+ row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
+ row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
+ )
+ )
+
+ return orders
diff --git a/qlib/rl/entries/__init__.py b/qlib/rl/entries/__init__.py
new file mode 100644
index 00000000000..169fa985c50
--- /dev/null
+++ b/qlib/rl/entries/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Train, test, inference utilities.
+
+The APIs in this directory are NOT considered final and are subject to change!
+"""
diff --git a/qlib/rl/entries/test.py b/qlib/rl/entries/test.py
new file mode 100644
index 00000000000..ca311407b15
--- /dev/null
+++ b/qlib/rl/entries/test.py
@@ -0,0 +1,99 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+import copy
+from typing import Callable, Sequence
+
+from tianshou.data import Collector
+from tianshou.policy import BasePolicy
+
+from qlib.constant import INF
+from qlib.log import get_module_logger
+from qlib.rl.simulator import InitialStateType, Simulator
+from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
+from qlib.rl.reward import Reward
+from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
+
+
+_logger = get_module_logger(__name__)
+
+
+def backtest(
+ simulator_fn: Callable[[InitialStateType], Simulator],
+ state_interpreter: StateInterpreter,
+ action_interpreter: ActionInterpreter,
+ initial_states: Sequence[InitialStateType],
+ policy: BasePolicy,
+ logger: LogWriter | list[LogWriter],
+ reward: Reward | None = None,
+ finite_env_type: FiniteEnvType = "subproc",
+ concurrency: int = 2,
+) -> None:
+ """Backtest with the parallelism provided by RL framework.
+
+ Parameters
+ ----------
+ simulator_fn
+ Callable receiving initial seed, returning a simulator.
+ state_interpreter
+ Interprets the state of simulators.
+ action_interpreter
+ Interprets the policy actions.
+ initial_states
+ Initial states to iterate over. Every state will be run exactly once.
+ policy
+ Policy to test against.
+ logger
+ Logger to record the backtest results. Logger must be present because
+ without logger, all information will be lost.
+ reward
+ Optional reward function. For backtest, this is for testing the rewards
+ and logging them only.
+ finite_env_type
+ Type of finite env implementation.
+ concurrency
+ Parallel workers.
+ """
+
+ # To save bandwidth
+ min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
+
+ def env_factory():
+ # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
+ # and could be thread unsafe.
+ # I'm not sure whether it's a design flaw.
+ # I'll rethink about this when designing the trainer.
+
+ if finite_env_type == "dummy":
+ # We could only experience the "threading-unsafe" problem in dummy.
+ state = copy.deepcopy(state_interpreter)
+ action = copy.deepcopy(action_interpreter)
+ rew = copy.deepcopy(reward)
+ else:
+ state, action, rew = state_interpreter, action_interpreter, reward
+
+ return EnvWrapper(
+ simulator_fn,
+ state,
+ action,
+ seed_iterator,
+ rew,
+ logger=LogCollector(min_loglevel=min_loglevel),
+ )
+
+ with DataQueue(initial_states) as seed_iterator:
+ vector_env = vectorize_env(
+ env_factory,
+ finite_env_type,
+ concurrency,
+ logger,
+ )
+
+ policy.eval()
+
+ with vector_env.collector_guard():
+ test_collector = Collector(policy, vector_env)
+ _logger.info("All ready. Start backtest.")
+ test_collector.collect(n_step=INF * len(vector_env))
diff --git a/qlib/rl/entries/train.py b/qlib/rl/entries/train.py
new file mode 100644
index 00000000000..c852e623528
--- /dev/null
+++ b/qlib/rl/entries/train.py
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+# TBD
diff --git a/qlib/rl/env.py b/qlib/rl/env.py
deleted file mode 100644
index 6173a27a869..00000000000
--- a/qlib/rl/env.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-from typing import Union
-
-
-from ..backtest.executor import BaseExecutor
-from .interpreter import StateInterpreter, ActionInterpreter
-from ..utils import init_instance_by_config
-
-
-class BaseRLEnv:
- """Base environment for reinforcement learning"""
-
- def reset(self, **kwargs):
- raise NotImplementedError("reset is not implemented!")
-
- def step(self, action):
- """
- step method of rl env
- Parameters
- ----------
- action :
- action from rl policy
-
- Returns
- -------
- env state to rl policy
- """
- raise NotImplementedError("step is not implemented!")
-
-
-class QlibRLEnv:
- """qlib-based RL env"""
-
- def __init__(
- self,
- executor: BaseExecutor,
- ):
- """
- Parameters
- ----------
- executor : BaseExecutor
- qlib multi-level/single-level executor, which can be regarded as gamecore in RL
- """
- self.executor = executor
-
- def reset(self, **kwargs):
- self.executor.reset(**kwargs)
-
-
-class QlibIntRLEnv(QlibRLEnv):
- """(Qlib)-based RL (Env) with (Interpreter)"""
-
- def __init__(
- self,
- executor: BaseExecutor,
- state_interpreter: Union[dict, StateInterpreter],
- action_interpreter: Union[dict, ActionInterpreter],
- ):
- """
-
- Parameters
- ----------
- state_interpreter : Union[dict, StateInterpreter]
- interpreter that interprets the qlib execute result into rl env state.
-
- action_interpreter : Union[dict, ActionInterpreter]
- interpreter that interprets the rl agent action into qlib order list
- """
- super(QlibIntRLEnv, self).__init__(executor=executor)
- self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
- self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
-
- def step(self, action):
- """
- step method of rl env, it run as following step:
- - Use `action_interpreter.interpret` method to interpret the agent action into order list
- - Execute the order list with qlib executor, and get the executed result
- - Use `state_interpreter.interpret` method to interpret the executed result into env state
-
- Parameters
- ----------
- action :
- action from rl policy
-
- Returns
- -------
- env state to rl policy
- """
- _interpret_decision = self.action_interpreter.interpret(action=action)
- _execute_result = self.executor.execute(trade_decision=_interpret_decision)
- _interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
- return _interpret_state
diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py
index c711b838089..3835b5b9238 100644
--- a/qlib/rl/interpreter.py
+++ b/qlib/rl/interpreter.py
@@ -1,47 +1,150 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
+from __future__ import annotations
-class BaseInterpreter:
- """Base Interpreter"""
+from typing import TYPE_CHECKING, TypeVar, Generic, Any
- def interpret(self, **kwargs):
- raise NotImplementedError("interpret is not implemented!")
+import numpy as np
+from qlib.typehint import final
-class ActionInterpreter(BaseInterpreter):
- """Action Interpreter that interpret rl agent action into qlib orders"""
+from .simulator import StateType, ActType
+
+if TYPE_CHECKING:
+ from .utils.env_wrapper import EnvWrapper
+
+import gym
+from gym import spaces
+
+ObsType = TypeVar("ObsType")
+PolicyActType = TypeVar("PolicyActType")
+
+
+class Interpreter:
+ """Interpreter is a media between states produced by simulators and states needed by RL policies.
+ Interpreters are two-way:
+
+ 1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`.
+ 2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`.
+
+ Inherit one of the two sub-classes to define your own interpreter.
+ This super-class is only used for isinstance check.
- def interpret(self, action, **kwargs):
- """interpret method
+ Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx``
+ in interpreter is anti-pattern. In future, we might support register some interpreter-related
+ states by calling ``self.env.register_state()``, but it's not planned for first iteration.
+ """
+
+
+class StateInterpreter(Generic[StateType, ObsType], Interpreter):
+ """State Interpreter that interpret execution result of qlib executor into rl env state"""
+
+ env: EnvWrapper | None = None
+
+ @property
+ def observation_space(self) -> gym.Space:
+ raise NotImplementedError()
+
+ @final # no overridden
+ def __call__(self, simulator_state: StateType) -> ObsType:
+ obs = self.interpret(simulator_state)
+ self.validate(obs)
+ return obs
+
+ def validate(self, obs: ObsType) -> None:
+ """Validate whether an observation belongs to the pre-defined observation space."""
+ _gym_space_contains(self.observation_space, obs)
+
+ def interpret(self, simulator_state: StateType) -> ObsType:
+ """Interpret the state of simulator.
Parameters
----------
- action :
- rl agent action
+ simulator_state
+ Retrieved with ``simulator.get_state()``.
Returns
-------
- qlib orders
-
+ State needed by policy. Should conform with the state space defined in ``observation_space``.
"""
-
raise NotImplementedError("interpret is not implemented!")
-class StateInterpreter(BaseInterpreter):
- """State Interpreter that interpret execution result of qlib executor into rl env state"""
+class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
+ """Action Interpreter that interpret rl agent action into qlib orders"""
+
+ env: "EnvWrapper" | None = None
+
+ @property
+ def action_space(self) -> gym.Space:
+ raise NotImplementedError()
- def interpret(self, execute_result, **kwargs):
- """interpret method
+ @final # no overridden
+ def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType:
+ self.validate(action)
+ obs = self.interpret(simulator_state, action)
+ return obs
+
+ def validate(self, action: PolicyActType) -> None:
+ """Validate whether an action belongs to the pre-defined action space."""
+ _gym_space_contains(self.action_space, action)
+
+ def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType:
+ """Convert the policy action to simulator action.
Parameters
----------
- execute_result :
- qlib execution result
+ simulator_state
+ Retrieved with ``simulator.get_state()``.
+ action
+ Raw action given by policy.
Returns
- ----------
- rl env state
+ -------
+ The action needed by simulator,
"""
raise NotImplementedError("interpret is not implemented!")
+
+
+def _gym_space_contains(space: gym.Space, x: Any) -> None:
+ """Strengthened version of gym.Space.contains.
+ Giving more diagnostic information on why validation fails.
+
+ Throw exception rather than returning true or false.
+ """
+ if isinstance(space, spaces.Dict):
+ if not isinstance(x, dict) or len(x) != len(space):
+ raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x)
+ for k, subspace in space.spaces.items():
+ if k not in x:
+ raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x)
+ try:
+ _gym_space_contains(subspace, x[k])
+ except GymSpaceValidationError as e:
+ raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e
+
+ elif isinstance(space, spaces.Tuple):
+ if isinstance(x, (list, np.ndarray)):
+ x = tuple(x) # Promote list and ndarray to tuple for contains check
+ if not isinstance(x, tuple) or len(x) != len(space):
+ raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x)
+ for i, (subspace, part) in enumerate(zip(space, x)):
+ try:
+ _gym_space_contains(subspace, part)
+ except GymSpaceValidationError as e:
+ raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e
+
+ else:
+ if not space.contains(x):
+ raise GymSpaceValidationError("Validation error reported by gym.", space, x)
+
+
+class GymSpaceValidationError(Exception):
+ def __init__(self, message: str, space: gym.Space, x: Any):
+ self.message = message
+ self.space = space
+ self.x = x
+
+ def __str__(self):
+ return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
diff --git a/qlib/rl/order_execution/__init__.py b/qlib/rl/order_execution/__init__.py
new file mode 100644
index 00000000000..048dfecac9d
--- /dev/null
+++ b/qlib/rl/order_execution/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+Currently it supports single-asset order execution.
+Multi-asset is on the way.
+"""
+
+from .interpreter import *
+from .network import *
+from .policy import *
+from .simulator_simple import *
diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py
new file mode 100644
index 00000000000..9bb5dc2cf13
--- /dev/null
+++ b/qlib/rl/order_execution/interpreter.py
@@ -0,0 +1,222 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+import math
+from pathlib import Path
+from typing import Any, cast
+
+import numpy as np
+import pandas as pd
+from gym import spaces
+
+from qlib.constant import EPS
+from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
+from qlib.rl.data import pickle_styled
+from qlib.typehint import TypedDict
+
+from .simulator_simple import SAOEState
+
+__all__ = [
+ "FullHistoryStateInterpreter",
+ "CurrentStepStateInterpreter",
+ "CategoricalActionInterpreter",
+ "TwapRelativeActionInterpreter",
+]
+
+
+def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
+ """To 32-bit numeric types. Recursively."""
+ if isinstance(value, pd.DataFrame):
+ return value.to_numpy()
+ if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == "f"):
+ return np.array(value, dtype=np.float32)
+ elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == "i"):
+ return np.array(value, dtype=np.int32)
+ elif isinstance(value, dict):
+ return {k: canonicalize(v) for k, v in value.items()}
+ else:
+ return value
+
+
+class FullHistoryObs(TypedDict):
+ data_processed: Any
+ data_processed_prev: Any
+ acquiring: Any
+ cur_tick: Any
+ cur_step: Any
+ num_step: Any
+ target: Any
+ position: Any
+ position_history: Any
+
+
+class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
+ """The observation of all the history, including today (until this moment), and yesterday.
+
+ Parameters
+ ----------
+ data_dir
+ Path to load data after feature engineering.
+ max_step
+ Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
+ data_ticks
+ Equal to the total number of records. For example, in SAOE per minute,
+ the total ticks is the length of day in minutes.
+ data_dim
+ Number of dimensions in data.
+ """
+
+ def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None:
+ self.data_dir = data_dir
+ self.max_step = max_step
+ self.data_ticks = data_ticks
+ self.data_dim = data_dim
+
+ def interpret(self, state: SAOEState) -> FullHistoryObs:
+ processed = pickle_styled.load_intraday_processed_data(
+ self.data_dir,
+ state.order.stock_id,
+ pd.Timestamp(state.order.start_time.date()),
+ self.data_dim,
+ state.ticks_index,
+ )
+
+ position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32)
+ position_history[0] = state.order.amount
+ position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy()
+
+ assert self.env is not None
+
+ # The min, slice here are to make sure that indices fit into the range,
+ # even after the final step of the simulator (in the done step),
+ # to make network in policy happy.
+ return cast(
+ FullHistoryObs,
+ canonicalize(
+ {
+ "data_processed": self._mask_future_info(processed.today, state.cur_time),
+ "data_processed_prev": processed.yesterday,
+ "acquiring": state.order.direction == state.order.BUY,
+ "cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
+ "cur_step": min(self.env.status["cur_step"], self.max_step - 1),
+ "num_step": self.max_step,
+ "target": state.order.amount,
+ "position": state.position,
+ "position_history": position_history[: self.max_step],
+ }
+ ),
+ )
+
+ @property
+ def observation_space(self):
+ space = {
+ "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
+ "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
+ "acquiring": spaces.Discrete(2),
+ "cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32),
+ "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
+ # TODO: support arbitrary length index
+ "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
+ "target": spaces.Box(-EPS, np.inf, shape=()),
+ "position": spaces.Box(-EPS, np.inf, shape=()),
+ "position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)),
+ }
+ return spaces.Dict(space)
+
+ @staticmethod
+ def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame:
+ arr = arr.copy(deep=True)
+ arr.loc[current:] = 0.0 # mask out data after this moment (inclusive)
+ return arr
+
+
+class CurrentStateObs(TypedDict):
+ acquiring: bool
+ cur_step: int
+ num_step: int
+ target: float
+ position: float
+
+
+class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
+ """The observation of current step.
+
+ Used when policy only depends on the latest state, but not history.
+ The key list is not full. You can add more if more information is needed by your policy.
+ """
+
+ def __init__(self, max_step: int):
+ self.max_step = max_step
+
+ @property
+ def observation_space(self):
+ space = {
+ "acquiring": spaces.Discrete(2),
+ "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
+ "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
+ "target": spaces.Box(-EPS, np.inf, shape=()),
+ "position": spaces.Box(-EPS, np.inf, shape=()),
+ }
+ return spaces.Dict(space)
+
+ def interpret(self, state: SAOEState) -> CurrentStateObs:
+ assert self.env is not None
+ assert self.env.status["cur_step"] <= self.max_step
+ obs = CurrentStateObs(
+ {
+ "acquiring": state.order.direction == state.order.BUY,
+ "cur_step": self.env.status["cur_step"],
+ "num_step": self.max_step,
+ "target": state.order.amount,
+ "position": state.position,
+ }
+ )
+ return obs
+
+
+class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
+ """Convert a discrete policy action to a continuous action, then multiplied by ``order.amount``.
+
+ Parameters
+ ----------
+ values
+ It can be a list of length $L$: $[a_1, a_2, \\ldots, a_L]$.
+ Then when policy givens decision $x$, $a_x$ times order amount is the output.
+ It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated,
+ i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
+ """
+
+ def __init__(self, values: int | list[float]):
+ if isinstance(values, int):
+ values = [i / values for i in range(0, values + 1)]
+ self.action_values = values
+
+ @property
+ def action_space(self) -> spaces.Discrete:
+ return spaces.Discrete(len(self.action_values))
+
+ def interpret(self, state: SAOEState, action: int) -> float:
+ assert 0 <= action < len(self.action_values)
+ return min(state.position, state.order.amount * self.action_values[action])
+
+
+class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
+ """Convert a continous ratio to deal amount.
+
+ The ratio is relative to TWAP on the remainder of the day.
+ For example, there are 5 steps left, and the left position is 300.
+ With TWAP strategy, in each position, 60 should be traded.
+ When this interpreter receives action $a$, its output is $60 \\cdot a$.
+ """
+
+ @property
+ def action_space(self) -> spaces.Box:
+ return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
+
+ def interpret(self, state: SAOEState, action: float) -> float:
+ assert self.env is not None
+ estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
+ twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
+ return min(state.position, twap_volume * action)
diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py
new file mode 100644
index 00000000000..908f96130fb
--- /dev/null
+++ b/qlib/rl/order_execution/network.py
@@ -0,0 +1,118 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+from typing import cast
+
+import torch
+import torch.nn as nn
+from tianshou.data import Batch
+
+from qlib.typehint import Literal
+from .interpreter import FullHistoryObs
+
+__all__ = ["Recurrent"]
+
+
+class Recurrent(nn.Module):
+ """The network architecture proposed in `OPD `_.
+
+ At every timestep the input of policy network is divided into two parts,
+ the public variables and the private variables. which are handled by ``raw_rnn``
+ and ``pri_rnn`` in this network, respectively.
+
+ One minor difference is that, in this implementation, we don't assume the direction to be fixed.
+ Thus, another ``dire_fc`` is added to produce an extra direction-related feature.
+ """
+
+ def __init__(
+ self,
+ obs_space: FullHistoryObs,
+ hidden_dim: int = 64,
+ output_dim: int = 32,
+ rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
+ rnn_num_layers: int = 1,
+ ):
+ super().__init__()
+
+ self.hidden_dim = hidden_dim
+ self.output_dim = output_dim
+ self.num_sources = 3
+
+ rnn_classes = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}
+
+ self.rnn_class = rnn_classes[rnn_type]
+ self.rnn_layers = rnn_num_layers
+
+ self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
+ self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
+ self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
+
+ self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU())
+ self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU())
+ self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
+
+ self._init_extra_branches()
+
+ self.fc = nn.Sequential(
+ nn.Linear(hidden_dim * self.num_sources, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, output_dim),
+ nn.ReLU(),
+ )
+
+ def _init_extra_branches(self):
+ pass
+
+ def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
+ bs, _, data_dim = obs["data_processed"].size()
+ data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
+ cur_step = obs["cur_step"].long()
+ cur_tick = obs["cur_tick"].long()
+ bs_indices = torch.arange(bs, device=device)
+
+ position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step]
+ steps = (
+ torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float()
+ / obs["num_step"].unsqueeze(-1).float()
+ ) # [bs, num_step]
+ priv = torch.stack((position.float(), steps), -1)
+
+ data_in = self.raw_fc(data)
+ data_out, _ = self.raw_rnn(data_in)
+ # as it is padded with zero in front, this should be last minute
+ data_out_slice = data_out[bs_indices, cur_tick]
+
+ priv_in = self.pri_fc(priv)
+ priv_out = self.pri_rnn(priv_in)[0]
+ priv_out = priv_out[bs_indices, cur_step]
+
+ sources = [data_out_slice, priv_out]
+
+ dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float())
+ sources.append(dir_out)
+
+ return sources, data_out
+
+ def forward(self, batch: Batch) -> torch.Tensor:
+ """
+ Input should be a dict (at least) containing:
+
+ - data_processed: [N, T, C]
+ - cur_step: [N] (int)
+ - cur_time: [N] (int)
+ - position_history: [N, S] (S is number of steps)
+ - target: [N]
+ - num_step: [N] (int)
+ - acquiring: [N] (0 or 1)
+ """
+
+ inp = cast(FullHistoryObs, batch)
+ device = inp["data_processed"].device
+
+ sources, _ = self._source_features(inp, device)
+ assert len(sources) == self.num_sources
+
+ out = torch.cat(sources, -1)
+ return self.fc(out)
diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py
new file mode 100644
index 00000000000..f95a53c7588
--- /dev/null
+++ b/qlib/rl/order_execution/policy.py
@@ -0,0 +1,158 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from pathlib import Path
+from typing import Optional, cast
+
+import numpy as np
+import gym
+import torch
+import torch.nn as nn
+from gym.spaces import Discrete
+from tianshou.data import Batch, to_torch
+from tianshou.policy import PPOPolicy, BasePolicy
+
+__all__ = ["AllOne", "PPO"]
+
+
+# baselines #
+
+
+class NonlearnablePolicy(BasePolicy):
+ """Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
+
+ This could be moved outside in future.
+ """
+
+ def __init__(self, obs_space: gym.Space, action_space: gym.Space):
+ super().__init__()
+
+ def learn(self, batch, batch_size, repeat):
+ pass
+
+ def process_fn(self, batch, buffer, indice):
+ pass
+
+
+class AllOne(NonlearnablePolicy):
+ """Forward returns a batch full of 1.
+
+ Useful when implementing some baselines (e.g., TWAP).
+ """
+
+ def forward(self, batch, state=None, **kwargs):
+ return Batch(act=np.full(len(batch), 1.0), state=state)
+
+
+# ppo #
+
+
+class PPOActor(nn.Module):
+ def __init__(self, extractor: nn.Module, action_dim: int):
+ super().__init__()
+ self.extractor = extractor
+ self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
+
+ def forward(self, obs, state=None, info={}):
+ feature = self.extractor(to_torch(obs, device=auto_device(self)))
+ out = self.layer_out(feature)
+ return out, state
+
+
+class PPOCritic(nn.Module):
+ def __init__(self, extractor: nn.Module):
+ super().__init__()
+ self.extractor = extractor
+ self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
+
+ def forward(self, obs, state=None, info={}):
+ feature = self.extractor(to_torch(obs, device=auto_device(self)))
+ return self.value_out(feature).squeeze(dim=-1)
+
+
+class PPO(PPOPolicy):
+ """A wrapper of tianshou PPOPolicy.
+
+ Differences:
+
+ - Auto-create actor and critic network. Supports discrete action space only.
+ - Dedup common parameters between actor network and critic network
+ (not sure whether this is included in latest tianshou or not).
+ - Support a ``weight_file`` that supports loading checkpoint.
+ - Some parameters' default values are different from original.
+ """
+
+ def __init__(
+ self,
+ network: nn.Module,
+ obs_space: gym.Space,
+ action_space: gym.Space,
+ lr: float,
+ weight_decay: float = 0.0,
+ discount_factor: float = 1.0,
+ max_grad_norm: float = 100.0,
+ reward_normalization: bool = True,
+ eps_clip: float = 0.3,
+ value_clip: float = True,
+ vf_coef: float = 1.0,
+ gae_lambda: float = 1.0,
+ max_batchsize: int = 256,
+ deterministic_eval: bool = True,
+ weight_file: Optional[Path] = None,
+ ):
+ assert isinstance(action_space, Discrete)
+ actor = PPOActor(network, action_space.n)
+ critic = PPOCritic(network)
+ optimizer = torch.optim.Adam(
+ chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
+ )
+ super().__init__(
+ actor,
+ critic,
+ optimizer,
+ torch.distributions.Categorical,
+ discount_factor=discount_factor,
+ max_grad_norm=max_grad_norm,
+ reward_normalization=reward_normalization,
+ eps_clip=eps_clip,
+ value_clip=value_clip,
+ vf_coef=vf_coef,
+ gae_lambda=gae_lambda,
+ max_batchsize=max_batchsize,
+ deterministic_eval=deterministic_eval,
+ observation_space=obs_space,
+ action_space=action_space,
+ )
+ if weight_file is not None:
+ load_weight(self, weight_file)
+
+
+# utilities: these should be put in a separate (common) file. #
+
+
+def auto_device(module: nn.Module) -> torch.device:
+ for param in module.parameters():
+ return param.device
+ return torch.device("cpu") # fallback to cpu
+
+
+def load_weight(policy, path):
+ assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
+ loaded_weight = torch.load(path, map_location="cpu")
+ try:
+ policy.load_state_dict(loaded_weight)
+ except RuntimeError:
+ # try again by loading the converted weight
+ # https://github.com/thu-ml/tianshou/issues/468
+ for k in list(loaded_weight):
+ loaded_weight["_actor_critic." + k] = loaded_weight[k]
+ policy.load_state_dict(loaded_weight)
+
+
+def chain_dedup(*iterables):
+ seen = set()
+ for iterable in iterables:
+ for i in iterable:
+ if i not in seen:
+ seen.add(i)
+ yield i
diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py
new file mode 100644
index 00000000000..71aaa222be5
--- /dev/null
+++ b/qlib/rl/order_execution/simulator_qlib.py
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Placeholder for qlib-based simulator."""
diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py
new file mode 100644
index 00000000000..8022c34cece
--- /dev/null
+++ b/qlib/rl/order_execution/simulator_simple.py
@@ -0,0 +1,403 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import NamedTuple, Any, TypeVar, cast
+
+import numpy as np
+import pandas as pd
+
+from qlib.backtest.decision import Order, OrderDir
+from qlib.constant import EPS
+from qlib.rl.simulator import Simulator
+from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
+from qlib.rl.utils import LogLevel
+from qlib.typehint import TypedDict
+
+__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
+
+ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
+
+
+class SAOEMetrics(TypedDict):
+ """Metrics for SAOE accumulated for a "period".
+ It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute.
+
+ Warnings
+ --------
+ The type hints are for single elements. In lots of times, they can be vectorized.
+ For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float.
+ """
+
+ stock_id: str
+ """Stock ID of this record."""
+ datetime: pd.Timestamp
+ """Datetime of this record (this is index in the dataframe)."""
+ direction: int
+ """Direction of the order. 0 for sell, 1 for buy."""
+
+ # Market information.
+ market_volume: float
+ """(total) market volume traded in the period."""
+ market_price: float
+ """Deal price. If it's a period of time, this is the average market deal price."""
+
+ # Strategy records.
+
+ amount: float
+ """Total amount (volume) strategy intends to trade."""
+ inner_amount: float
+ """Total amount that the lower-level strategy intends to trade
+ (might be larger than amount, e.g., to ensure ffr)."""
+
+ deal_amount: float
+ """Amount that successfully takes effect (must be less than inner_amount)."""
+ trade_price: float
+ """The average deal price for this strategy."""
+ trade_value: float
+ """Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
+ position: float
+ """Position left after this "period"."""
+
+ # Accumulated metrics
+
+ ffr: float
+ """Completed how much percent of the daily order."""
+
+ pa: float
+ """Price advantage compared to baseline (i.e., trade with baseline market price).
+ The baseline is trade price when using TWAP strategy to execute this order.
+ Please note that there could be data leak here).
+ Unit is BP (basis point, 1/10000)."""
+
+
+class SAOEState(NamedTuple):
+ """Data structure holding a state for SAOE simulator."""
+
+ order: Order
+ """The order we are dealing with."""
+ cur_time: pd.Timestamp
+ """Current time, e.g., 9:30."""
+ position: float
+ """Current remaining volume to execute."""
+ history_exec: pd.DataFrame
+ """See :attr:`SingleAssetOrderExecution.history_exec`."""
+ history_steps: pd.DataFrame
+ """See :attr:`SingleAssetOrderExecution.history_steps`."""
+
+ metrics: SAOEMetrics | None
+ """Daily metric, only available when the trading is in "done" state."""
+
+ backtest_data: IntradayBacktestData
+ """Backtest data is included in the state.
+ Actually, only the time index of this data is needed, at this moment.
+ I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented.
+ Interpreter can use this as they wish, but they should be careful not to leak future data.
+ """
+
+ ticks_per_step: int
+ """How many ticks for each step."""
+ ticks_index: pd.DatetimeIndex
+ """Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59]."""
+ ticks_for_order: pd.DatetimeIndex
+ """Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44]."""
+
+
+class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
+ """Single-asset order execution (SAOE) simulator.
+
+ As there's no "calendar" in the simple simulator, ticks are used to trade.
+ A tick is a record (a line) in the pickle-styled data file.
+ Each tick is considered as a individual trading opportunity.
+ If such fine granularity is not needed, use ``ticks_per_step`` to
+ lengthen the ticks for each step.
+
+ In each step, the traded amount are "equally" splitted to each tick,
+ then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
+ and if it's the last step, try to ensure all the amount to be executed.
+
+ Parameters
+ ----------
+ initial
+ The seed to start an SAOE simulator is an order.
+ ticks_per_step
+ How many ticks per step.
+ data_dir
+ Path to load backtest data
+ vol_threshold
+ Maximum execution volume (divided by market execution volume).
+ """
+
+ history_exec: pd.DataFrame
+ """All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
+
+ history_steps: pd.DataFrame
+ """Positions at each step. The position before first step is also recorded.
+ See :class:`SAOEMetrics` for available columns."""
+
+ metrics: SAOEMetrics | None
+ """Metrics. Only available when done."""
+
+ twap_price: float
+ """This price is used to compute price advantage.
+ It"s defined as the average price in the period from order"s start time to end time."""
+
+ ticks_index: pd.DatetimeIndex
+ """All available ticks for the day (not restricted to order)."""
+
+ ticks_for_order: pd.DatetimeIndex
+ """Ticks that is available for trading (sliced by order)."""
+
+ def __init__(
+ self,
+ order: Order,
+ data_dir: Path,
+ ticks_per_step: int = 30,
+ deal_price_type: DealPriceType = "close",
+ vol_threshold: float | None = None,
+ ) -> None:
+ self.order = order
+ self.ticks_per_step: int = ticks_per_step
+ self.deal_price_type = deal_price_type
+ self.vol_threshold = vol_threshold
+ self.data_dir = data_dir
+ self.backtest_data = load_intraday_backtest_data(
+ self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
+ )
+
+ self.ticks_index = self.backtest_data.get_time_index()
+
+ # Get time index available for trading
+ self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)
+
+ self.cur_time = self.ticks_for_order[0]
+ # NOTE: astype(float) is necessary in some systems.
+ # this will align the precision with `.to_numpy()` in `_split_exec_vol`
+ self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())
+
+ self.position = order.amount
+
+ metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
+ # NOTE: can empty dataframe contain index?
+ self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
+ self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
+ self.metrics = None
+
+ self.market_price: np.ndarray | None = None
+ self.market_vol: np.ndarray | None = None
+ self.market_vol_limit: np.ndarray | None = None
+
+ def step(self, amount: float) -> None:
+ """Execute one step or SAOE.
+
+ Parameters
+ ----------
+ amount
+ The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
+ """
+
+ assert not self.done()
+
+ self.market_price = self.market_vol = None # avoid misuse
+ exec_vol = self._split_exec_vol(amount)
+ assert self.market_price is not None and self.market_vol is not None
+
+ ticks_position = self.position - np.cumsum(exec_vol)
+
+ self.position -= exec_vol.sum()
+ if self.position < -EPS or (exec_vol < -EPS).any():
+ raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")
+
+ # Get time index available for this step
+ time_index = self._get_ticks_slice(self.cur_time, self._next_time())
+
+ self.history_exec = self._dataframe_append(
+ self.history_exec,
+ SAOEMetrics(
+ # It should have the same keys with SAOEMetrics,
+ # but the values do not necessarily have the annotated type.
+ # Some values could be vectorized (e.g., exec_vol).
+ stock_id=self.order.stock_id,
+ datetime=time_index,
+ direction=self.order.direction,
+ market_volume=self.market_vol,
+ market_price=self.market_price,
+ amount=exec_vol,
+ inner_amount=exec_vol,
+ deal_amount=exec_vol,
+ trade_price=self.market_price,
+ trade_value=self.market_price * exec_vol,
+ position=ticks_position,
+ ffr=exec_vol / self.order.amount,
+ pa=price_advantage(self.market_price, self.twap_price, self.order.direction),
+ ),
+ )
+
+ self.history_steps = self._dataframe_append(
+ self.history_steps,
+ [self._metrics_collect(self.cur_time, self.market_vol, self.market_price, amount, exec_vol)],
+ )
+
+ if self.done():
+ if self.env is not None:
+ self.env.logger.add_any("history_steps", self.history_steps, loglevel=LogLevel.DEBUG)
+ self.env.logger.add_any("history_exec", self.history_exec, loglevel=LogLevel.DEBUG)
+
+ self.metrics = self._metrics_collect(
+ self.ticks_index[0], # start time
+ self.history_exec["market_volume"],
+ self.history_exec["market_price"],
+ self.history_steps["amount"].sum(),
+ self.history_exec["deal_amount"],
+ )
+
+ # NOTE (yuge): It looks to me that it's the "correct" decision to
+ # put all the logs here, because only components like simulators themselves
+ # have the knowledge about what could appear in the logs, and what's the format.
+ # But I admit it's not necessarily the most convenient way.
+ # I'll rethink about it when we have the second environment
+ # Maybe some APIs like self.logger.enable_auto_log() ?
+
+ if self.env is not None:
+ for key, value in self.metrics.items():
+ if isinstance(value, float):
+ self.env.logger.add_scalar(key, value)
+ else:
+ self.env.logger.add_any(key, value)
+
+ self.cur_time = self._next_time()
+
+ def get_state(self) -> SAOEState:
+ return SAOEState(
+ order=self.order,
+ cur_time=self.cur_time,
+ position=self.position,
+ history_exec=self.history_exec,
+ history_steps=self.history_steps,
+ metrics=self.metrics,
+ backtest_data=self.backtest_data,
+ ticks_per_step=self.ticks_per_step,
+ ticks_index=self.ticks_index,
+ ticks_for_order=self.ticks_for_order,
+ )
+
+ def done(self) -> bool:
+ return self.position < EPS or self.cur_time >= self.order.end_time
+
+ def _next_time(self) -> pd.Timestamp:
+ """The "current time" (``cur_time``) for next step."""
+ # Look for next time on time index
+ current_loc = self.ticks_index.get_loc(self.cur_time)
+ next_loc = current_loc + self.ticks_per_step
+
+ # Calibrate the next location to multiple of ticks_per_step.
+ # This is to make sure that:
+ # as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon.
+ next_loc = next_loc - next_loc % self.ticks_per_step
+
+ if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.order.end_time:
+ return self.ticks_index[next_loc]
+ else:
+ return self.order.end_time
+
+ def _cur_duration(self) -> pd.Timedelta:
+ """The "duration" of this step (step that is about to happen)."""
+ return self._next_time() - self.cur_time
+
+ def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray:
+ """
+ Split the volume in each step into minutes, considering possible constraints.
+ This follows TWAP strategy.
+ """
+ next_time = self._next_time()
+
+ # get the backtest data for next interval
+ self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
+ self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
+
+ assert self.market_vol is not None and self.market_price is not None
+
+ # split the volume equally into each minute
+ exec_vol = np.repeat(exec_vol_sum / len(self.market_price), len(self.market_price))
+
+ # apply the volume threshold
+ market_vol_limit = self.vol_threshold * self.market_vol if self.vol_threshold is not None else np.inf
+ exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
+
+ # Complete all the order amount at the last moment.
+ if next_time >= self.order.end_time:
+ exec_vol[-1] += self.position - exec_vol.sum()
+ exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
+
+ return exec_vol
+
+ def _metrics_collect(
+ self,
+ datetime: pd.Timestamp,
+ market_vol: np.ndarray,
+ market_price: np.ndarray,
+ amount: float, # intended to trade such amount
+ exec_vol: np.ndarray,
+ ) -> SAOEMetrics:
+ assert len(market_vol) == len(market_price) == len(exec_vol)
+
+ if np.abs(np.sum(exec_vol)) < EPS:
+ exec_avg_price = 0.0
+ else:
+ exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
+ if hasattr(exec_avg_price, "item"): # could be numpy scalar
+ exec_avg_price = exec_avg_price.item() # type: ignore
+
+ return SAOEMetrics(
+ stock_id=self.order.stock_id,
+ datetime=datetime,
+ direction=self.order.direction,
+ market_volume=market_vol.sum(),
+ market_price=market_price.mean(),
+ amount=amount,
+ inner_amount=exec_vol.sum(),
+ deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
+ trade_price=exec_avg_price,
+ trade_value=np.sum(market_price * exec_vol),
+ position=self.position,
+ ffr=float(exec_vol.sum() / self.order.amount),
+ pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
+ )
+
+ def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex:
+ if not include_end:
+ end = end - ONE_SEC
+ return self.ticks_index[self.ticks_index.slice_indexer(start, end)]
+
+ @staticmethod
+ def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
+ # dataframe.append is deprecated
+ other_df = pd.DataFrame(other).set_index("datetime")
+ other_df.index.name = "datetime"
+ return pd.concat([df, other_df], axis=0)
+
+
+_float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
+
+
+def price_advantage(
+ exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
+) -> _float_or_ndarray:
+ if baseline_price == 0: # something is wrong with data. Should be nan here
+ if isinstance(exec_price, float):
+ return 0.0
+ else:
+ return np.zeros_like(exec_price)
+ if direction == OrderDir.BUY:
+ res = (1 - exec_price / baseline_price) * 10000
+ elif direction == OrderDir.SELL:
+ res = (exec_price / baseline_price - 1) * 10000
+ else:
+ raise ValueError(f"Unexpected order direction: {direction}")
+ res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
+ if res_wo_nan.size == 1:
+ return res_wo_nan.item()
+ else:
+ return cast(_float_or_ndarray, res_wo_nan)
diff --git a/qlib/rl/reward.py b/qlib/rl/reward.py
new file mode 100644
index 00000000000..20d98587427
--- /dev/null
+++ b/qlib/rl/reward.py
@@ -0,0 +1,84 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+from typing import Generic, Any, TypeVar, TYPE_CHECKING
+
+from qlib.typehint import final
+
+if TYPE_CHECKING:
+ from .utils.env_wrapper import EnvWrapper
+
+SimulatorState = TypeVar("SimulatorState")
+
+
+class Reward(Generic[SimulatorState]):
+ """
+ Reward calculation component that takes a single argument: state of simulator. Returns a real number: reward.
+
+ Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
+ """
+
+ env: EnvWrapper | None = None
+
+ @final
+ def __call__(self, simulator_state: SimulatorState) -> float:
+ return self.reward(simulator_state)
+
+ def reward(self, simulator_state: SimulatorState) -> float:
+ """Implement this method for your own reward."""
+ raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
+
+ def log(self, name, value):
+ self.env.logger.add_scalar(name, value)
+
+
+class RewardCombination(Reward):
+ """Combination of multiple reward."""
+
+ def __init__(self, rewards: dict[str, tuple[Reward, float]]):
+ self.rewards = rewards
+
+ def reward(self, simulator_state: Any) -> float:
+ total_reward = 0.0
+ for name, (reward_fn, weight) in self.rewards.items():
+ rew = reward_fn(simulator_state) * weight
+ total_reward += rew
+ self.log(name, rew)
+ return total_reward
+
+
+# TODO:
+# reward_factory is disabled for now
+
+# _RegistryConfigReward = RegistryConfig[REWARDS]
+
+
+# @configclass
+# class _WeightedRewardConfig:
+# weight: float
+# reward: _RegistryConfigReward
+
+
+# RewardConfig = Union[_RegistryConfigReward, Dict[str, Union[_RegistryConfigReward, _WeightedRewardConfig]]]
+
+
+# def reward_factory(reward_config: RewardConfig) -> Reward:
+# """
+# Use this factory to instantiate the reward from config.
+# Simply using ``reward_config.build()`` might not work because reward can have complex combinations.
+# """
+# if isinstance(reward_config, dict):
+# # as reward combination
+# rewards = {}
+# for name, rew in reward_config.items():
+# if not isinstance(rew, _WeightedRewardConfig):
+# # default weight is 1.
+# rew = _WeightedRewardConfig(weight=1., rew=rew)
+# # no recursive build in this step
+# rewards[name] = (rew.reward.build(), rew.weight)
+# return RewardCombination(rewards)
+# else:
+# # single reward
+# return reward_config.build()
diff --git a/qlib/rl/seed.py b/qlib/rl/seed.py
new file mode 100644
index 00000000000..93d452a4a2a
--- /dev/null
+++ b/qlib/rl/seed.py
@@ -0,0 +1,12 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Defines a set of initial state definitions and state-set definitions.
+
+With single-asset order execution only, the only seed is order.
+"""
+
+from typing import TypeVar
+
+InitialStateType = TypeVar("InitialStateType")
+"""Type of data that creates the simulator."""
diff --git a/qlib/rl/simulator.py b/qlib/rl/simulator.py
new file mode 100644
index 00000000000..56fc12042c1
--- /dev/null
+++ b/qlib/rl/simulator.py
@@ -0,0 +1,75 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+from typing import TypeVar, Generic, Any, TYPE_CHECKING
+
+from .seed import InitialStateType
+
+if TYPE_CHECKING:
+ from .utils.env_wrapper import EnvWrapper
+
+StateType = TypeVar("StateType")
+"""StateType stores all the useful data in the simulation process
+(as well as utilities to generate/retrieve data when needed)."""
+
+ActType = TypeVar("ActType")
+"""This ActType is the type of action at the simulator end."""
+
+
+class Simulator(Generic[InitialStateType, StateType, ActType]):
+ """
+ Simulator that resets with ``__init__``, and transits with ``step(action)``.
+
+ To make the data-flow clear, we make the following restrictions to Simulator:
+
+ 1. The only way to modify the inner status of a simulator is by using ``step(action)``.
+ 2. External modules can *read* the status of a simulator by using ``simulator.get_state()``,
+ and check whether the simulator is in the ending state by calling ``simulator.done()``.
+
+ A simulator is defined to be bounded with three types:
+
+ - *InitialStateType* that is the type of the data used to create the simulator.
+ - *StateType* that is the type of the **status** (state) of the simulator.
+ - *ActType* that is the type of the **action**, which is the input received in each step.
+
+ Different simulators might share the same StateType. For example, when they are dealing with the same task,
+ but with different simulation implementation. With the same type, they can safely share other components in the MDP.
+
+ Simulators are ephemeral. The lifecycle of a simulator starts with an initial state, and ends with the trajectory.
+ In another word, when the trajectory ends, simulator is recycled.
+ If simulators want to share context between (e.g., for speed-up purposes),
+ this could be done by accessing the weak reference of environment wrapper.
+
+ Attributes
+ ----------
+ env
+ A reference of env-wrapper, which could be useful in some corner cases.
+ Simulators are discouraged to use this, because it's prone to induce errors.
+ """
+
+ env: EnvWrapper | None = None
+
+ def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
+ pass
+
+ def step(self, action: ActType) -> None:
+ """Receives an action of ActType.
+
+ Simulator should update its internal state, and return None.
+ The updated state can be retrieved with ``simulator.get_state()``.
+ """
+ raise NotImplementedError()
+
+ def get_state(self) -> StateType:
+ raise NotImplementedError()
+
+ def done(self) -> bool:
+ """Check whether the simulator is in a "done" state.
+ When simulator is in a "done" state,
+ it should no longer receives any ``step`` request.
+ As simulators are ephemeral, to reset the simulator,
+ the old one should be destroyed and a new simulator can be created.
+ """
+ raise NotImplementedError()
diff --git a/qlib/rl/utils/__init__.py b/qlib/rl/utils/__init__.py
new file mode 100644
index 00000000000..4a1fa9d9058
--- /dev/null
+++ b/qlib/rl/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from .data_queue import *
+from .env_wrapper import *
+from .finite_env import *
+from .log import *
diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py
new file mode 100644
index 00000000000..1e9dab24da4
--- /dev/null
+++ b/qlib/rl/utils/data_queue.py
@@ -0,0 +1,179 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+import multiprocessing
+import threading
+import time
+import warnings
+from queue import Empty
+from typing import TypeVar, Generic, Sequence, cast
+
+from qlib.log import get_module_logger
+
+_logger = get_module_logger(__name__)
+
+T = TypeVar("T")
+
+__all__ = ["DataQueue"]
+
+
+class DataQueue(Generic[T]):
+ """Main process (producer) produces data and stores them in a queue.
+ Sub-processes (consumers) can retrieve the data-points from the queue.
+ Data-points are generated via reading items from ``dataset``.
+
+ :class:`DataQueue` is ephemeral. You must create a new DataQueue
+ when the ``repeat`` is exhausted.
+
+ See the documents of :class:`qlib.rl.utils.FiniteVectorEnv` for more background.
+
+ Parameters
+ ----------
+ dataset
+ The dataset to read data from. Must implement ``__len__`` and ``__getitem__``.
+ repeat
+ Iterate over the data-points for how many times. Use ``-1`` to iterate forever.
+ shuffle
+ If ``shuffle`` is true, the items will be read in random order.
+ producer_num_workers
+ Concurrent workers for data-loading.
+ queue_maxsize
+ Maximum items to put into queue before it jams.
+
+ Examples
+ --------
+ >>> data_queue = DataQueue(my_dataset)
+ >>> with data_queue:
+ ... ...
+
+ In worker:
+
+ >>> for data in data_queue:
+ ... print(data)
+ """
+
+ def __init__(
+ self,
+ dataset: Sequence[T],
+ repeat: int = 1,
+ shuffle: bool = True,
+ producer_num_workers: int = 0,
+ queue_maxsize: int = 0,
+ ):
+ if queue_maxsize == 0:
+ if os.cpu_count() is not None:
+ queue_maxsize = cast(int, os.cpu_count())
+ _logger.info(f"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming.")
+ else:
+ queue_maxsize = 1
+ _logger.warning(f"CPU count not available. Setting queue maxsize to 1.")
+
+ self.dataset: Sequence[T] = dataset
+ self.repeat: int = repeat
+ self.shuffle: bool = shuffle
+ self.producer_num_workers: int = producer_num_workers
+
+ self._activated: bool = False
+ self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
+ self._done = multiprocessing.Value("i", 0)
+
+ def __enter__(self):
+ self.activate()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.cleanup()
+
+ def cleanup(self):
+ with self._done.get_lock():
+ self._done.value += 1
+ for repeat in range(500):
+ if repeat >= 1:
+ warnings.warn(f"After {repeat} cleanup, the queue is still not empty.", category=RuntimeWarning)
+ while not self._queue.empty():
+ try:
+ self._queue.get(block=False)
+ except Empty:
+ pass
+ # Sometimes when the queue gets emptied, more data have already been sent,
+ # and they are on the way into the queue.
+ # If these data didn't get consumed, it will jam the queue and make the process hang.
+ # We wait a second here for potential data arriving, and check again (for ``repeat`` times).
+ time.sleep(1.0)
+ if self._queue.empty():
+ break
+ _logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
+
+ def get(self, block=True):
+ if not hasattr(self, "_first_get"):
+ self._first_get = True
+ if self._first_get:
+ timeout = 5.0
+ self._first_get = False
+ else:
+ timeout = 0.5
+ while True:
+ try:
+ return self._queue.get(block=block, timeout=timeout)
+ except Empty:
+ if self._done.value:
+ raise StopIteration # pylint: disable=raise-missing-from
+
+ def put(self, obj, block=True, timeout=None):
+ return self._queue.put(obj, block=block, timeout=timeout)
+
+ def mark_as_done(self):
+ with self._done.get_lock():
+ self._done.value = 1
+
+ def done(self):
+ return self._done.value
+
+ def activate(self):
+ if self._activated:
+ raise ValueError("DataQueue can not activate twice.")
+ thread = threading.Thread(target=self._producer, daemon=True)
+ thread.start()
+ self._activated = True
+ return self
+
+ def __del__(self):
+ _logger.debug(f"__del__ of {__name__}.DataQueue")
+ self.cleanup()
+
+ def __iter__(self):
+ if not self._activated:
+ raise ValueError(
+ "Need to call activate() to launch a daemon worker " "to produce data into data queue before using it."
+ )
+ return self._consumer()
+
+ def _consumer(self):
+ while True:
+ try:
+ yield self.get()
+ except StopIteration:
+ _logger.debug("Data consumer timed-out from get.")
+ return
+
+ def _producer(self):
+ # pytorch dataloader is used here only because we need its sampler and multi-processing
+ from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
+
+ dataloader = DataLoader(
+ cast(Dataset[T], self.dataset),
+ batch_size=None,
+ num_workers=self.producer_num_workers,
+ shuffle=self.shuffle,
+ collate_fn=lambda t: t, # identity collate fn
+ )
+ repeat = 10**18 if self.repeat == -1 else self.repeat
+ for _rep in range(repeat):
+ for data in dataloader:
+ if self._done.value:
+ # Already done.
+ return
+ self._queue.put(data)
+ _logger.debug(f"Dataloader loop done. Repeat {_rep}.")
+ self.mark_as_done()
diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py
new file mode 100644
index 00000000000..f343e5b9b4e
--- /dev/null
+++ b/qlib/rl/utils/env_wrapper.py
@@ -0,0 +1,249 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+import weakref
+from typing import Callable, Any, Iterable, Iterator, Generic, cast
+
+import gym
+
+from qlib.rl.aux_info import AuxiliaryInfoCollector
+from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
+from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
+from qlib.rl.reward import Reward
+from qlib.typehint import TypedDict
+
+from .finite_env import generate_nan_observation
+from .log import LogCollector, LogLevel
+
+__all__ = ["InfoDict", "EnvWrapperStatus", "EnvWrapper"]
+
+# in this case, there won't be any seed for simulator
+SEED_INTERATOR_MISSING = "_missing_"
+
+
+class InfoDict(TypedDict):
+ """The type of dict that is used in the 4th return value of ``env.step()``."""
+
+ aux_info: dict
+ """Any information depends on auxiliary info collector."""
+ log: dict[str, Any]
+ """Collected by LogCollector."""
+
+
+class EnvWrapperStatus(TypedDict):
+ """
+ This is the status data structure used in EnvWrapper.
+ The fields here are in the semantics of RL.
+ For example, ``obs`` means the observation fed into policy.
+ ``action`` means the raw action returned by policy.
+ """
+
+ cur_step: int
+ done: bool
+ initial_state: Any | None
+ obs_history: list
+ action_history: list
+ reward_history: list
+
+
+class EnvWrapper(
+ gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
+):
+ """Qlib-based RL environment, subclassing ``gym.Env``.
+ A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
+
+ This is what the framework of simulator - interpreter - policy looks like in RL training.
+ All the components other than policy needs to be assembled into a single object called "environment".
+ The "environment" are replicated into multiple workers, and (at least in tianshou's implementation),
+ one single policy (agent) plays against a batch of environments.
+
+ Parameters
+ ----------
+ simulator_fn
+ A callable that is the simulator factory.
+ When ``seed_iterator`` is present, the factory should take one argument,
+ that is the seed (aka initial state).
+ Otherwise, it should take zero argument.
+ state_interpreter
+ State-observation converter.
+ action_interpreter
+ Policy-simulator action converter.
+ seed_iterator
+ An iterable of seed. With the help of :class:`qlib.rl.utils.DataQueue`,
+ environment workers in different processes can share one ``seed_iterator``.
+ reward_fn
+ A callable that accepts the StateType and returns a float (at least in single-agent case).
+ aux_info_collector
+ Collect auxiliary information. Could be useful in MARL.
+ logger
+ Log collector that collects the logs. The collected logs are sent back to main process,
+ via the return value of ``env.step()``.
+
+ Attributes
+ ----------
+ status : EnvWrapperStatus
+ Status indicator. All terms are in *RL language*.
+ It can be used if users care about data on the RL side.
+ Can be none when no trajectory is available.
+ """
+
+ simulator: Simulator[InitialStateType, StateType, ActType]
+ seed_iterator: str | Iterator[InitialStateType] | None
+
+ def __init__(
+ self,
+ simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
+ state_interpreter: StateInterpreter[StateType, ObsType],
+ action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
+ seed_iterator: Iterable[InitialStateType] | None,
+ reward_fn: Reward | None = None,
+ aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
+ logger: LogCollector | None = None,
+ ):
+ # Assign weak reference to wrapper.
+ #
+ # Use weak reference here, because:
+ # 1. Logically, the other components should be able to live without an env_wrapper.
+ # For example, they might live in a strategy_wrapper in future.
+ # Therefore injecting a "hard" attribute called "env" is not appropripate.
+ # 2. When the environment gets destroyed, it gets destoryed.
+ # We don't want it to silently live inside some interpreters.
+ # 3. Avoid circular reference.
+ # 4. When the components get serialized, we can throw away the env without any burden.
+ # (though this part is not implemented yet)
+ for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
+ if obj is not None:
+ obj.env = weakref.proxy(self) # type: ignore
+
+ self.simulator_fn = simulator_fn
+ self.state_interpreter = state_interpreter
+ self.action_interpreter = action_interpreter
+
+ if seed_iterator is None:
+ # In this case, there won't be any seed for simulator
+ # We can't set it to None because None actually means something else.
+ # If `seed_iterator` is None, it means that it's exhausted.
+ self.seed_iterator = SEED_INTERATOR_MISSING
+ else:
+ self.seed_iterator = iter(seed_iterator)
+ self.reward_fn = reward_fn
+
+ self.aux_info_collector = aux_info_collector
+ self.logger: LogCollector = logger or LogCollector()
+ self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
+
+ @property
+ def action_space(self):
+ return self.action_interpreter.action_space
+
+ @property
+ def observation_space(self):
+ return self.state_interpreter.observation_space
+
+ def reset(self, **kwargs: Any) -> ObsType:
+ """
+ Try to get a state from state queue, and init the simulator with this state.
+ If the queue is exhausted, generate an invalid (nan) observation.
+ """
+
+ try:
+ if self.seed_iterator is None:
+ raise RuntimeError("You can trying to get a state from a dead environment wrapper.")
+
+ # TODO: simulator/observation might need seed to prefetch something
+ # as only seed has the ability to do the work beforehands
+
+ # NOTE: though logger is reset here, logs in this function won't work,
+ # because we can't send them outside.
+ # See https://github.com/thu-ml/tianshou/issues/605
+ self.logger.reset()
+
+ if self.seed_iterator is SEED_INTERATOR_MISSING:
+ # no initial state
+ initial_state = None
+ self.simulator = cast(Callable[[], Simulator], self.simulator_fn)()
+ else:
+ initial_state = next(cast(Iterator[InitialStateType], self.seed_iterator))
+ self.simulator = self.simulator_fn(initial_state)
+
+ self.status = EnvWrapperStatus(
+ cur_step=0,
+ done=False,
+ initial_state=initial_state,
+ obs_history=[],
+ action_history=[],
+ reward_history=[],
+ )
+
+ self.simulator.env = cast(EnvWrapper, weakref.proxy(self))
+
+ sim_state = self.simulator.get_state()
+ obs = self.state_interpreter(sim_state)
+
+ self.status["obs_history"].append(obs)
+
+ return obs
+
+ except StopIteration:
+ # The environment should be recycled because it's in a dead state.
+ self.seed_iterator = None
+ return generate_nan_observation(self.observation_space)
+
+ def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
+ """Environment step.
+
+ See the code along with comments to get a sequence of things happening here.
+ """
+
+ if self.seed_iterator is None:
+ raise RuntimeError("State queue is already exhausted, but the environment is still receiving action.")
+
+ # Clear the logged information from last step
+ self.logger.reset()
+
+ # Action is what we have got from policy
+ self.status["action_history"].append(policy_action)
+ action = self.action_interpreter(self.simulator.get_state(), policy_action)
+
+ # This update must be after action interpreter and before simulator.
+ self.status["cur_step"] += 1
+
+ # Use the converted action of update the simulator
+ self.simulator.step(action)
+
+ # Update "done" first, as this status might be used by reward_fn later
+ done = self.simulator.done()
+ self.status["done"] = done
+
+ # Get state and calculate observation
+ sim_state = self.simulator.get_state()
+ obs = self.state_interpreter(sim_state)
+ self.status["obs_history"].append(obs)
+
+ # Reward and extra info
+ if self.reward_fn is not None:
+ rew = self.reward_fn(sim_state)
+ else:
+ # No reward. Treated as 0.
+ rew = 0.0
+ self.status["reward_history"].append(rew)
+
+ if self.aux_info_collector is not None:
+ aux_info = self.aux_info_collector(sim_state)
+ else:
+ aux_info = {}
+
+ # Final logging stuff: RL-specific logs
+ if done:
+ self.logger.add_scalar("steps_per_episode", self.status["cur_step"])
+ self.logger.add_scalar("reward", rew)
+ self.logger.add_any("obs", obs, loglevel=LogLevel.DEBUG)
+ self.logger.add_any("policy_act", policy_action, loglevel=LogLevel.DEBUG)
+
+ info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
+ return obs, rew, done, info_dict
+
+ def render(self):
+ raise NotImplementedError("Render is not implemented in EnvWrapper.")
diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py
new file mode 100644
index 00000000000..fc9c2c75ea5
--- /dev/null
+++ b/qlib/rl/utils/finite_env.py
@@ -0,0 +1,337 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+This is to support finite env in vector env.
+See https://github.com/thu-ml/tianshou/issues/322 for details.
+"""
+
+from __future__ import annotations
+
+import copy
+import warnings
+from contextlib import contextmanager
+
+import gym
+import numpy as np
+from typing import Any, Set, Callable, Type
+
+from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
+
+from qlib.typehint import Literal
+from .log import LogWriter
+
+__all__ = [
+ "generate_nan_observation",
+ "check_nan_observation",
+ "FiniteVectorEnv",
+ "FiniteDummyVectorEnv",
+ "FiniteSubprocVectorEnv",
+ "FiniteShmemVectorEnv",
+ "FiniteEnvType",
+ "vectorize_env",
+]
+
+
+FiniteEnvType = Literal["dummy", "subproc", "shmem"]
+
+
+def fill_invalid(obj):
+ if isinstance(obj, (int, float, bool)):
+ return fill_invalid(np.array(obj))
+ if hasattr(obj, "dtype"):
+ if isinstance(obj, np.ndarray):
+ if np.issubdtype(obj.dtype, np.floating):
+ return np.full_like(obj, np.nan)
+ return np.full_like(obj, np.iinfo(obj.dtype).max)
+ # dealing with corner cases that numpy number is not supported by tianshou's sharray
+ return fill_invalid(np.array(obj))
+ elif isinstance(obj, dict):
+ return {k: fill_invalid(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [fill_invalid(v) for v in obj]
+ elif isinstance(obj, tuple):
+ return tuple(fill_invalid(v) for v in obj)
+ raise ValueError(f"Unsupported value to fill with invalid: {obj}")
+
+
+def is_invalid(arr):
+ if hasattr(arr, "dtype"):
+ if np.issubdtype(arr.dtype, np.floating):
+ return np.isnan(arr).all()
+ return (np.iinfo(arr.dtype).max == arr).all()
+ if isinstance(arr, dict):
+ return all(is_invalid(o) for o in arr.values())
+ if isinstance(arr, (list, tuple)):
+ return all(is_invalid(o) for o in arr)
+ if isinstance(arr, (int, float, bool, np.number)):
+ return is_invalid(np.array(arr))
+ return True
+
+
+def generate_nan_observation(obs_space: gym.Space) -> Any:
+ """The NaN observation that indicates the environment receives no seed.
+
+ We assume that obs is complex and there must be something like float.
+ Otherwise this logic doesn't work.
+ """
+
+ sample = obs_space.sample()
+ sample = fill_invalid(sample)
+ return sample
+
+
+def check_nan_observation(obs: Any) -> bool:
+ """Check whether obs is generated by :func:`generate_nan_observation`."""
+ return is_invalid(obs)
+
+
+class FiniteVectorEnv(BaseVectorEnv):
+ """To allow the paralleled env workers consume a single DataQueue until it's exhausted.
+
+ See `tianshou issue #322 `_.
+
+ The requirement is to make every possible seed (stored in :class:`qlib.rl.utils.DataQueue` in our case)
+ consumed by exactly one environment. This is not possible by tianshou's native VectorEnv and Collector,
+ because tianshou is unaware of this "exactly one" constraint, and might launch extra workers.
+
+ Consider a corner case, where concurrency is 2, but there is only one seed in DataQueue.
+ The reset of two workers must be both called according to the logic in collect.
+ The returned results of two workers are collected, regardless of what they are.
+ The problem is, one of the reset result must be invalid, or repeated,
+ because there's only one need in queue, and collector isn't aware of such situation.
+
+ Luckily, we can hack the vector env, and make a protocol between single env and vector env.
+ The single environment (should be :class:`qlib.rl.utils.EnvWrapper` in our case) is responsible for
+ reading from queue, and generate a special observation when the queue is exhausted. The special obs
+ is called "nan observation", because simply using none causes problems in shared-memory vector env.
+ :class:`FiniteVectorEnv` then read the observations from all workers, and select those non-nan
+ observation. It also maintains an ``_alive_env_ids`` to track which workers should never be
+ called again. When also the environments are exhausted, it will raise StopIteration exception.
+
+ The usage of this vector env in collector are two parts:
+
+ 1. If the data queue is finite (usually when inference), collector should collect "infinity" number of
+ episodes, until the vector env exhausts by itself.
+ 2. If the data queue is infinite (usually in training), collector can set number of episodes / steps.
+ In this case, data would be randomly ordered, and some repetitions wouldn't matter.
+
+ One extra function of this vector env is that it has a logger that explicitly collects logs
+ from child workers. See :class:`qlib.rl.utils.LogWriter`.
+ """
+
+ def __init__(
+ self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
+ ) -> None:
+ super().__init__(env_fns, **kwargs)
+
+ self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
+ self._alive_env_ids: Set[int] = set()
+ self._reset_alive_envs()
+ self._default_obs = self._default_info = self._default_rew = None
+ self._zombie = False
+
+ self._collector_guarded: bool = False
+
+ def _reset_alive_envs(self):
+ if not self._alive_env_ids:
+ # starting or running out
+ self._alive_env_ids = set(range(self.env_num))
+
+ # to workaround with tianshou's buffer and batch
+ def _set_default_obs(self, obs):
+ if obs is not None and self._default_obs is None:
+ self._default_obs = copy.deepcopy(obs)
+
+ def _set_default_info(self, info):
+ if info is not None and self._default_info is None:
+ self._default_info = copy.deepcopy(info)
+
+ def _set_default_rew(self, rew):
+ if rew is not None and self._default_rew is None:
+ self._default_rew = copy.deepcopy(rew)
+
+ def _get_default_obs(self):
+ return copy.deepcopy(self._default_obs)
+
+ def _get_default_info(self):
+ return copy.deepcopy(self._default_info)
+
+ def _get_default_rew(self):
+ return copy.deepcopy(self._default_rew)
+
+ # END
+
+ @staticmethod
+ def _postproc_env_obs(obs):
+ # reserved for shmem vector env to restore empty observation
+ if obs is None or check_nan_observation(obs):
+ return None
+ return obs
+
+ @contextmanager
+ def collector_guard(self):
+ """Guard the collector. Recommended to guard every collect.
+
+ This guard is for two purposes.
+
+ 1. Catch and ignore the StopIteration exception, which is the stopping signal
+ thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
+ 2. Notify the loggers that the collect is done what it's done.
+
+ Examples
+ --------
+ >>> with finite_env.collector_guard():
+ ... collector.collect(n_episode=INF)
+ """
+ self._collector_guarded = True
+
+ try:
+ yield self
+ except StopIteration:
+ pass
+ finally:
+ self._collector_guarded = False
+
+ # At last trigger the loggers
+ for logger in self._logger:
+ logger.on_env_all_done()
+
+ def reset(self, id=None):
+ assert not self._zombie
+
+ # Check whether it's guarded by collector_guard()
+ if not self._collector_guarded:
+ warnings.warn(
+ "Collector is not guarded by FiniteEnv. "
+ "This may cause unexpected problems, like unexpected StopIteration exception, "
+ "or missing logs.",
+ RuntimeWarning,
+ )
+
+ id = self._wrap_id(id)
+ self._reset_alive_envs()
+
+ # ask super to reset alive envs and remap to current index
+ request_id = list(filter(lambda i: i in self._alive_env_ids, id))
+ obs = [None] * len(id)
+ id2idx = {i: k for k, i in enumerate(id)}
+ if request_id:
+ for i, o in zip(request_id, super().reset(request_id)):
+ obs[id2idx[i]] = self._postproc_env_obs(o)
+
+ for i, o in zip(id, obs):
+ if o is None and i in self._alive_env_ids:
+ self._alive_env_ids.remove(i)
+
+ # logging
+ for i, o in zip(id, obs):
+ if i in self._alive_env_ids:
+ for logger in self._logger:
+ logger.on_env_reset(i, obs)
+
+ # fill empty observation with default(fake) observation
+ for o in obs:
+ self._set_default_obs(o)
+ for i, o in enumerate(obs):
+ if o is None:
+ obs[i] = self._get_default_obs()
+
+ if not self._alive_env_ids:
+ # comment this line so that the env becomes indisposable
+ # self.reset()
+ self._zombie = True
+ raise StopIteration
+
+ return np.stack(obs)
+
+ def step(self, action, id=None):
+ assert not self._zombie
+ id = self._wrap_id(id)
+ id2idx = {i: k for k, i in enumerate(id)}
+ request_id = list(filter(lambda i: i in self._alive_env_ids, id))
+ result = [[None, None, False, None] for _ in range(len(id))]
+
+ # ask super to step alive envs and remap to current index
+ if request_id:
+ valid_act = np.stack([action[id2idx[i]] for i in request_id])
+ for i, r in zip(request_id, zip(*super().step(valid_act, request_id))):
+ result[id2idx[i]] = list(r)
+ result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
+
+ # logging
+ for i, r in zip(id, result):
+ if i in self._alive_env_ids:
+ for logger in self._logger:
+ logger.on_env_step(i, *r)
+
+ # fill empty observation/info with default(fake)
+ for _, r, ___, i in result:
+ self._set_default_info(i)
+ self._set_default_rew(r)
+ for i, r in enumerate(result):
+ if r[0] is None:
+ result[i][0] = self._get_default_obs()
+ if r[1] is None:
+ result[i][1] = self._get_default_rew()
+ if r[3] is None:
+ result[i][3] = self._get_default_info()
+
+ return list(map(np.stack, zip(*result)))
+
+
+class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
+ pass
+
+
+class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
+ pass
+
+
+class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
+ pass
+
+
+def vectorize_env(
+ env_factory: Callable[..., gym.Env],
+ env_type: FiniteEnvType,
+ concurrency: int,
+ logger: LogWriter | list[LogWriter],
+) -> FiniteVectorEnv:
+ """Helper function to create a vector env.
+
+ Parameters
+ ----------
+ env_factory
+ Callable to instantiate one single ``gym.Env``.
+ All concurrent workers will have the same ``env_factory``.
+ env_type
+ dummy or subproc or shmem. Corresponding to
+ `parallelism in tianshou `_.
+ concurrency
+ Concurrent environment workers.
+ logger
+ Log writers.
+
+ Warnings
+ --------
+ Please do not use lambda expression here for ``env_factory`` as it may create incorrectly-shared instances.
+
+ Don't do: ::
+
+ vectorize_env(lambda: EnvWrapper(...), ...)
+
+ Please do: ::
+
+ def env_factory(): ...
+ vectorize_env(env_factory, ...)
+ """
+ env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
+ "dummy": FiniteDummyVectorEnv,
+ "subproc": FiniteSubprocVectorEnv,
+ "shmem": FiniteShmemVectorEnv,
+ }
+
+ finite_env_cls = env_type_cls_mapping[env_type]
+
+ return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py
new file mode 100644
index 00000000000..3d495b11dac
--- /dev/null
+++ b/qlib/rl/utils/log.py
@@ -0,0 +1,398 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Distributed logger for RL.
+
+:class:`LogCollector` runs in every environment workers. It collects log info from simulator states,
+and add them (as a dict) to auxiliary info returned for each step.
+
+:class:`LogWriter` runs in the central worker. It decodes the dict collected by :class:`LogCollector`
+in each worker, and writes them to console, log files, or tensorboard...
+
+The two modules communicate by the "log" field in "info" returned by ``env.step()``.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections import defaultdict
+from enum import IntEnum
+from pathlib import Path
+from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
+
+import numpy as np
+import pandas as pd
+
+from qlib.log import get_module_logger
+
+if TYPE_CHECKING:
+ from .env_wrapper import InfoDict
+
+
+__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
+
+ObsType = TypeVar("ObsType")
+ActType = TypeVar("ActType")
+
+
+class LogLevel(IntEnum):
+ """Log-levels for RL training.
+ The behavior of handling each log level depends on the implementation of :class:`LogWriter`.
+ """
+
+ DEBUG = 10
+ """If you only want to see the metric in debug mode."""
+ PERIODIC = 20
+ """If you want to see the metric periodically."""
+ # FIXME: I haven't given much thought about this. Let's hold it for one iteration.
+
+ INFO = 30
+ """Important log messages."""
+ CRITICAL = 40
+ """LogWriter should always handle CRITICAL messages"""
+
+
+class LogCollector:
+ """Logs are first collected in each environment worker,
+ and then aggregated to stream at the central thread in vector env.
+
+ In :class:`LogCollector`, every metric is added to a dict, which needs to be ``reset()`` at each step.
+ The dict is sent via the ``info`` in ``env.step()``, and decoded by the :class:`LogWriter` at vector env.
+
+ ``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
+ """
+
+ _logged: dict[str, tuple[int, Any]]
+ _min_loglevel: int
+
+ def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
+ self._min_loglevel = int(min_loglevel)
+
+ def reset(self):
+ """Clear all collected contents."""
+ self._logged = {}
+
+ def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
+ if name in self._logged:
+ raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
+ self._logged[name] = (int(loglevel), metric)
+
+ def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
+ """Add a string with name into logged contents."""
+ if loglevel < self._min_loglevel:
+ return
+ if not isinstance(string, str):
+ raise TypeError(f"{string} is not a string.")
+ self._add_metric(name, string, loglevel)
+
+ def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
+ """Add a scalar with name into logged contents.
+ Scalar will be converted into a float.
+ """
+ if loglevel < self._min_loglevel:
+ return
+
+ if hasattr(scalar, "item"):
+ # could be single-item number
+ scalar = scalar.item()
+ if not isinstance(scalar, (float, int)):
+ raise TypeError(f"{scalar} is not and can not be converted into float or integer.")
+ scalar = float(scalar)
+ self._add_metric(name, scalar, loglevel)
+
+ def add_array(
+ self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
+ ) -> None:
+ """Add an array with name into logging."""
+ if loglevel < self._min_loglevel:
+ return
+
+ if not isinstance(array, (np.ndarray, pd.DataFrame, pd.Series)):
+ raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
+ self._add_metric(name, array, loglevel)
+
+ def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
+ """Log something with any type.
+
+ As it's an "any" object, the only LogWriter accepting it is pickle.
+ Therefore pickle must be able to serialize it.
+ """
+ if loglevel < self._min_loglevel:
+ return
+
+ # FIXME: detect and rescue object that could be scalar or array
+
+ self._add_metric(name, obj, loglevel)
+
+ def logs(self) -> dict[str, np.ndarray]:
+ return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
+
+
+class LogWriter(Generic[ObsType, ActType]):
+ """Base class for log writers, triggered at every reset and step by finite env.
+
+ What to do with a specific log depends on the implementation of subclassing :class:`LogWriter`.
+ The general principle is that, it should handle logs above its loglevel (inclusive),
+ and discard logs that are not acceptable. For instance, console loggers obviously can't handle an image.
+ """
+
+ episode_count: int
+ """Counter of episodes."""
+
+ step_count: int
+ """Counter of steps."""
+
+ global_step: int
+ """Counter of steps. Won"t be cleared in ``clear``."""
+
+ global_episode: int
+ """Counter of episodes. Won"t be cleared in ``clear``."""
+
+ active_env_ids: Set[int]
+ """Active environment ids in vector env."""
+
+ episode_lengths: dict[int, int]
+ """Map from environment id to episode length."""
+
+ episode_rewards: dict[int, list[float]]
+ """Map from environment id to episode total reward."""
+
+ episode_logs: dict[int, list]
+ """Map from environment id to episode logs."""
+
+ def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
+ self.loglevel = loglevel
+
+ self.global_step = 0
+ self.global_episode = 0
+
+ # Information, logs of one episode is stored here.
+ # This assumes that episode is not too long to fit into the memory.
+ self.episode_lengths = dict()
+ self.episode_rewards = dict()
+ self.episode_logs = dict()
+
+ self.clear()
+
+ def clear(self):
+ self.episode_count = self.step_count = 0
+ self.active_env_ids = set()
+ self.logs = []
+
+ def aggregation(self, array: Sequence[Any]) -> Any:
+ """Aggregation function from step-wise to episode-wise.
+
+ If it's a sequence of float, take the mean.
+ Otherwise, take the first element.
+ """
+ assert len(array) > 0, "The aggregated array must be not empty."
+ if all(isinstance(v, float) for v in array):
+ return np.mean(array)
+ else:
+ return array[0]
+
+ def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
+ """This is triggered at the end of each trajectory.
+
+ Parameters
+ ----------
+ length
+ Length of this trajectory.
+ rewards
+ A list of rewards at each step of this episode.
+ contents
+ Logged contents for every steps.
+ """
+
+ def log_step(self, reward: float, contents: dict[str, Any]) -> None:
+ """This is triggered at each step.
+
+ Parameters
+ ----------
+ reward
+ Reward for this step.
+ contents
+ Logged contents for this step.
+ """
+
+ def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) -> None:
+ """Callback for finite env, on each step."""
+
+ # Update counter
+ self.global_step += 1
+ self.step_count += 1
+
+ self.active_env_ids.add(env_id)
+ self.episode_lengths[env_id] += 1
+ # TODO: reward can be a list of list for MARL
+ self.episode_rewards[env_id].append(rew)
+
+ values: dict[str, Any] = {}
+
+ for key, (loglevel, value) in info["log"].items():
+ if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
+ values[key] = value
+ self.episode_logs[env_id].append(values)
+
+ self.log_step(rew, values)
+
+ if done:
+ # Update counter
+ self.global_episode += 1
+ self.episode_count += 1
+
+ self.log_episode(self.episode_lengths[env_id], self.episode_rewards[env_id], self.episode_logs[env_id])
+
+ def on_env_reset(self, env_id: int, obs: ObsType) -> None:
+ """Callback for finite env.
+
+ Reset episode statistics. Nothing task-specific is logged here because of
+ `a limitation of tianshou `__.
+ """
+ self.episode_lengths[env_id] = 0
+ self.episode_rewards[env_id] = []
+ self.episode_logs[env_id] = []
+
+ def on_env_all_done(self) -> None:
+ """All done. Time for cleanup."""
+
+
+class ConsoleWriter(LogWriter):
+ """Write log messages to console periodically.
+
+ It tracks an average meter for each metric, which is the average value since last ``clear()`` till now.
+ The display format for each metric is `` ()``.
+
+ Non-single-number metrics are auto skipped.
+ """
+
+ prefix: str
+ """Prefix can be set via ``writer.prefix``."""
+
+ def __init__(
+ self,
+ log_every_n_episode: int = 20,
+ total_episodes: int | None = None,
+ float_format: str = ":.4f",
+ counter_format: str = ":4d",
+ loglevel: int | LogLevel = LogLevel.PERIODIC,
+ ):
+ super().__init__(loglevel)
+ # TODO: support log_every_n_step
+ self.log_every_n_episode = log_every_n_episode
+ self.total_episodes = total_episodes
+
+ self.counter_format = counter_format
+ self.float_format = float_format
+
+ self.prefix = ""
+
+ self.console_logger = get_module_logger(__name__, level=logging.INFO)
+
+ def clear(self):
+ super().clear()
+ # Clear average meters
+ self.metric_counts: dict[str, int] = defaultdict(int)
+ self.metric_sums: dict[str, float] = defaultdict(float)
+
+ def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
+ # Aggregate step-wise to episode-wise
+ episode_wise_contents: dict[str, list] = defaultdict(list)
+
+ for step_contents in contents:
+ for name, value in step_contents.items():
+ if isinstance(value, float):
+ episode_wise_contents[name].append(value)
+
+ # Generate log contents and track them in average-meter.
+ # This should be done at every step, regardless of periodic or not.
+ logs: dict[str, float] = {}
+ for name, values in episode_wise_contents.items():
+ logs[name] = self.aggregation(values) # type: ignore
+
+ for name, value in logs.items():
+ self.metric_counts[name] += 1
+ self.metric_sums[name] += value
+
+ if self.episode_count % self.log_every_n_episode == 0 or self.episode_count == self.total_episodes:
+ # Only log periodically or at the end
+ self.console_logger.info(self.generate_log_message(logs))
+
+ def generate_log_message(self, logs: dict[str, float]) -> str:
+ if self.prefix:
+ msg_prefix = self.prefix + " "
+ else:
+ msg_prefix = ""
+ if self.total_episodes is None:
+ msg_prefix += "[Step {" + self.counter_format + "}]"
+ else:
+ msg_prefix += "[{" + self.counter_format + "}/" + str(self.total_episodes) + "]"
+ msg_prefix = msg_prefix.format(self.episode_count)
+
+ msg = ""
+ for name, value in logs.items():
+ # Double-space as delimiter
+ format_template = r" {} {" + self.float_format + "} ({" + self.float_format + "})"
+ msg += format_template.format(name, value, self.metric_sums[name] / self.metric_counts[name])
+
+ msg = msg_prefix + " " + msg
+
+ return msg
+
+
+class CsvWriter(LogWriter):
+ """Dump all episode metrics to a ``result.csv``.
+
+ This is not the correct implementation. It's only used for first iteration.
+ """
+
+ SUPPORTED_TYPES = (float, str, pd.Timestamp)
+
+ all_records: list[dict[str, Any]]
+
+ def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
+ super().__init__(loglevel)
+ self.output_dir = output_dir
+ self.output_dir.mkdir(exist_ok=True)
+
+ def clear(self):
+ super().clear()
+ self.all_records = []
+
+ def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
+ # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
+ episode_wise_contents: dict[str, list] = defaultdict(list)
+
+ for step_contents in contents:
+ for name, value in step_contents.items():
+ if isinstance(value, self.SUPPORTED_TYPES):
+ episode_wise_contents[name].append(value)
+
+ logs: dict[str, float] = {}
+ for name, values in episode_wise_contents.items():
+ logs[name] = self.aggregation(values) # type: ignore
+
+ self.all_records.append(logs)
+
+ def on_env_all_done(self) -> None:
+ # FIXME: this is temporary
+ pd.DataFrame.from_records(self.all_records).to_csv(self.output_dir / "result.csv", index=False)
+
+
+# The following are not implemented yet.
+
+
+class PickleWriter(LogWriter):
+ """Dump logs to pickle files."""
+
+
+class TensorboardWriter(LogWriter):
+ """Write logs to event files that can be visualized with tensorboard."""
+
+
+class MlflowWriter(LogWriter):
+ """Add logs to mlflow."""
+
+
+class LogBuffer(LogWriter):
+ """Keep everything in memory."""
diff --git a/qlib/typehint.py b/qlib/typehint.py
new file mode 100644
index 00000000000..3a584b4b751
--- /dev/null
+++ b/qlib/typehint.py
@@ -0,0 +1,13 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Commonly used types."""
+
+import sys
+
+__all__ = ["Literal", "TypedDict", "final"]
+
+if sys.version_info >= (3, 8):
+ from typing import Literal, TypedDict, final # type: ignore # pylint: disable=no-name-in-module
+else:
+ from typing_extensions import Literal, TypedDict, final
diff --git a/setup.py b/setup.py
index 8780e8be73c..d95f59a250a 100644
--- a/setup.py
+++ b/setup.py
@@ -134,7 +134,12 @@ def get_version(rel_path: str) -> str:
"sphinx",
"sphinx_rtd_theme",
"pre-commit",
- ]
+ ],
+ "rl": [
+ "tianshou",
+ "gym",
+ "torch",
+ ],
},
include_package_data=True,
classifiers=[
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000000..7a869da16ee
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,10 @@
+import os
+import sys
+
+"""Ignore RL tests on non-linux platform."""
+collect_ignore = []
+
+if sys.platform != "linux":
+ for root, dirs, files in os.walk("rl"):
+ for file in files:
+ collect_ignore.append(os.path.join(root, file))
diff --git a/tests/pytest.ini b/tests/pytest.ini
new file mode 100644
index 00000000000..2ea41da8c39
--- /dev/null
+++ b/tests/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+filterwarnings =
+ ignore:.*rng.randint:DeprecationWarning
+ ignore:.*Casting input x to numpy array:UserWarning
diff --git a/tests/rl/test_data_queue.py b/tests/rl/test_data_queue.py
new file mode 100644
index 00000000000..0b0c61280de
--- /dev/null
+++ b/tests/rl/test_data_queue.py
@@ -0,0 +1,88 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import multiprocessing
+import time
+
+import numpy as np
+import pandas as pd
+
+from torch.utils.data import Dataset, DataLoader
+from qlib.rl.utils.data_queue import DataQueue
+
+
+class DummyDataset(Dataset):
+ def __init__(self, length):
+ self.length = length
+
+ def __getitem__(self, index):
+ assert 0 <= index < self.length
+ return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD"))
+
+ def __len__(self):
+ return self.length
+
+
+def _worker(dataloader, collector):
+ # for i in range(3):
+ for i, data in enumerate(dataloader):
+ collector.put(len(data))
+
+
+def _queue_to_list(queue):
+ result = []
+ while not queue.empty():
+ result.append(queue.get())
+ return result
+
+
+def test_pytorch_dataloader():
+ dataset = DummyDataset(100)
+ dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
+ queue = multiprocessing.Queue()
+ _worker(dataloader, queue)
+ assert len(set(_queue_to_list(queue))) == 100
+
+
+def test_multiprocess_shared_dataloader():
+ dataset = DummyDataset(100)
+ with DataQueue(dataset, producer_num_workers=1) as data_queue:
+ queue = multiprocessing.Queue()
+ processes = []
+ for _ in range(3):
+ processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))
+ processes[-1].start()
+ for p in processes:
+ p.join()
+ assert len(set(_queue_to_list(queue))) == 100
+
+
+def test_exit_on_crash_finite():
+ def _exit_finite():
+ dataset = DummyDataset(100)
+
+ with DataQueue(dataset, producer_num_workers=4) as data_queue:
+ time.sleep(3)
+ raise ValueError
+
+ # https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
+
+ process = multiprocessing.Process(target=_exit_finite)
+ process.start()
+ process.join()
+
+
+def test_exit_on_crash_infinite():
+ def _exit_infinite():
+ dataset = DummyDataset(100)
+ with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:
+ time.sleep(3)
+ raise ValueError
+
+ process = multiprocessing.Process(target=_exit_infinite)
+ process.start()
+ process.join()
+
+
+if __name__ == "__main__":
+ test_multiprocess_shared_dataloader()
diff --git a/tests/rl/test_finite_env.py b/tests/rl/test_finite_env.py
new file mode 100644
index 00000000000..d6f2a2ec959
--- /dev/null
+++ b/tests/rl/test_finite_env.py
@@ -0,0 +1,249 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from collections import Counter
+
+import gym
+import numpy as np
+from tianshou.data import Batch, Collector
+from tianshou.policy import BasePolicy
+from torch.utils.data import DataLoader, Dataset, DistributedSampler
+from qlib.rl.utils.finite_env import (
+ LogWriter,
+ FiniteDummyVectorEnv,
+ FiniteShmemVectorEnv,
+ FiniteSubprocVectorEnv,
+ check_nan_observation,
+ generate_nan_observation,
+)
+
+
+_test_space = gym.spaces.Dict(
+ {
+ "sensors": gym.spaces.Dict(
+ {
+ "position": gym.spaces.Box(low=-100, high=100, shape=(3,)),
+ "velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)),
+ "front_cam": gym.spaces.Tuple(
+ (gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)))
+ ),
+ "rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
+ }
+ ),
+ "ext_controller": gym.spaces.MultiDiscrete((5, 2, 2)),
+ "inner_state": gym.spaces.Dict(
+ {
+ "charge": gym.spaces.Discrete(100),
+ "system_checks": gym.spaces.MultiBinary(10),
+ "job_status": gym.spaces.Dict(
+ {
+ "task": gym.spaces.Discrete(5),
+ "progress": gym.spaces.Box(low=0, high=100, shape=()),
+ }
+ ),
+ }
+ ),
+ }
+)
+
+
+class FiniteEnv(gym.Env):
+ def __init__(self, dataset, num_replicas, rank):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
+ self.iterator = None
+ self.observation_space = gym.spaces.Discrete(255)
+ self.action_space = gym.spaces.Discrete(2)
+
+ def reset(self):
+ if self.iterator is None:
+ self.iterator = iter(self.loader)
+ try:
+ self.current_sample, self.step_count = next(self.iterator)
+ self.current_step = 0
+ return self.current_sample
+ except StopIteration:
+ self.iterator = None
+ return generate_nan_observation(self.observation_space)
+
+ def step(self, action):
+ self.current_step += 1
+ assert self.current_step <= self.step_count
+ return (
+ 0,
+ 1.0,
+ self.current_step >= self.step_count,
+ {"sample": self.current_sample, "action": action, "metric": 2.0},
+ )
+
+
+class FiniteEnvWithComplexObs(FiniteEnv):
+ def __init__(self, dataset, num_replicas, rank):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
+ self.iterator = None
+ self.observation_space = gym.spaces.Discrete(255)
+ self.action_space = gym.spaces.Discrete(2)
+
+ def reset(self):
+ if self.iterator is None:
+ self.iterator = iter(self.loader)
+ try:
+ self.current_sample, self.step_count = next(self.iterator)
+ self.current_step = 0
+ return _test_space.sample()
+ except StopIteration:
+ self.iterator = None
+ return generate_nan_observation(self.observation_space)
+
+ def step(self, action):
+ self.current_step += 1
+ assert self.current_step <= self.step_count
+ return (
+ _test_space.sample(),
+ 1.0,
+ self.current_step >= self.step_count,
+ {"sample": _test_space.sample(), "action": action, "metric": 2.0},
+ )
+
+
+class DummyDataset(Dataset):
+ def __init__(self, length):
+ self.length = length
+ self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
+
+ def __getitem__(self, index):
+ assert 0 <= index < self.length
+ return index, self.episodes[index]
+
+ def __len__(self):
+ return self.length
+
+
+class AnyPolicy(BasePolicy):
+ def forward(self, batch, state=None):
+ return Batch(act=np.stack([1] * len(batch)))
+
+ def learn(self, batch):
+ pass
+
+
+def _finite_env_factory(dataset, num_replicas, rank, complex=False):
+ if complex:
+ return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank)
+ return lambda: FiniteEnv(dataset, num_replicas, rank)
+
+
+class MetricTracker(LogWriter):
+ def __init__(self, length):
+ super().__init__()
+ self.counter = Counter()
+ self.finished = set()
+ self.length = length
+
+ def on_env_step(self, env_id, obs, rew, done, info):
+ assert rew == 1.0
+ index = info["sample"]
+ if done:
+ # assert index not in self.finished
+ self.finished.add(index)
+ self.counter[index] += 1
+
+ def validate(self):
+ assert len(self.finished) == self.length
+ for k, v in self.counter.items():
+ assert v == k * 3 % 5 + 1
+
+
+class DoNothingTracker(LogWriter):
+ def on_env_step(self, *args, **kwargs):
+ pass
+
+
+def test_finite_dummy_vector_env():
+ length = 100
+ dataset = DummyDataset(length)
+ envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
+ envs._collector_guarded = True
+ policy = AnyPolicy()
+ test_collector = Collector(policy, envs, exploration_noise=True)
+
+ for _ in range(1):
+ envs._logger = [MetricTracker(length)]
+ try:
+ test_collector.collect(n_step=10**18)
+ except StopIteration:
+ envs._logger[0].validate()
+
+
+def test_finite_shmem_vector_env():
+ length = 100
+ dataset = DummyDataset(length)
+ envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
+ envs._collector_guarded = True
+ policy = AnyPolicy()
+ test_collector = Collector(policy, envs, exploration_noise=True)
+
+ for _ in range(1):
+ envs._logger = [MetricTracker(length)]
+ try:
+ test_collector.collect(n_step=10**18)
+ except StopIteration:
+ envs._logger[0].validate()
+
+
+def test_finite_subproc_vector_env():
+ length = 100
+ dataset = DummyDataset(length)
+ envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
+ envs._collector_guarded = True
+ policy = AnyPolicy()
+ test_collector = Collector(policy, envs, exploration_noise=True)
+
+ for _ in range(1):
+ envs._logger = [MetricTracker(length)]
+ try:
+ test_collector.collect(n_step=10**18)
+ except StopIteration:
+ envs._logger[0].validate()
+
+
+def test_nan():
+ assert check_nan_observation(generate_nan_observation(_test_space))
+ assert not check_nan_observation(_test_space.sample())
+
+
+def test_finite_dummy_vector_env_complex():
+ length = 100
+ dataset = DummyDataset(length)
+ envs = FiniteDummyVectorEnv(
+ DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
+ )
+ envs._collector_guarded = True
+ policy = AnyPolicy()
+ test_collector = Collector(policy, envs, exploration_noise=True)
+
+ try:
+ test_collector.collect(n_step=10**18)
+ except StopIteration:
+ pass
+
+
+def test_finite_shmem_vector_env_complex():
+ length = 100
+ dataset = DummyDataset(length)
+ envs = FiniteShmemVectorEnv(
+ DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
+ )
+ envs._collector_guarded = True
+ policy = AnyPolicy()
+ test_collector = Collector(policy, envs, exploration_noise=True)
+
+ try:
+ test_collector.collect(n_step=10**18)
+ except StopIteration:
+ pass
diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py
new file mode 100644
index 00000000000..240ffc1e142
--- /dev/null
+++ b/tests/rl/test_logger.py
@@ -0,0 +1,156 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from random import randint, choice
+from pathlib import Path
+
+import re
+import gym
+import numpy as np
+import pandas as pd
+from gym import spaces
+from tianshou.data import Collector, Batch
+from tianshou.policy import BasePolicy
+
+from qlib.log import set_log_with_config
+from qlib.config import C
+from qlib.constant import INF
+from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
+from qlib.rl.simulator import Simulator
+from qlib.rl.utils.data_queue import DataQueue
+from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper
+from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter
+from qlib.rl.utils.finite_env import vectorize_env
+
+
+class SimpleEnv(gym.Env[int, int]):
+ def __init__(self):
+ self.logger = LogCollector()
+ self.observation_space = gym.spaces.Discrete(2)
+ self.action_space = gym.spaces.Discrete(2)
+
+ def reset(self):
+ self.step_count = 0
+ return 0
+
+ def step(self, action: int):
+ self.logger.reset()
+
+ self.logger.add_scalar("reward", 42.0)
+
+ self.logger.add_scalar("a", randint(1, 10))
+ self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
+
+ if self.step_count >= 3:
+ done = choice([False, True])
+ else:
+ done = False
+
+ if 2 <= self.step_count <= 3:
+ self.logger.add_scalar("c", randint(11, 20))
+
+ self.step_count += 1
+
+ return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
+
+
+class AnyPolicy(BasePolicy):
+ def forward(self, batch, state=None):
+ return Batch(act=np.stack([1] * len(batch)))
+
+ def learn(self, batch):
+ pass
+
+
+def test_simple_env_logger(caplog):
+ set_log_with_config(C.logging_config)
+ for venv_cls_name in ["dummy", "shmem", "subproc"]:
+ writer = ConsoleWriter()
+ csv_writer = CsvWriter(Path(__file__).parent / ".output")
+ venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])
+ with venv.collector_guard():
+ collector = Collector(AnyPolicy(), venv)
+ collector.collect(n_episode=30)
+
+ output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
+ assert output_file.columns.tolist() == ["reward", "a", "c"]
+ assert len(output_file) >= 30
+
+ line_counter = 0
+ for line in caplog.text.splitlines():
+ line = line.strip()
+ if line:
+ line_counter += 1
+ assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
+ assert line_counter >= 3
+
+
+class SimpleSimulator(Simulator[int, float, float]):
+ def __init__(self, initial: int, **kwargs) -> None:
+ self.initial = float(initial)
+
+ def step(self, action: float) -> None:
+ import torch
+
+ self.initial += action
+ self.env.logger.add_scalar("test_a", torch.tensor(233.0))
+ self.env.logger.add_scalar("test_b", np.array(200))
+
+ def get_state(self) -> float:
+ return self.initial
+
+ def done(self) -> bool:
+ return self.initial % 1 > 0.5
+
+
+class DummyStateInterpreter(StateInterpreter[float, float]):
+ def interpret(self, state: float) -> float:
+ return state
+
+ @property
+ def observation_space(self) -> spaces.Box:
+ return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
+
+
+class DummyActionInterpreter(ActionInterpreter[float, int, float]):
+ def interpret(self, state: float, action: int) -> float:
+ return action / 100
+
+ @property
+ def action_space(self) -> spaces.Box:
+ return spaces.Discrete(5)
+
+
+class RandomFivePolicy(BasePolicy):
+ def forward(self, batch, state=None):
+ return Batch(act=np.random.randint(5, size=len(batch)))
+
+ def learn(self, batch):
+ pass
+
+
+def test_logger_with_env_wrapper():
+ with DataQueue(list(range(20)), shuffle=False) as data_iterator:
+ env_wrapper_factory = lambda: EnvWrapper(
+ SimpleSimulator,
+ DummyStateInterpreter(),
+ DummyActionInterpreter(),
+ data_iterator,
+ logger=LogCollector(LogLevel.DEBUG),
+ )
+
+ # loglevel can be debug here because metrics can all dump into csv
+ # otherwise, csv writer might crash
+ csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
+ venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
+ with venv.collector_guard():
+ collector = Collector(RandomFivePolicy(), venv)
+ collector.collect(n_episode=INF * len(venv))
+
+ output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
+ assert len(output_df) == 20
+ # obs has a increasing trend
+ assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
+ assert (output_df["test_a"] == 233).all()
+ assert (output_df["test_b"] == 200).all()
+ assert "steps_per_episode" in output_df and "reward" in output_df
diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py
new file mode 100644
index 00000000000..2ac0d9cbd3c
--- /dev/null
+++ b/tests/rl/test_saoe_simple.py
@@ -0,0 +1,308 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import sys
+from functools import partial
+from pathlib import Path
+from typing import NamedTuple
+
+import numpy as np
+import pandas as pd
+import pytest
+
+import torch
+from tianshou.data import Batch
+
+from qlib.backtest import Order
+from qlib.config import C
+from qlib.log import set_log_with_config
+from qlib.rl.data import pickle_styled
+from qlib.rl.entries.test import backtest
+from qlib.rl.order_execution import *
+from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
+
+pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
+
+
+DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "intraday_saoe"
+DATA_DIR = DATA_ROOT_DIR / "us"
+BACKTEST_DATA_DIR = DATA_DIR / "backtest"
+FEATURE_DATA_DIR = DATA_DIR / "processed"
+ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
+
+CN_DATA_DIR = DATA_ROOT_DIR / "cn"
+CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
+CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
+CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
+CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
+
+
+def test_pickle_data_inspect():
+ data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
+ assert len(data) == 390
+
+ data = pickle_styled.load_intraday_processed_data(
+ DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index()
+ )
+ assert len(data.today) == len(data.yesterday) == 390
+
+
+def test_simulator_first_step():
+ order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
+
+ simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
+ state = simulator.get_state()
+ assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
+ assert state.position == 30.0
+
+ simulator.step(15.0)
+ state = simulator.get_state()
+ assert len(state.history_exec) == 30
+ assert state.history_exec.index[0] == pd.Timestamp("2013-12-11 09:30:00")
+ assert state.history_exec["market_volume"].iloc[0] == 450072.0
+ assert abs(state.history_exec["market_price"].iloc[0] - 25.370001) < 1e-4
+ assert (state.history_exec["amount"] == 0.5).all()
+ assert (state.history_exec["deal_amount"] == 0.5).all()
+ assert abs(state.history_exec["trade_price"].iloc[0] - 25.370001) < 1e-4
+ assert abs(state.history_exec["trade_value"].iloc[0] - 12.68500) < 1e-4
+ assert state.history_exec["position"].iloc[0] == 29.5
+ assert state.history_exec["ffr"].iloc[0] == 1 / 60
+
+ assert state.history_steps["market_volume"].iloc[0] == 5041147.0
+ assert state.history_steps["amount"].iloc[0] == 15.0
+ assert state.history_steps["deal_amount"].iloc[0] == 15.0
+ assert state.history_steps["ffr"].iloc[0] == 0.5
+ assert (
+ state.history_steps["pa"].iloc[0]
+ == (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000
+ )
+
+ assert state.position == 15.0
+ assert state.cur_time == pd.Timestamp("2013-12-11 10:00:00")
+
+
+def test_simulator_stop_twap():
+ order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
+
+ simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
+ for _ in range(13):
+ simulator.step(1.0)
+
+ state = simulator.get_state()
+ assert len(state.history_exec) == 390
+ assert (state.history_exec["deal_amount"] == 13 / 390).all()
+ assert state.history_steps["position"].iloc[0] == 12 and state.history_steps["position"].iloc[-1] == 0
+
+ assert (state.metrics["ffr"] - 1) < 1e-3
+ assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4
+ assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
+ assert state.position == 0.0
+ assert abs(state.metrics["trade_price"] - state.metrics["market_price"]) < 1e-4
+ assert abs(state.metrics["pa"]) < 1e-2
+
+ assert simulator.done()
+
+
+def test_simulator_stop_early():
+ order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
+
+ with pytest.raises(ValueError):
+ simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
+ simulator.step(2.0)
+
+ simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
+ simulator.step(1.0)
+
+ with pytest.raises(AssertionError):
+ simulator.step(1.0)
+
+
+def test_simulator_start_middle():
+ order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
+
+ simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
+ assert len(simulator.ticks_for_order) == 330
+ assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
+ simulator.step(2.0)
+ assert simulator.cur_time == pd.Timestamp("2013-12-11 10:30:00")
+
+ for _ in range(10):
+ simulator.step(1.0)
+
+ simulator.step(2.0)
+ assert len(simulator.history_exec) == 330
+ assert simulator.done()
+ assert abs(simulator.history_exec["amount"].iloc[-1] - (1 + 2 / 15)) < 1e-4
+ assert abs(simulator.metrics["ffr"] - 1) < 1e-4
+
+
+def test_interpreter():
+ order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
+
+ simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
+ assert len(simulator.ticks_for_order) == 330
+ assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
+
+ # emulate a env status
+ class EmulateEnvWrapper(NamedTuple):
+ status: EnvWrapperStatus
+
+ interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
+ interpreter_step = CurrentStepStateInterpreter(13)
+ interpreter_action = CategoricalActionInterpreter(20)
+ interpreter_action_twap = TwapRelativeActionInterpreter()
+
+ wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
+
+ # first step
+ interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
+
+ obs = interpreter(simulator.get_state())
+ assert obs["cur_tick"] == 45
+ assert obs["cur_step"] == 0
+ assert obs["position"] == 15.0
+ assert obs["position_history"][0] == 15.0
+ assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(45))
+ assert np.sum(obs["data_processed"][45:]) == 0
+ assert obs["data_processed_prev"].shape == (390, 5)
+
+ # first step: second interpreter
+ interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
+
+ obs = interpreter_step(simulator.get_state())
+ assert obs["acquiring"] == 1
+ assert obs["position"] == 15.0
+
+ # second step
+ simulator.step(5.0)
+ interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs))
+
+ obs = interpreter(simulator.get_state())
+ assert obs["cur_tick"] == 60
+ assert obs["cur_step"] == 1
+ assert obs["position"] == 10.0
+ assert obs["position_history"][:2].tolist() == [15.0, 10.0]
+ assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(60))
+ assert np.sum(obs["data_processed"][60:]) == 0
+
+ # second step: action
+ action = interpreter_action(simulator.get_state(), 1)
+ assert action == 15 / 20
+
+ interpreter_action_twap.env = EmulateEnvWrapper(
+ status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)
+ )
+ action = interpreter_action_twap(simulator.get_state(), 1.5)
+ assert action == 1.5
+
+ # fast-forward
+ for _ in range(10):
+ simulator.step(0.0)
+
+ # last step
+ simulator.step(5.0)
+ interpreter.env = EmulateEnvWrapper(
+ status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs)
+ )
+
+ assert interpreter.env.status["done"]
+
+ obs = interpreter(simulator.get_state())
+ assert obs["cur_tick"] == 375
+ assert obs["cur_step"] == 12
+ assert obs["position"] == 0.0
+ assert obs["position_history"][1:11].tolist() == [10.0] * 10
+ assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(375))
+ assert np.sum(obs["data_processed"][375:]) == 0
+
+
+def test_network_sanity():
+ # we won't check the correctness of networks here
+ order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
+
+ simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
+ assert len(simulator.ticks_for_order) == 390
+
+ class EmulateEnvWrapper(NamedTuple):
+ status: EnvWrapperStatus
+
+ interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
+ action_interp = CategoricalActionInterpreter(13)
+
+ wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
+
+ network = Recurrent(interpreter.observation_space)
+ policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3)
+
+ for i in range(14):
+ interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs))
+ obs = interpreter(simulator.get_state())
+ batch = Batch(obs=[obs])
+ output = policy(batch)
+ assert 0 <= output["act"].item() <= 13
+ if i < 13:
+ simulator.step(1.0)
+ else:
+ assert obs["cur_tick"] == 389
+ assert obs["cur_step"] == 12
+ assert obs["position_history"][-1] == 3
+
+
+@pytest.mark.parametrize("finite_env_type", ["dummy", "subproc", "shmem"])
+def test_twap_strategy(finite_env_type):
+ set_log_with_config(C.logging_config)
+ orders = pickle_styled.load_orders(ORDER_DIR)
+ assert len(orders) == 248
+
+ state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
+ action_interp = TwapRelativeActionInterpreter()
+ policy = AllOne(state_interp.observation_space, action_interp.action_space)
+ csv_writer = CsvWriter(Path(__file__).parent / ".output")
+
+ backtest(
+ partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
+ state_interp,
+ action_interp,
+ orders,
+ policy,
+ [ConsoleWriter(total_episodes=len(orders)), csv_writer],
+ concurrency=4,
+ finite_env_type=finite_env_type,
+ )
+
+ metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
+ assert len(metrics) == 248
+ assert np.isclose(metrics["ffr"].mean(), 1.0)
+ assert np.isclose(metrics["pa"].mean(), 0.0)
+ assert np.allclose(metrics["pa"], 0.0, atol=2e-3)
+
+
+def test_cn_ppo_strategy():
+ set_log_with_config(C.logging_config)
+ # The data starts with 9:31 and ends with 15:00
+ orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
+ assert len(orders) == 40
+
+ state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
+ action_interp = CategoricalActionInterpreter(4)
+ network = Recurrent(state_interp.observation_space)
+ policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
+ policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
+ csv_writer = CsvWriter(Path(__file__).parent / ".output")
+
+ backtest(
+ partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
+ state_interp,
+ action_interp,
+ orders,
+ policy,
+ [ConsoleWriter(total_episodes=len(orders)), csv_writer],
+ concurrency=4,
+ )
+
+ metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
+ assert len(metrics) == len(orders)
+ assert np.isclose(metrics["ffr"].mean(), 1.0)
+ assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
+ assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
+ assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)