Skip to content

Commit 1a8fc1f

Browse files
committed
Merge branch 'hotfix/1.7.1_samples'
2 parents e5767d4 + 45e33db commit 1a8fc1f

File tree

3 files changed

+214
-168
lines changed

3 files changed

+214
-168
lines changed

samples/backtesting/position_manager.py

Lines changed: 101 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from collections import defaultdict
1718
from decimal import Decimal
18-
from typing import Dict, Optional
19+
from typing import cast, Dict, Optional
1920
import asyncio
21+
import copy
2022
import dataclasses
2123
import datetime
2224
import logging
@@ -29,28 +31,50 @@
2931
@dataclasses.dataclass
3032
class PositionInfo:
3133
pair: bs.Pair
34+
pair_info: bs.PairInfo
3235
initial: Decimal
3336
initial_avg_price: Decimal
3437
target: Decimal
3538
order: backtesting_exchange.OrderInfo
3639

40+
def __post_init__(self):
41+
# Both initial and initial_avg_price should be set to 0, or none of them.
42+
assert (self.initial == Decimal(0)) is (self.initial_avg_price == Decimal(0)), \
43+
f"initial={self.initial}, initial_avg_price={self.initial_avg_price}"
44+
3745
@property
3846
def current(self) -> Decimal:
3947
delta = self.order.amount_filled if self.order.operation == bs.OrderOperation.BUY else -self.order.amount_filled
4048
return self.initial + delta
4149

4250
@property
4351
def avg_price(self) -> Decimal:
44-
order_fill_price = Decimal(0) if self.order.fill_price is None else self.order.fill_price
45-
ret = order_fill_price
52+
# If the current position is 0, then the average price is 0.
53+
current = self.current
54+
if current == Decimal(0):
55+
return Decimal(0)
4656

57+
# If the current position is not 0, then the order will have a fill price.
58+
order_fill_price = cast(Decimal, self.order.fill_price)
59+
60+
# If we're going from a neutral position to a non-neutral position, the order fill price is returned.
4761
if self.initial == 0:
4862
ret = order_fill_price
49-
# Transition from long to short, or viceversa, and already on the target side.
50-
elif self.initial * self.target < 0 and self.current * self.target > 0:
51-
ret = order_fill_price
63+
# If we are closing the position, going from a non-neutral position to a neutral position, the initial average
64+
# price is returned.
65+
elif self.target == 0:
66+
ret = self.initial_avg_price
67+
# Going from long to short, or the other way around.
68+
elif self.initial * self.target < 0:
69+
# If we are on the target side, the order fill price is returned.
70+
if current * self.target > 0:
71+
ret = order_fill_price
72+
# If we're still on the initial side, the initial average price is returned.
73+
else:
74+
ret = self.initial_avg_price
5275
# Rebalancing on the same side.
53-
elif self.initial * self.target > 0:
76+
else:
77+
assert self.initial * self.target > 0
5478
# Reducing the position.
5579
if self.target > 0 and self.order.operation == bs.OrderOperation.SELL \
5680
or self.target < 0 and self.order.operation == bs.OrderOperation.BUY:
@@ -59,6 +83,7 @@ def avg_price(self) -> Decimal:
5983
else:
6084
ret = (abs(self.initial) * self.initial_avg_price + self.order.amount_filled * order_fill_price) \
6185
/ (abs(self.initial) + self.order.amount_filled)
86+
6287
return ret
6388

