Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 147 additions & 1 deletion codex-rs/exec-server/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ pub(crate) struct SessionState {
events: ExecProcessEventLog,
ordered_events: StdMutex<OrderedSessionEvents>,
recoverable: AtomicBool,
next_write_id: AtomicU64,
}

#[derive(Default)]
Expand Down Expand Up @@ -421,12 +422,14 @@ impl ExecServerClient {
&self,
process_id: &ProcessId,
chunk: Vec<u8>,
write_id: String,
) -> Result<WriteResponse, ExecServerError> {
self.call(
EXEC_WRITE_METHOD,
&WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
write_id,
},
)
.await
Expand Down Expand Up @@ -730,6 +733,7 @@ impl SessionState {
),
ordered_events: StdMutex::new(OrderedSessionEvents::default()),
recoverable: AtomicBool::new(recoverable),
next_write_id: AtomicU64::new(1),
}
}

Expand Down Expand Up @@ -829,6 +833,12 @@ impl SessionState {
failure: Some(message),
}
}

fn next_write_id(&self) -> String {
self.next_write_id
.fetch_add(1, Ordering::Relaxed)
.to_string()
}
}

impl Session {
Expand Down Expand Up @@ -885,7 +895,22 @@ impl Session {
}

pub(crate) async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
self.client.write(&self.process_id, chunk).await
let write_id = self.state.next_write_id();
loop {
match self
.client
.write(&self.process_id, chunk.clone(), write_id.clone())
.await
{
Ok(response) => return Ok(response),
Err(error)
if is_transport_closed_error(&error) && !self.client.inner.is_failed() =>
{
continue;
}
Err(error) => return Err(error),
}
}
}

pub(crate) async fn signal(&self, signal: ProcessSignal) -> Result<(), ExecServerError> {
Expand Down Expand Up @@ -1110,6 +1135,8 @@ mod tests {
use crate::protocol::EXEC_CLOSED_METHOD;
use crate::protocol::EXEC_EXITED_METHOD;
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
use crate::protocol::EXEC_READ_METHOD;
use crate::protocol::EXEC_WRITE_METHOD;
use crate::protocol::ExecClosedNotification;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
Expand All @@ -1118,6 +1145,10 @@ mod tests {
use crate::protocol::INITIALIZED_METHOD;
use crate::protocol::InitializeResponse;
use crate::protocol::ProcessOutputChunk;
use crate::protocol::ReadResponse;
use crate::protocol::WriteParams;
use crate::protocol::WriteResponse;
use crate::protocol::WriteStatus;

async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
where
Expand Down Expand Up @@ -1685,6 +1716,121 @@ mod tests {
server.await.expect("server task should finish");
}

#[tokio::test]
async fn session_write_retries_same_write_id_after_recovery() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let websocket_url = format!(
"ws://{}",
listener.local_addr().expect("listener should have address")
);
let (finish_tx, finish_rx) = oneshot::channel();
let server = tokio::spawn(async move {
let mut first = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut first,
"session-1",
/*expected_resume_session_id*/ None,
)
.await;

let first_write = read_jsonrpc_websocket(&mut first).await;
let first_write = match first_write {
JSONRPCMessage::Request(request) if request.method == EXEC_WRITE_METHOD => request,
other => panic!("expected first process/write request, got {other:?}"),
};
let first_write_params: WriteParams =
serde_json::from_value(first_write.params.expect("write params should exist"))
.expect("write params should deserialize");
assert_eq!(first_write_params.process_id.as_str(), "proc-write");
assert_eq!(first_write_params.chunk.into_inner(), b"hello\n".to_vec());
let write_id = first_write_params.write_id;
assert!(!write_id.is_empty());
drop(first);

let mut resumed = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut resumed,
"session-1",
/*expected_resume_session_id*/ Some("session-1"),
)
.await;

let recovery_read = read_jsonrpc_websocket(&mut resumed).await;
let recovery_read = match recovery_read {
JSONRPCMessage::Request(request) if request.method == EXEC_READ_METHOD => request,
other => panic!("expected recovery process/read request, got {other:?}"),
};
write_jsonrpc_websocket(
&mut resumed,
JSONRPCMessage::Response(JSONRPCResponse {
id: recovery_read.id,
result: serde_json::to_value(ReadResponse {
chunks: Vec::new(),
next_seq: 1,
exited: false,
exit_code: None,
closed: false,
failure: None,
})
.expect("read response should serialize"),
}),
)
.await;

let retried_write = read_jsonrpc_websocket(&mut resumed).await;
let retried_write = match retried_write {
JSONRPCMessage::Request(request) if request.method == EXEC_WRITE_METHOD => request,
other => panic!("expected retried process/write request, got {other:?}"),
};
let retried_write_params: WriteParams =
serde_json::from_value(retried_write.params.expect("write params should exist"))
.expect("write params should deserialize");
assert_eq!(retried_write_params.process_id.as_str(), "proc-write");
assert_eq!(retried_write_params.chunk.into_inner(), b"hello\n".to_vec());
assert_eq!(retried_write_params.write_id, write_id);
write_jsonrpc_websocket(
&mut resumed,
JSONRPCMessage::Response(JSONRPCResponse {
id: retried_write.id,
result: serde_json::to_value(WriteResponse {
status: WriteStatus::Accepted,
})
.expect("write response should serialize"),
}),
)
.await;

finish_rx.await.expect("test should finish");
});

let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl {
websocket_url,
connect_timeout: Duration::from_secs(1),
initialize_timeout: Duration::from_secs(1),
});
let stable_client = client.get().await.expect("client should connect");
let session = stable_client
.register_session(&ProcessId::from("proc-write"))
.await
.expect("session should register");

