Skip to content

Commit 8b36c4c

Browse files
authored
Fix/35 (#37)
* Allow multiple event listeners on same ws connection * fix some logs
1 parent f39ec12 commit 8b36c4c

File tree

3 files changed

+44
-31
lines changed

3 files changed

+44
-31
lines changed

src/controller.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ impl AppState {
374374
tx_sig: &str,
375375
sub_account_id: Option<u16>,
376376
) -> GatewayResult<TxEventsResponse> {
377-
let signature = Signature::from_str(&tx_sig).map_err(|err| {
377+
let signature = Signature::from_str(tx_sig).map_err(|err| {
378378
warn!(target: LOG_TARGET, "failed to parse transaction signature: {err:?}");
379379
ControllerError::BadRequest(format!("failed to parse transaction signature: {err:?}"))
380380
})?;

src/types.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
//! - gateway request/responses
33
//! - wrappers for presenting drift program types with less implementation detail
44
//!
5-
use std::io::Empty;
6-
75
use drift_sdk::{
86
constants::{ProgramData, BASE_PRECISION, PRICE_PRECISION},
97
dlob::{self, L2Level, L2Orderbook},

src/websocket.rs

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Websocket server
22
3-
use std::ops::Neg;
3+
use std::{collections::HashMap, ops::Neg, sync::Arc};
44

55
use drift_sdk::{
66
async_utils::retry_policy::{self},
@@ -14,9 +14,9 @@ use log::{info, warn};
1414
use rust_decimal::Decimal;
1515
use serde::{Deserialize, Deserializer, Serialize, Serializer};
1616
use serde_json::json;
17-
use solana_sdk::account::Account;
1817
use tokio::{
1918
net::{TcpListener, TcpStream},
19+
sync::Mutex,
2020
task::JoinHandle,
2121
};
2222
use tokio_tungstenite::{accept_async, tungstenite::Message};
@@ -58,21 +58,20 @@ async fn accept_connection(
5858
) {
5959
let addr = stream.peer_addr().expect("peer address");
6060
let ws_stream = accept_async(stream).await.expect("Ws handshake");
61-
info!("accepted Ws connection: {}", addr);
61+
info!(target: LOG_TARGET, "accepted Ws connection: {}", addr);
6262

6363
let (mut ws_out, mut ws_in) = ws_stream.split();
6464
let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::<Message>(32);
65-
let mut stream_handle: Option<JoinHandle<()>> = None;
65+
let subscriptions = Arc::new(Mutex::new(HashMap::<u8, JoinHandle<()>>::default()));
6666

6767
// writes messages to the connection
6868
tokio::spawn(async move {
6969
while let Some(msg) = message_rx.recv().await {
7070
if msg.is_close() {
7171
let _ = ws_out.close().await;
7272
break;
73-
} else {
74-
ws_out.send(msg).await.expect("sent");
7573
}
74+
ws_out.send(msg).await.expect("sent");
7675
}
7776
});
7877

@@ -84,23 +83,34 @@ async fn accept_connection(
8483
match request.method {
8584
Method::Subscribe => {
8685
// TODO: support subscriptions for individual channels and/or markets
87-
if stream_handle.is_some() {
88-
// no double subs
89-
return;
86+
let mut subscription_map = subscriptions.lock().await;
87+
if subscription_map.contains_key(&request.sub_account_id) {
88+
info!(target: LOG_TARGET, "subscription already exists for: {}", request.sub_account_id);
89+
message_tx
90+
.send(Message::text(
91+
json!({
92+
"error": "bad request",
93+
"reason": "subscription already exists",
94+
})
95+
.to_string(),
96+
))
97+
.await
98+
.unwrap();
99+
continue;
90100
}
91101
info!(target: LOG_TARGET, "subscribing to events for: {}", request.sub_account_id);
92102

93-
let sub_account_address =
94-
wallet.sub_account(request.sub_account_id as u16);
95-
let mut event_stream = EventSubscriber::subscribe(
96-
PubsubClient::new(ws_endpoint.as_str())
97-
.await
98-
.expect("ws connect"),
99-
sub_account_address,
100-
retry_policy::forever(5),
101-
);
102-
103103
let join_handle = tokio::spawn({
104+
let sub_account_address =
105+
wallet.sub_account(request.sub_account_id as u16);
106+
let mut event_stream = EventSubscriber::subscribe(
107+
PubsubClient::new(ws_endpoint.as_str())
108+
.await
109+
.expect("ws connect"),
110+
sub_account_address,
111+
retry_policy::forever(5),
112+
);
113+
let subscription_map = Arc::clone(&subscriptions);
104114
let sub_account_id = request.sub_account_id;
105115
let message_tx = message_tx.clone();
106116
async move {
@@ -113,7 +123,7 @@ async fn accept_connection(
113123
if data.is_none() {
114124
continue;
115125
}
116-
message_tx
126+
if message_tx
117127
.send(Message::text(
118128
serde_json::to_string(&WsEvent {
119129
data,
@@ -123,34 +133,39 @@ async fn accept_connection(
123133
.expect("serializes"),
124134
))
125135
.await
126-
.expect("capacity");
136+
.is_err()
137+
{
138+
break;
139+
}
127140
}
128141
warn!(target: LOG_TARGET, "event stream finished: {sub_account_id:?}, sending close");
129-
let _ = message_tx.send(Message::Close(None)).await;
142+
subscription_map.lock().await.remove(&sub_account_id);
130143
}
131144
});
132145

133-
stream_handle = Some(join_handle);
146+
subscription_map.insert(request.sub_account_id, join_handle);
134147
}
135148
Method::Unsubscribe => {
136-
info!(target: LOG_TARGET, "unsubscribing: {}", request.sub_account_id);
149+
info!(target: LOG_TARGET, "unsubscribing events of: {}", request.sub_account_id);
137150
// TODO: support ending by channel, this ends all channels
138-
if let Some(task) = stream_handle.take() {
151+
let mut subscription_map = subscriptions.lock().await;
152+
if let Some(task) = subscription_map.remove(&request.sub_account_id) {
139153
task.abort();
140154
}
141155
}
142156
}
143157
}
144158
Err(err) => {
145159
message_tx
146-
.try_send(Message::text(
160+
.send(Message::text(
147161
json!({
148162
"error": "bad request",
149163
"reason": err.to_string(),
150164
})
151165
.to_string(),
152166
))
153-
.expect("capacity");
167+
.await
168+
.unwrap();
154169
}
155170
},
156171
Message::Close(frame) => {
@@ -161,7 +176,7 @@ async fn accept_connection(
161176
_ => (),
162177
}
163178
}
164-
info!("closing Ws connection: {}", addr);
179+
info!(target: LOG_TARGET, "closing Ws connection: {}", addr);
165180
}
166181

167182
#[derive(Deserialize, Debug)]

0 commit comments

Comments
 (0)