6489
@property
@@ -93,32 +118,26 @@ def __init__(
93118
self._position_amount = position_amount
94119
self._quote_symbol = quote_symbol
95120
self._positions: Dict[bs.Pair, PositionInfo] = {}
121+
self._pos_mutex: Dict[bs.Pair, asyncio.Lock] = defaultdict(asyncio.Lock)
96122
self._stop_loss_pct = stop_loss_pct
97123
self._borrowing_disabled = borrowing_disabled
98124
self._last_check_loss: Optional[datetime.datetime] = None
99125

100-
async def cancel_open_orders(self, pair: bs.Pair):
101-
open_orders = await self._exchange.get_open_orders(pair)
102-
await asyncio.gather(*[
103-
self._exchange.cancel_order(open_order.id)
104-
for open_order in open_orders
105-
])
106-
107126
async def get_position_info(self, pair: bs.Pair) -> Optional[PositionInfo]:
108127
pos_info = self._positions.get(pair)
109128
if pos_info and pos_info.order_open:
110-
pos_info.order = await self._exchange.get_order_info(pos_info.order.id)
111-
return pos_info
129+
async with self._pos_mutex[pair]:
130+
pos_info.order = await self._exchange.get_order_info(pos_info.order.id)
131+
return copy.deepcopy(pos_info)
112132

113133
async def check_loss(self):
114-
# Refresh positions that have an open order.
115-
refresh_pairs = [pair for pair, pos_info in self._positions.items() if pos_info.order_open]
116-
coros = [self.get_position_info(pair) for pair in refresh_pairs]
134+
positions = []
135+
coros = [self.get_position_info(pair) for pair in self._positions.keys()]
117136
if coros:
118-
await asyncio.gather(*coros)
137+
positions.extend(await asyncio.gather(*coros))
119138

120139
# Check unrealized PnL for all non-neutral positions.
121-
non_neutral = [pos_info for pos_info in self._positions.values() if pos_info.current != Decimal(0)]
140+
non_neutral = [pos_info for pos_info in positions if pos_info.current != Decimal(0)]
122141
if not non_neutral:
123142
return
124143

@@ -128,14 +147,17 @@ async def check_loss(self):
128147
pnl_pct = pos_info.calculate_unrealized_pnl_pct(bid, ask)
129148
logging.info(StructuredMessage(
130149
f"Position for {pos_info.pair}", current=pos_info.current, target=pos_info.target,
131-
avg_price=pos_info.avg_price, pnl_pct=pnl_pct, order_open=pos_info.order_open
150+
order_open=pos_info.order_open,
151+
avg_price=bs.round_decimal(pos_info.avg_price, pos_info.pair_info.quote_precision),
152+
pnl_pct=bs.round_decimal(pnl_pct, 2)
132153
))
133154
if pnl_pct <= self._stop_loss_pct * -1:
134155
logging.info(f"Stop loss for {pos_info.pair}")
135156
await self.switch_position(pos_info.pair, bs.Position.NEUTRAL, force=True)
136157

137158
async def switch_position(self, pair: bs.Pair, target_position: bs.Position, force: bool = False):
138159
current_pos_info = await self.get_position_info(pair)
160+
139161
# Unless force is set, we can ignore the request if we're already there.
140162
if not force and any([
141163
current_pos_info is None and target_position == bs.Position.NEUTRAL,
@@ -147,54 +169,63 @@ async def switch_position(self, pair: bs.Pair, target_position: bs.Position, for
147169
]):
148170
return
149171

150-
# Cancel the previous order.
151-
if current_pos_info and current_pos_info.order_open:
152-
await self._exchange.cancel_order(current_pos_info.order.id)
153-
current_pos_info.order = await self._exchange.get_order_info(current_pos_info.order.id)
154-
155-
(bid, ask), pair_info = await asyncio.gather(
156-
self._exchange.get_bid_ask(pair),
157-
self._exchange.get_pair_info(pair),
158-
)
159-
160-
# 1. Calculate the target balance.
161-
# If the target position is neutral, the target balance is 0, otherwise we need to convert
162-
# self._position_amount, which is expressed in self._quote_symbol units, into base units.
163-
if target_position == bs.Position.NEUTRAL:
164-
target = Decimal(0)
165-
else:
166-
if pair.quote_symbol == self._quote_symbol:
167-
target = self._position_amount / ((bid + ask) / 2)
172+
# Exclusive access to the position since we're going to modify it.
173+
async with self._pos_mutex[pair]:
174+
# Cancel the previous order.
175+
if current_pos_info and current_pos_info.order_open:
176+
logging.info(StructuredMessage("Canceling order", order_ids=current_pos_info.order.id))
177+
await self._exchange.cancel_order(current_pos_info.order.id)
178+
current_pos_info.order = await self._exchange.get_order_info(current_pos_info.order.id)
179+
180+
(bid, ask), pair_info = await asyncio.gather(
181+
self._exchange.get_bid_ask(pair),
182+
self._exchange.get_pair_info(pair),
183+
)
184+
185+
# 1. Calculate the target balance.
186+
# If the target position is neutral, the target balance is 0, otherwise we need to convert
187+
# self._position_amount, which is expressed in self._quote_symbol units, into base units.
188+
if target_position == bs.Position.NEUTRAL:
189+
target = Decimal(0)
168190
else:
169-
quote_bid, quote_ask = await self._exchange.get_bid_ask(bs.Pair(pair.base_symbol, self._quote_symbol))
170-
target = self._position_amount / ((quote_bid + quote_ask) / 2)
171-
172-
if target_position == bs.Position.SHORT:
173-
target *= -1
174-
target = bs.truncate_decimal(target, pair_info.base_precision)
175-
176-
# 2. Calculate the difference between the target balance and our current balance.
177-
current = Decimal(0) if current_pos_info is None else current_pos_info.current
178-
delta = target - current
179-
logging.info(StructuredMessage("Switch position", pair=pair, current=current, target=target, delta=delta))
180-
if delta == 0:
181-
return
182-
183-
# 3. Create the order.
184-
order_size = abs(delta)
185-
operation = bs.OrderOperation.BUY if delta > 0 else bs.OrderOperation.SELL
186-
logging.info(StructuredMessage("Creating market order", operation=operation, pair=pair, order_size=order_size))
187-
created_order = await self._exchange.create_market_order(
188-
operation, pair, order_size, auto_borrow=True, auto_repay=True
189-
)
190-
order = await self._exchange.get_order_info(created_order.id)
191-
192-
# 4. Keep track of the position.
193-
initial_avg_price = Decimal(0) if current_pos_info is None else current_pos_info.avg_price
194-
pos_info = PositionInfo(
195-
pair=pair, initial=current, initial_avg_price=initial_avg_price, target=target, order=order
196-
)
197-
self._positions[pair] = pos_info
191+
if pair.quote_symbol == self._quote_symbol:
192+
target = self._position_amount / ((bid + ask) / 2)
193+
else:
194+
quote_bid, quote_ask = await self._exchange.get_bid_ask(
195+
bs.Pair(pair.base_symbol, self._quote_symbol)
196+
)
197+
target = self._position_amount / ((quote_bid + quote_ask) / 2)
198+
199+
if target_position == bs.Position.SHORT:
200+
target *= -1
201+
target = bs.truncate_decimal(target, pair_info.base_precision)
202+
203+
# 2. Calculate the difference between the target balance and our current balance.
204+
current = Decimal(0) if current_pos_info is None else current_pos_info.current
205+
delta = target - current
206+
logging.info(StructuredMessage("Switch position", pair=pair, current=current, target=target, delta=delta))
207+
if delta == 0:
208+
return
209+
210+
# 3. Create the order.
211+
order_size = abs(delta)
212+
operation = bs.OrderOperation.BUY if delta > 0 else bs.OrderOperation.SELL
213+
logging.info(StructuredMessage(
214+
"Creating market order", operation=operation, pair=pair, order_size=order_size
215+
))
216+
created_order = await self._exchange.create_market_order(
217+
operation, pair, order_size, auto_borrow=True, auto_repay=True
218+
)
219+
logging.info(StructuredMessage("Order created", id=created_order.id))
220+
order = await self._exchange.get_order_info(created_order.id)
221+
222+
# 4. Keep track of the position.
223+
initial_avg_price = Decimal(0) if current_pos_info is None else current_pos_info.avg_price
224+
pos_info = PositionInfo(
225+
pair=pair, pair_info=pair_info, initial=current, initial_avg_price=initial_avg_price, target=target,
226+
order=order
227+
)
228+
self._positions[pair] = pos_info
198229

199230
async def on_trading_signal(self, trading_signal: bs.TradingSignal):
200231
pairs = list(trading_signal.get_pairs())
@@ -220,7 +251,7 @@ async def on_bar_event(self, bar_event: bs.BarEvent):
220251
async def on_order_event(self, order_event: backtesting_exchange.OrderEvent):
221252
order = order_event.order
222253
logging.info(StructuredMessage(
223-
"Order udpated", id=order.id, is_open=order.is_open, amount=order.amount,
254+
"Order updated", id=order.id, is_open=order.is_open, amount=order.amount,
224255
amount_filled=order.amount_filled, avg_fill_price=order.fill_price
225256
))
226257

0 commit comments

Comments
 (0)