let response = timeout(Duration::from_secs(2), session.write(b"hello\n".to_vec()))
.await
.expect("write should not time out")
.expect("write should recover");
assert_eq!(
response,
WriteResponse {
status: WriteStatus::Accepted
}
);

finish_tx.send(()).expect("test should finish");
server.await.expect("server task should finish");
}

#[tokio::test]
async fn explicit_resume_drains_notifications_before_initialize_response() {
let listener = TcpListener::bind("127.0.0.1:0")
Expand Down
81 changes: 77 additions & 4 deletions codex-rs/exec-server/src/local_process.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::time::Duration;

use codex_app_server_protocol::JSONRPCErrorError;
Expand Down Expand Up @@ -54,6 +57,8 @@ use crate::rpc::invalid_request;
const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024;
const NOTIFICATION_CHANNEL_CAPACITY: usize = 256;
const PROCESS_EVENT_CHANNEL_CAPACITY: usize = 256;
const RETAINED_STDIN_WRITE_IDS_PER_PROCESS: usize = 4096;
static NEXT_LOCAL_STDIN_WRITE_ID: AtomicU64 = AtomicU64::new(1);
#[cfg(test)]
const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25);
#[cfg(not(test))]
Expand All @@ -70,6 +75,7 @@ struct RunningProcess {
session: ExecCommandSession,
tty: bool,
pipe_stdin: bool,
accepted_stdin_write_ids: Arc<Mutex<AcceptedStdinWriteIds>>,
output: VecDeque<RetainedOutputChunk>,
retained_bytes: usize,
next_seq: u64,
Expand All @@ -81,6 +87,37 @@ struct RunningProcess {
closed: bool,
}

/// Bounded cache of stdin write ids that have already been accepted for one process.
///
/// A remote client can retry `process/write` after reconnecting. Remembering accepted
/// ids lets the server acknowledge the retried request without writing the same bytes
/// to child stdin twice.
#[derive(Default)]
struct AcceptedStdinWriteIds {
ids: HashSet<String>,
order: VecDeque<String>,
Comment thread
jif-oai marked this conversation as resolved.
}

impl AcceptedStdinWriteIds {
fn contains(&self, write_id: &str) -> bool {
self.ids.contains(write_id)
}

fn remember(&mut self, write_id: String) {
if !self.ids.insert(write_id.clone()) {
return;
}

self.order.push_back(write_id);
while self.order.len() > RETAINED_STDIN_WRITE_IDS_PER_PROCESS {
let Some(evicted) = self.order.pop_front() else {
break;
};
self.ids.remove(&evicted);
}
}
}

struct ProcessStart;

enum ProcessEntry {
Expand Down Expand Up @@ -247,6 +284,9 @@ impl LocalProcess {
session: spawned.session,
tty: params.tty,
pipe_stdin: params.pipe_stdin,
accepted_stdin_write_ids: Arc::new(
Mutex::new(AcceptedStdinWriteIds::default()),
),
output: VecDeque::new(),
retained_bytes: 0,
next_seq: 1,
Expand Down Expand Up @@ -383,7 +423,11 @@ impl LocalProcess {
params: WriteParams,
) -> Result<WriteResponse, JSONRPCErrorError> {
let _input_bytes = params.chunk.0.len();
let writer_tx = {
if params.write_id.is_empty() {
return Err(invalid_params("writeId must not be empty".to_string()));
}

let (writer_tx, accepted_stdin_write_ids) = {
let process_map = self.inner.processes.lock().await;
let Some(process) = process_map.get(&params.process_id) else {
return Ok(WriteResponse {
Expand All @@ -400,13 +444,37 @@ impl LocalProcess {
status: WriteStatus::StdinClosed,
});
}
process.session.writer_sender()
(
process.session.writer_sender(),
Arc::clone(&process.accepted_stdin_write_ids),
)
};

writer_tx
.send(params.chunk.into_inner())
if accepted_stdin_write_ids
.lock()
.await
.contains(&params.write_id)
{
return Ok(WriteResponse {
status: WriteStatus::Accepted,
});
}

let permit = writer_tx
.reserve()
.await
.map_err(|_| internal_error("failed to write to process stdin".to_string()))?;
let mut accepted_stdin_write_ids = accepted_stdin_write_ids.lock().await;
if accepted_stdin_write_ids.contains(&params.write_id) {
return Ok(WriteResponse {
status: WriteStatus::Accepted,
});
}

// After this synchronous send, record the write id before any further await.
// Otherwise a cancelled RPC handler could retry and write the same bytes again.
permit.send(params.chunk.into_inner());
accepted_stdin_write_ids.remember(params.write_id);

Ok(WriteResponse {
status: WriteStatus::Accepted,
Expand Down Expand Up @@ -601,6 +669,10 @@ impl LocalProcess {
self.exec_write(WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
write_id: format!(
"local-{}",
NEXT_LOCAL_STDIN_WRITE_ID.fetch_add(1, Ordering::Relaxed)
),
})
.await
.map_err(map_handler_error)
Expand Down Expand Up @@ -1023,6 +1095,7 @@ mod tests {
session: dummy_session(),
tty: false,
pipe_stdin: false,
accepted_stdin_write_ids: Arc::new(Mutex::new(AcceptedStdinWriteIds::default())),
output: VecDeque::new(),
retained_bytes: 0,
next_seq: 1,
Expand Down
1 change: 1 addition & 0 deletions codex-rs/exec-server/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ pub struct ReadResponse {
pub struct WriteParams {
pub process_id: ProcessId,
pub chunk: ByteChunk,
pub write_id: String,
Comment thread
jif-oai marked this conversation as resolved.
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
Expand Down
Loading
Loading