Skip to content

Commit aa481bc

Browse files
authored
Merge pull request #1244 from microsoft/huoran/qlib_rl_followup
Qlib simulator refinement
2 parents e78fe48 + cb2b214 commit aa481bc

26 files changed

+995
-758
lines changed

qlib/backtest/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,4 +345,4 @@ def format_decisions(
345345
return res
346346

347347

348-
__all__ = ["Order", "backtest"]
348+
__all__ = ["Order", "backtest", "get_strategy_executor"]

qlib/backtest/backtest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def collect_data_loop(
8383
while not trade_executor.finished():
8484
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
8585
_execute_result = yield from trade_executor.collect_data(_trade_decision, level=0)
86+
trade_strategy.post_exe_step(_execute_result)
8687
bar.update(1)
88+
trade_strategy.post_upper_level_exe_step()
8789

8890
if return_value is not None:
8991
all_executors = trade_executor.get_all_executors()

qlib/backtest/decision.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,21 @@ def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> U
135135
else:
136136
raise NotImplementedError(f"This type of input is not supported")
137137

138+
@property
139+
def key_by_day(self) -> tuple:
140+
"""A hashable & unique key to identify this order, under the granularity in day."""
141+
return self.stock_id, self.date, self.direction
142+
143+
@property
144+
def key(self) -> tuple:
145+
"""A hashable & unique key to identify this order."""
146+
return self.stock_id, self.start_time, self.end_time, self.direction
147+
148+
@property
149+
def date(self) -> pd.Timestamp:
150+
"""Date of the order."""
151+
return pd.Timestamp(self.start_time.replace(hour=0, minute=0, second=0))
152+
138153

139154
class OrderHelper:
140155
"""

qlib/backtest/executor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
self.track_data = track_data
115115
self._trade_exchange = trade_exchange
116116
self.level_infra = LevelInfrastructure()
117-
self.level_infra.reset_infra(common_infra=common_infra)
117+
self.level_infra.reset_infra(common_infra=common_infra, executor=self)
118118
self._settle_type = settle_type
119119
self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
120120
if common_infra is None:
@@ -134,6 +134,8 @@ def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_acco
134134
else:
135135
self.common_infra.update(common_infra)
136136

137+
self.level_infra.reset_infra(common_infra=self.common_infra)
138+
137139
if common_infra.has("trade_account"):
138140
# NOTE: there is a trick in the code.
139141
# shallow copy is used instead of deepcopy.
@@ -256,6 +258,7 @@ def collect_data(
256258
object
257259
trade decision
258260
"""
261+
259262
if self.track_data:
260263
yield trade_decision
261264

@@ -296,6 +299,7 @@ def collect_data(
296299

297300
if return_value is not None:
298301
return_value.update({"execute_result": res})
302+
299303
return res
300304

301305
def get_all_executors(self) -> List[BaseExecutor]:
@@ -396,7 +400,7 @@ def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTrade
396400
trade_decision = updated_trade_decision
397401
# NEW UPDATE
398402
# create a hook for inner strategy to update outer decision
399-
self.inner_strategy.alter_outer_trade_decision(trade_decision)
403+
trade_decision = self.inner_strategy.alter_outer_trade_decision(trade_decision)
400404
return trade_decision
401405

402406
def _collect_data(
@@ -473,6 +477,9 @@ def _collect_data(
473477
# do nothing and just step forward
474478
sub_cal.step()
475479

480+
# Let inner strategy know that the outer level execution is done.
481+
self.inner_strategy.post_upper_level_exe_step()
482+
476483
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
477484

478485
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:

qlib/backtest/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
from __future__ import annotations
55

6-
import bisect
76
from abc import abstractmethod
8-
from typing import TYPE_CHECKING, Any, Set, Tuple, Union
7+
from typing import Any, Set, Tuple, TYPE_CHECKING, Union
98

109
import numpy as np
1110

@@ -184,8 +183,8 @@ def get_range_idx(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tup
184183
Tuple[int, int]:
185184
the index of the range. **the left and right are closed**
186185
"""
187-
left = bisect.bisect_right(list(self._calendar), start_time) - 1
188-
right = bisect.bisect_right(list(self._calendar), end_time) - 1
186+
left = np.searchsorted(self._calendar, start_time, side="right") - 1
187+
right = np.searchsorted(self._calendar, end_time, side="right") - 1
189188
left -= self.start_index
190189
right -= self.start_index
191190

@@ -248,7 +247,7 @@ def get_support_infra(self) -> Set[str]:
248247
sub_level_infra:
249248
- **NOTE**: this will only work after _init_sub_trading !!!
250249
"""
251-
return {"trade_calendar", "sub_level_infra", "common_infra"}
250+
return {"trade_calendar", "sub_level_infra", "common_infra", "executor"}
252251

253252
def reset_cal(
254253
self,

qlib/constant.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
# Licensed under the MIT License.
33

44
# REGION CONST
5+
from typing import TypeVar
6+
7+
import numpy as np
8+
import pandas as pd
9+
510
REG_CN = "cn"
611
REG_US = "us"
712
REG_TW = "tw"
@@ -10,4 +15,8 @@
1015
EPS = 1e-12
1116

1217
# Infinity in integer
13-
INF = 10**18
18+
INF = int(1e18)
19+
ONE_DAY = pd.Timedelta("1day")
20+
ONE_MIN = pd.Timedelta("1min")
21+
EPS_T = pd.Timedelta("1s") # use 1 second to exclude the right interval point
22+
float_or_ndarray = TypeVar("float_or_ndarray", float, np.ndarray)

qlib/data/dataset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,4 +615,4 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
615615
return tsds
616616

617617

618-
__all__ = ["Optional"]
618+
__all__ = ["Optional", "Dataset", "DatasetH"]

qlib/rl/aux_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from __future__ import annotations
55

6-
from typing import Optional, TYPE_CHECKING, Generic, TypeVar
6+
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
77

88
from qlib.typehint import final
99

qlib/rl/data/exchange_wrapper.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,33 @@
33

44
from typing import cast
55

6+
import cachetools
67
import pandas as pd
78

89
from qlib.backtest import Exchange, Order
9-
from .pickle_styled import IntradayBacktestData
10+
from qlib.backtest.decision import TradeRange, TradeRangeByTime
11+
from qlib.constant import ONE_DAY, EPS_T
12+
from qlib.rl.order_execution.utils import get_ticks_slice
13+
from qlib.utils.index_data import IndexData
14+
from .pickle_styled import BaseIntradayBacktestData
1015

1116

12-
class QlibIntradayBacktestData(IntradayBacktestData):
17+
class IntradayBacktestData(BaseIntradayBacktestData):
1318
"""Backtest data for Qlib simulator"""
1419

15-
def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None:
16-
super(QlibIntradayBacktestData, self).__init__()
20+
def __init__(
21+
self,
22+
order: Order,
23+
exchange: Exchange,
24+
ticks_index: pd.DatetimeIndex,
25+
ticks_for_order: pd.DatetimeIndex,
26+
) -> None:
1727
self._order = order
1828
self._exchange = exchange
19-
self._start_time = start_time
20-
self._end_time = end_time
29+
self._start_time = ticks_for_order[0]
30+
self._end_time = ticks_for_order[-1]
31+
self.ticks_index = ticks_index
32+
self.ticks_for_order = ticks_for_order
2133

2234
self._deal_price = cast(
2335
pd.Series,
@@ -56,3 +68,43 @@ def get_volume(self) -> pd.Series:
5668

5769
def get_time_index(self) -> pd.DatetimeIndex:
5870
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
71+
72+
73+
@cachetools.cached( # type: ignore
74+
cache=cachetools.LRUCache(100),
75+
key=lambda order, _, __: order.key_by_day,
76+
)
77+
def load_qlib_backtest_data(
78+
order: Order,
79+
trade_exchange: Exchange,
80+
trade_range: TradeRange,
81+
) -> IntradayBacktestData:
82+
data = cast(
83+
IndexData,
84+
trade_exchange.get_deal_price(
85+
stock_id=order.stock_id,
86+
start_time=order.date,
87+
end_time=order.date + ONE_DAY - EPS_T,
88+
direction=order.direction,
89+
method=None,
90+
),
91+
)
92+
93+
ticks_index = pd.DatetimeIndex(data.index)
94+
if isinstance(trade_range, TradeRangeByTime):
95+
ticks_for_order = get_ticks_slice(
96+
ticks_index,
97+
trade_range.start_time,
98+
trade_range.end_time,
99+
include_end=True,
100+
)
101+
else:
102+
ticks_for_order = None # FIXME: implement this logic
103+
104+
backtest_data = IntradayBacktestData(
105+
order=order,
106+
exchange=trade_exchange,
107+
ticks_index=ticks_index,
108+
ticks_for_order=ticks_for_order,
109+
)
110+
return backtest_data

qlib/rl/data/pickle_styled.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
8686
return pd.read_pickle(_find_pickle(filename_without_suffix))
8787

8888

89-
class IntradayBacktestData:
89+
class BaseIntradayBacktestData:
9090
"""
9191
Raw market data that is often used in backtesting (thus called BacktestData).
9292
@@ -115,7 +115,7 @@ def get_time_index(self) -> pd.DatetimeIndex:
115115
raise NotImplementedError
116116

117117

118-
class SimpleIntradayBacktestData(IntradayBacktestData):
118+
class SimpleIntradayBacktestData(BaseIntradayBacktestData):
119119
"""Backtest data for simple simulator"""
120120

121121
def __init__(

0 commit comments

Comments
 (0)