1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17+ from collections import defaultdict
1718from decimal import Decimal
18- from typing import Dict , Optional
19+ from typing import cast , Dict , Optional
1920import asyncio
21+ import copy
2022import dataclasses
2123import datetime
2224import logging
2931@dataclasses .dataclass
3032class PositionInfo :
3133 pair : bs .Pair
34+ pair_info : bs .PairInfo
3235 initial : Decimal
3336 initial_avg_price : Decimal
3437 target : Decimal
3538 order : backtesting_exchange .OrderInfo
3639
40+ def __post_init__ (self ):
41+ # Both initial and initial_avg_price should be set to 0, or none of them.
42+ assert (self .initial == Decimal (0 )) is (self .initial_avg_price == Decimal (0 )), \
43+ f"initial={ self .initial } , initial_avg_price={ self .initial_avg_price } "
44+
3745 @property
3846 def current (self ) -> Decimal :
3947 delta = self .order .amount_filled if self .order .operation == bs .OrderOperation .BUY else - self .order .amount_filled
4048 return self .initial + delta
4149
4250 @property
4351 def avg_price (self ) -> Decimal :
44- order_fill_price = Decimal (0 ) if self .order .fill_price is None else self .order .fill_price
45- ret = order_fill_price
52+ # If the current position is 0, then the average price is 0.
53+ current = self .current
54+ if current == Decimal (0 ):
55+ return Decimal (0 )
4656
57+ # If the current position is not 0, then the order will have a fill price.
58+ order_fill_price = cast (Decimal , self .order .fill_price )
59+
60+ # If we're going from a neutral position to a non-neutral position, the order fill price is returned.
4761 if self .initial == 0 :
4862 ret = order_fill_price
49- # Transition from long to short, or viceversa, and already on the target side.
50- elif self .initial * self .target < 0 and self .current * self .target > 0 :
51- ret = order_fill_price
63+ # If we are closing the position, going from a non-neutral position to a neutral position, the initial average
64+ # price is returned.
65+ elif self .target == 0 :
66+ ret = self .initial_avg_price
67+ # Going from long to short, or the other way around.
68+ elif self .initial * self .target < 0 :
69+ # If we are on the target side, the order fill price is returned.
70+ if current * self .target > 0 :
71+ ret = order_fill_price
72+ # If we're still on the initial side, the initial average price is returned.
73+ else :
74+ ret = self .initial_avg_price
5275 # Rebalancing on the same side.
53- elif self .initial * self .target > 0 :
76+ else :
77+ assert self .initial * self .target > 0
5478 # Reducing the position.
5579 if self .target > 0 and self .order .operation == bs .OrderOperation .SELL \
5680 or self .target < 0 and self .order .operation == bs .OrderOperation .BUY :
@@ -59,6 +83,7 @@ def avg_price(self) -> Decimal:
5983 else :
6084 ret = (abs (self .initial ) * self .initial_avg_price + self .order .amount_filled * order_fill_price ) \
6185 / (abs (self .initial ) + self .order .amount_filled )
86+
6287 return ret
6388
6489 @property
@@ -93,32 +118,26 @@ def __init__(
93118 self ._position_amount = position_amount
94119 self ._quote_symbol = quote_symbol
95120 self ._positions : Dict [bs .Pair , PositionInfo ] = {}
121+ self ._pos_mutex : Dict [bs .Pair , asyncio .Lock ] = defaultdict (asyncio .Lock )
96122 self ._stop_loss_pct = stop_loss_pct
97123 self ._borrowing_disabled = borrowing_disabled
98124 self ._last_check_loss : Optional [datetime .datetime ] = None
99125
100- async def cancel_open_orders (self , pair : bs .Pair ):
101- open_orders = await self ._exchange .get_open_orders (pair )
102- await asyncio .gather (* [
103- self ._exchange .cancel_order (open_order .id )
104- for open_order in open_orders
105- ])
106-
107126 async def get_position_info (self , pair : bs .Pair ) -> Optional [PositionInfo ]:
108127 pos_info = self ._positions .get (pair )
109128 if pos_info and pos_info .order_open :
110- pos_info .order = await self ._exchange .get_order_info (pos_info .order .id )
111- return pos_info
129+ async with self ._pos_mutex [pair ]:
130+ pos_info .order = await self ._exchange .get_order_info (pos_info .order .id )
131+ return copy .deepcopy (pos_info )
112132
113133 async def check_loss (self ):
114- # Refresh positions that have an open order.
115- refresh_pairs = [pair for pair , pos_info in self ._positions .items () if pos_info .order_open ]
116- coros = [self .get_position_info (pair ) for pair in refresh_pairs ]
134+ positions = []
135+ coros = [self .get_position_info (pair ) for pair in self ._positions .keys ()]
117136 if coros :
118- await asyncio .gather (* coros )
137+ positions . extend ( await asyncio .gather (* coros ) )
119138
120139 # Check unrealized PnL for all non-neutral positions.
121- non_neutral = [pos_info for pos_info in self . _positions . values () if pos_info .current != Decimal (0 )]
140+ non_neutral = [pos_info for pos_info in positions if pos_info .current != Decimal (0 )]
122141 if not non_neutral :
123142 return
124143
@@ -128,14 +147,17 @@ async def check_loss(self):
128147 pnl_pct = pos_info .calculate_unrealized_pnl_pct (bid , ask )
129148 logging .info (StructuredMessage (
130149 f"Position for { pos_info .pair } " , current = pos_info .current , target = pos_info .target ,
131- avg_price = pos_info .avg_price , pnl_pct = pnl_pct , order_open = pos_info .order_open
150+ order_open = pos_info .order_open ,
151+ avg_price = bs .round_decimal (pos_info .avg_price , pos_info .pair_info .quote_precision ),
152+ pnl_pct = bs .round_decimal (pnl_pct , 2 )
132153 ))
133154 if pnl_pct <= self ._stop_loss_pct * - 1 :
134155 logging .info (f"Stop loss for { pos_info .pair } " )
135156 await self .switch_position (pos_info .pair , bs .Position .NEUTRAL , force = True )
136157
137158 async def switch_position (self , pair : bs .Pair , target_position : bs .Position , force : bool = False ):
138159 current_pos_info = await self .get_position_info (pair )
160+
139161 # Unless force is set, we can ignore the request if we're already there.
140162 if not force and any ([
141163 current_pos_info is None and target_position == bs .Position .NEUTRAL ,
@@ -147,54 +169,63 @@ async def switch_position(self, pair: bs.Pair, target_position: bs.Position, for
147169 ]):
148170 return
149171
150- # Cancel the previous order .
151- if current_pos_info and current_pos_info . order_open :
152- await self . _exchange . cancel_order ( current_pos_info . order .id )
153- current_pos_info . order = await self . _exchange . get_order_info ( current_pos_info .order . id )
154-
155- ( bid , ask ), pair_info = await asyncio . gather (
156- self ._exchange .get_bid_ask ( pair ),
157- self . _exchange . get_pair_info ( pair ),
158- )
159-
160- # 1. Calculate the target balance.
161- # If the target position is neutral, the target balance is 0, otherwise we need to convert
162- # self._position_amount, which is expressed in self._quote_symbol units, into base units.
163- if target_position == bs . Position . NEUTRAL :
164- target = Decimal ( 0 )
165- else :
166- if pair . quote_symbol == self . _quote_symbol :
167- target = self . _position_amount / (( bid + ask ) / 2 )
172+ # Exclusive access to the position since we're going to modify it .
173+ async with self . _pos_mutex [ pair ] :
174+ # Cancel the previous order.
175+ if current_pos_info and current_pos_info .order_open :
176+ logging . info ( StructuredMessage ( "Canceling order" , order_ids = current_pos_info . order . id ))
177+ await self . _exchange . cancel_order ( current_pos_info . order . id )
178+ current_pos_info . order = await self ._exchange .get_order_info ( current_pos_info . order . id )
179+
180+ ( bid , ask ), pair_info = await asyncio . gather (
181+ self . _exchange . get_bid_ask ( pair ),
182+ self . _exchange . get_pair_info ( pair ),
183+ )
184+
185+ # 1. Calculate the target balance.
186+ # If the target position is neutral, the target balance is 0, otherwise we need to convert
187+ # self._position_amount, which is expressed in self._quote_symbol units, into base units.
188+ if target_position == bs . Position . NEUTRAL :
189+ target = Decimal ( 0 )
168190 else :
169- quote_bid , quote_ask = await self ._exchange .get_bid_ask (bs .Pair (pair .base_symbol , self ._quote_symbol ))
170- target = self ._position_amount / ((quote_bid + quote_ask ) / 2 )
171-
172- if target_position == bs .Position .SHORT :
173- target *= - 1
174- target = bs .truncate_decimal (target , pair_info .base_precision )
175-
176- # 2. Calculate the difference between the target balance and our current balance.
177- current = Decimal (0 ) if current_pos_info is None else current_pos_info .current
178- delta = target - current
179- logging .info (StructuredMessage ("Switch position" , pair = pair , current = current , target = target , delta = delta ))
180- if delta == 0 :
181- return
182-
183- # 3. Create the order.
184- order_size = abs (delta )
185- operation = bs .OrderOperation .BUY if delta > 0 else bs .OrderOperation .SELL
186- logging .info (StructuredMessage ("Creating market order" , operation = operation , pair = pair , order_size = order_size ))
187- created_order = await self ._exchange .create_market_order (
188- operation , pair , order_size , auto_borrow = True , auto_repay = True
189- )
190- order = await self ._exchange .get_order_info (created_order .id )
191-
192- # 4. Keep track of the position.
193- initial_avg_price = Decimal (0 ) if current_pos_info is None else current_pos_info .avg_price
194- pos_info = PositionInfo (
195- pair = pair , initial = current , initial_avg_price = initial_avg_price , target = target , order = order
196- )
197- self ._positions [pair ] = pos_info
191+ if pair .quote_symbol == self ._quote_symbol :
192+ target = self ._position_amount / ((bid + ask ) / 2 )
193+ else :
194+ quote_bid , quote_ask = await self ._exchange .get_bid_ask (
195+ bs .Pair (pair .base_symbol , self ._quote_symbol )
196+ )
197+ target = self ._position_amount / ((quote_bid + quote_ask ) / 2 )
198+
199+ if target_position == bs .Position .SHORT :
200+ target *= - 1
201+ target = bs .truncate_decimal (target , pair_info .base_precision )
202+
203+ # 2. Calculate the difference between the target balance and our current balance.
204+ current = Decimal (0 ) if current_pos_info is None else current_pos_info .current
205+ delta = target - current
206+ logging .info (StructuredMessage ("Switch position" , pair = pair , current = current , target = target , delta = delta ))
207+ if delta == 0 :
208+ return
209+
210+ # 3. Create the order.
211+ order_size = abs (delta )
212+ operation = bs .OrderOperation .BUY if delta > 0 else bs .OrderOperation .SELL
213+ logging .info (StructuredMessage (
214+ "Creating market order" , operation = operation , pair = pair , order_size = order_size
215+ ))
216+ created_order = await self ._exchange .create_market_order (
217+ operation , pair , order_size , auto_borrow = True , auto_repay = True
218+ )
219+ logging .info (StructuredMessage ("Order created" , id = created_order .id ))
220+ order = await self ._exchange .get_order_info (created_order .id )
221+
222+ # 4. Keep track of the position.
223+ initial_avg_price = Decimal (0 ) if current_pos_info is None else current_pos_info .avg_price
224+ pos_info = PositionInfo (
225+ pair = pair , pair_info = pair_info , initial = current , initial_avg_price = initial_avg_price , target = target ,
226+ order = order
227+ )
228+ self ._positions [pair ] = pos_info
198229
199230 async def on_trading_signal (self , trading_signal : bs .TradingSignal ):
200231 pairs = list (trading_signal .get_pairs ())
@@ -220,7 +251,7 @@ async def on_bar_event(self, bar_event: bs.BarEvent):
220251 async def on_order_event (self , order_event : backtesting_exchange .OrderEvent ):
221252 order = order_event .order
222253 logging .info (StructuredMessage (
223- "Order udpated " , id = order .id , is_open = order .is_open , amount = order .amount ,
254+ "Order updated " , id = order .id , is_open = order .is_open , amount = order .amount ,
224255 amount_filled = order .amount_filled , avg_fill_price = order .fill_price
225256 ))
226257
0 commit comments