11//! Websocket server
22
3- use std:: ops:: Neg ;
3+ use std:: { collections :: HashMap , ops:: Neg , sync :: Arc } ;
44
55use drift_sdk:: {
66 async_utils:: retry_policy:: { self } ,
@@ -14,9 +14,9 @@ use log::{info, warn};
1414use rust_decimal:: Decimal ;
1515use serde:: { Deserialize , Deserializer , Serialize , Serializer } ;
1616use serde_json:: json;
17- use solana_sdk:: account:: Account ;
1817use tokio:: {
1918 net:: { TcpListener , TcpStream } ,
19+ sync:: Mutex ,
2020 task:: JoinHandle ,
2121} ;
2222use tokio_tungstenite:: { accept_async, tungstenite:: Message } ;
@@ -58,21 +58,20 @@ async fn accept_connection(
5858) {
5959 let addr = stream. peer_addr ( ) . expect ( "peer address" ) ;
6060 let ws_stream = accept_async ( stream) . await . expect ( "Ws handshake" ) ;
61- info ! ( "accepted Ws connection: {}" , addr) ;
61+ info ! ( target : LOG_TARGET , "accepted Ws connection: {}" , addr) ;
6262
6363 let ( mut ws_out, mut ws_in) = ws_stream. split ( ) ;
6464 let ( message_tx, mut message_rx) = tokio:: sync:: mpsc:: channel :: < Message > ( 32 ) ;
65- let mut stream_handle : Option < JoinHandle < ( ) > > = None ;
65+ let subscriptions = Arc :: new ( Mutex :: new ( HashMap :: < u8 , JoinHandle < ( ) > > :: default ( ) ) ) ;
6666
6767 // writes messages to the connection
6868 tokio:: spawn ( async move {
6969 while let Some ( msg) = message_rx. recv ( ) . await {
7070 if msg. is_close ( ) {
7171 let _ = ws_out. close ( ) . await ;
7272 break ;
73- } else {
74- ws_out. send ( msg) . await . expect ( "sent" ) ;
7573 }
74+ ws_out. send ( msg) . await . expect ( "sent" ) ;
7675 }
7776 } ) ;
7877
@@ -84,23 +83,34 @@ async fn accept_connection(
8483 match request. method {
8584 Method :: Subscribe => {
8685 // TODO: support subscriptions for individual channels and/or markets
87- if stream_handle. is_some ( ) {
88- // no double subs
89- return ;
86+ let mut subscription_map = subscriptions. lock ( ) . await ;
87+ if subscription_map. contains_key ( & request. sub_account_id ) {
88+ info ! ( target: LOG_TARGET , "subscription already exists for: {}" , request. sub_account_id) ;
89+ message_tx
90+ . send ( Message :: text (
91+ json ! ( {
92+ "error" : "bad request" ,
93+ "reason" : "subscription already exists" ,
94+ } )
95+ . to_string ( ) ,
96+ ) )
97+ . await
98+ . unwrap ( ) ;
99+ continue ;
90100 }
91101 info ! ( target: LOG_TARGET , "subscribing to events for: {}" , request. sub_account_id) ;
92102
93- let sub_account_address =
94- wallet. sub_account ( request. sub_account_id as u16 ) ;
95- let mut event_stream = EventSubscriber :: subscribe (
96- PubsubClient :: new ( ws_endpoint. as_str ( ) )
97- . await
98- . expect ( "ws connect" ) ,
99- sub_account_address,
100- retry_policy:: forever ( 5 ) ,
101- ) ;
102-
103103 let join_handle = tokio:: spawn ( {
104+ let sub_account_address =
105+ wallet. sub_account ( request. sub_account_id as u16 ) ;
106+ let mut event_stream = EventSubscriber :: subscribe (
107+ PubsubClient :: new ( ws_endpoint. as_str ( ) )
108+ . await
109+ . expect ( "ws connect" ) ,
110+ sub_account_address,
111+ retry_policy:: forever ( 5 ) ,
112+ ) ;
113+ let subscription_map = Arc :: clone ( & subscriptions) ;
104114 let sub_account_id = request. sub_account_id ;
105115 let message_tx = message_tx. clone ( ) ;
106116 async move {
@@ -113,7 +123,7 @@ async fn accept_connection(
113123 if data. is_none ( ) {
114124 continue ;
115125 }
116- message_tx
126+ if message_tx
117127 . send ( Message :: text (
118128 serde_json:: to_string ( & WsEvent {
119129 data,
@@ -123,34 +133,39 @@ async fn accept_connection(
123133 . expect ( "serializes" ) ,
124134 ) )
125135 . await
126- . expect ( "capacity" ) ;
136+ . is_err ( )
137+ {
138+ break ;
139+ }
127140 }
128141 warn ! ( target: LOG_TARGET , "event stream finished: {sub_account_id:?}, sending close" ) ;
129- let _ = message_tx . send ( Message :: Close ( None ) ) . await ;
142+ subscription_map . lock ( ) . await . remove ( & sub_account_id ) ;
130143 }
131144 } ) ;
132145
133- stream_handle = Some ( join_handle) ;
146+ subscription_map . insert ( request . sub_account_id , join_handle) ;
134147 }
135148 Method :: Unsubscribe => {
136- info ! ( target: LOG_TARGET , "unsubscribing: {}" , request. sub_account_id) ;
149+ info ! ( target: LOG_TARGET , "unsubscribing events of : {}" , request. sub_account_id) ;
137150 // TODO: support ending by channel, this ends all channels
138- if let Some ( task) = stream_handle. take ( ) {
151+ let mut subscription_map = subscriptions. lock ( ) . await ;
152+ if let Some ( task) = subscription_map. remove ( & request. sub_account_id ) {
139153 task. abort ( ) ;
140154 }
141155 }
142156 }
143157 }
144158 Err ( err) => {
145159 message_tx
146- . try_send ( Message :: text (
160+ . send ( Message :: text (
147161 json ! ( {
148162 "error" : "bad request" ,
149163 "reason" : err. to_string( ) ,
150164 } )
151165 . to_string ( ) ,
152166 ) )
153- . expect ( "capacity" ) ;
167+ . await
168+ . unwrap ( ) ;
154169 }
155170 } ,
156171 Message :: Close ( frame) => {
@@ -161,7 +176,7 @@ async fn accept_connection(
161176 _ => ( ) ,
162177 }
163178 }
164- info ! ( "closing Ws connection: {}" , addr) ;
179+ info ! ( target : LOG_TARGET , "closing Ws connection: {}" , addr) ;
165180}
166181
167182#[ derive( Deserialize , Debug ) ]
0 commit comments