Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 136 additions & 31 deletions src/listen_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use tokio::sync::{broadcast, mpsc};
use tokio::time::timeout;
use uuid::Uuid;

use crate::worker_request::{RequestWorkerError, DEFAULT_REQUEST_TIMEOUT};

const LISTEN_API_SEND_TIMEOUT: Duration = Duration::from_secs(30);

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -80,10 +82,27 @@ pub enum ListenApiRequest {
cols: u16,
reply: tokio::sync::oneshot::Sender<Result<Value, String>>,
},
Snapshot {
/// Generic worker request/response RPC: park a oneshot in the
/// broker's `pending_requests` map keyed by a fresh `request_id`,
/// frame the request, and ship it to the named worker over its
/// stdin pipe. The reply fires when the worker echoes a matching
/// `*_response` frame or the deadline elapses (whichever first).
///
/// Used by request/response routes like `GET /api/spawned/{name}/snapshot`.
/// Fire-and-forget routes (`send_input`, `resize_pty`) keep their
/// existing single-arm channel pattern.
WorkerRequest {
name: String,
format: SnapshotFormat,
reply: tokio::sync::oneshot::Sender<Result<Value, String>>,
/// Outbound frame `type`, e.g. `"snapshot_pty"`. The worker is
/// expected to reply with `"{kind}_response"`.
kind: String,
/// Worker stdin frame payload — must match the worker-side
/// schema for `kind`.
payload: Value,
/// Max wall-clock duration the broker will wait for the worker's
/// response before sending [`RequestWorkerError::Timeout`].
timeout: Duration,
reply: tokio::sync::oneshot::Sender<Result<Value, RequestWorkerError>>,
},
GetMetrics {
agent: Option<String>,
Expand Down Expand Up @@ -911,7 +930,15 @@ fn api_error(
)
}

/// Parse an error string like "agent_not_found: worker-a" into a (code, status) pair.
/// Parse an error string like "agent_not_found: worker-a" into a
/// (code, status) pair.
///
/// Used by routes that still surface stringly-typed errors (e.g.
/// `send_input`, `resize_pty`). Routes built on `WorkerRequest` go
/// through [`worker_request_error_to_response`] instead, which
/// preserves typed-error code/status mappings but falls back here for
/// the structured `RequestWorkerError::WorkerError` envelope so worker-
/// side codes like `invalid_format` keep producing 400s.
fn classify_error(err: &str) -> (axum::http::StatusCode, &str) {
if err.starts_with("agent_not_found") {
(axum::http::StatusCode::NOT_FOUND, "agent_not_found")
Expand All @@ -923,10 +950,10 @@ fn classify_error(err: &str) -> (axum::http::StatusCode, &str) {
// 409 Conflict — the request itself is well-formed; the conflict
// is with the resource's current capabilities.
(axum::http::StatusCode::CONFLICT, "unsupported_runtime")
} else if err.starts_with("snapshot_timeout") {
} else if err.starts_with("worker_timeout") {
// Worker died or stalled between accepting the frame and
// replying. This is a server-side fault, not a bad request.
(axum::http::StatusCode::GATEWAY_TIMEOUT, "snapshot_timeout")
(axum::http::StatusCode::GATEWAY_TIMEOUT, "worker_timeout")
} else if err.starts_with("invalid_") {
(axum::http::StatusCode::BAD_REQUEST, "invalid_request")
} else {
Expand Down Expand Up @@ -1043,9 +1070,11 @@ async fn listen_api_snapshot(
let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
if state
.tx
.send(ListenApiRequest::Snapshot {
.send(ListenApiRequest::WorkerRequest {
name: name.clone(),
format,
kind: "snapshot_pty".to_string(),
payload: json!({ "format": format.as_wire_str() }),
timeout: DEFAULT_REQUEST_TIMEOUT,
reply: reply_tx,
})
.await
Expand All @@ -1055,14 +1084,50 @@ async fn listen_api_snapshot(
}
match reply_rx.await {
Ok(Ok(val)) => (axum::http::StatusCode::OK, axum::Json(val)),
Ok(Err(ref err)) => {
let (status, code) = classify_error(err);
api_error(status, code, err.clone())
}
Ok(Err(err)) => worker_request_error_to_response(&err),
Err(_) => internal_error(),
}
}

/// Map a [`RequestWorkerError`] to an HTTP response. Centralised so every
/// route built on `WorkerRequest` produces consistent status codes.
fn worker_request_error_to_response(
err: &RequestWorkerError,
) -> (axum::http::StatusCode, axum::Json<Value>) {
use axum::http::StatusCode;
match err {
RequestWorkerError::WorkerNotFound(_) => {
api_error(StatusCode::NOT_FOUND, "agent_not_found", err.to_string())
}
RequestWorkerError::UnsupportedRuntime(_) => {
api_error(StatusCode::CONFLICT, "unsupported_runtime", err.to_string())
}
RequestWorkerError::Timeout => api_error(
StatusCode::GATEWAY_TIMEOUT,
"worker_timeout",
err.to_string(),
),
RequestWorkerError::WorkerError { code, message } => {
// Reuse classify_error so worker-side codes ("invalid_format",
// "agent_not_found", …) keep producing their canonical HTTP
// status. Any unknown code falls back to 400.
let composed = format!("{code}: {message}");
let (status, mapped_code) = classify_error(&composed);
let mapped_code = mapped_code.to_string();
api_error(status, &mapped_code, composed)
}
RequestWorkerError::SendFailed(_) => {
api_error(StatusCode::NOT_FOUND, "agent_not_found", err.to_string())
}
RequestWorkerError::WorkerDisappeared(_) => api_error(
StatusCode::SERVICE_UNAVAILABLE,
"worker_disappeared",
err.to_string(),
),
RequestWorkerError::ChannelClosed => internal_error(),
}
}

// ---------------------------------------------------------------------------
// Observability
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1567,6 +1632,7 @@ mod auth_tests {
use tower::ServiceExt;

use super::{listen_api_router_with_auth, ListenApiConfig, ListenApiRequest};
use crate::worker_request::RequestWorkerError;

fn test_router(
broker_api_key: Option<&str>,
Expand Down Expand Up @@ -2300,13 +2366,16 @@ mod auth_tests {
let (router, mut rx) = test_router(Some("secret"));
let replier = tokio::spawn(async move {
match rx.recv().await {
Some(ListenApiRequest::Snapshot {
Some(ListenApiRequest::WorkerRequest {
name,
format,
kind,
payload,
reply,
..
}) => {
assert_eq!(name, "worker-a");
assert_eq!(format, super::SnapshotFormat::Plain);
assert_eq!(kind, "snapshot_pty");
assert_eq!(payload["format"], json!("plain"));
let _ = reply.send(Ok(json!({
"format": "plain",
"rows": 4,
Expand Down Expand Up @@ -2344,13 +2413,16 @@ mod auth_tests {
let (router, mut rx) = test_router(Some("secret"));
let replier = tokio::spawn(async move {
match rx.recv().await {
Some(ListenApiRequest::Snapshot {
Some(ListenApiRequest::WorkerRequest {
name,
format,
kind,
payload,
reply,
..
}) => {
assert_eq!(name, "worker-a");
assert_eq!(format, super::SnapshotFormat::Ansi);
assert_eq!(kind, "snapshot_pty");
assert_eq!(payload["format"], json!("ansi"));
let _ = reply.send(Ok(json!({
"format": "ansi",
"rows": 2,
Expand Down Expand Up @@ -2401,16 +2473,18 @@ mod auth_tests {
let body = response_json(response).await;
assert_eq!(body["code"], json!("invalid_format"));

// The broker channel must not have received a Snapshot request.
// The broker channel must not have received a WorkerRequest.
assert!(rx.try_recv().is_err());
}

#[tokio::test]
async fn snapshot_route_propagates_agent_not_found_as_404() {
let (router, mut rx) = test_router(Some("secret"));
let replier = tokio::spawn(async move {
if let Some(ListenApiRequest::Snapshot { reply, .. }) = rx.recv().await {
let _ = reply.send(Err("agent_not_found: no worker named 'ghost'".to_string()));
if let Some(ListenApiRequest::WorkerRequest { reply, .. }) = rx.recv().await {
let _ = reply.send(Err(RequestWorkerError::WorkerNotFound(
"no worker named 'ghost'".to_string(),
)));
}
});

Expand All @@ -2427,18 +2501,20 @@ mod auth_tests {
.expect("request should succeed");

assert_eq!(response.status(), StatusCode::NOT_FOUND);
let body = response_json(response).await;
assert_eq!(body["code"], json!("agent_not_found"));
replier.await.expect("replier should complete");
}

#[tokio::test]
async fn snapshot_route_maps_unsupported_runtime_to_409() {
let (router, mut rx) = test_router(Some("secret"));
let replier = tokio::spawn(async move {
if let Some(ListenApiRequest::Snapshot { reply, .. }) = rx.recv().await {
let _ = reply.send(Err(
"unsupported_runtime: worker 'h' is headless; snapshot_pty is only supported on PTY workers"
if let Some(ListenApiRequest::WorkerRequest { reply, .. }) = rx.recv().await {
let _ = reply.send(Err(RequestWorkerError::UnsupportedRuntime(
"worker 'h' is headless; snapshot_pty is only supported on PTY workers"
.to_string(),
));
)));
}
});

Expand All @@ -2461,13 +2537,11 @@ mod auth_tests {
}

#[tokio::test]
async fn snapshot_route_maps_snapshot_timeout_to_504() {
async fn snapshot_route_maps_worker_timeout_to_504() {
let (router, mut rx) = test_router(Some("secret"));
let replier = tokio::spawn(async move {
if let Some(ListenApiRequest::Snapshot { reply, .. }) = rx.recv().await {
let _ = reply.send(Err(
"snapshot_timeout: worker did not respond in time".to_string()
));
if let Some(ListenApiRequest::WorkerRequest { reply, .. }) = rx.recv().await {
let _ = reply.send(Err(RequestWorkerError::Timeout));
}
});

Expand All @@ -2485,7 +2559,38 @@ mod auth_tests {

assert_eq!(response.status(), StatusCode::GATEWAY_TIMEOUT);
let body = response_json(response).await;
assert_eq!(body["code"], json!("snapshot_timeout"));
assert_eq!(body["code"], json!("worker_timeout"));
replier.await.expect("replier should complete");
}

#[tokio::test]
async fn snapshot_route_propagates_worker_error_envelope() {
let (router, mut rx) = test_router(Some("secret"));
let replier = tokio::spawn(async move {
if let Some(ListenApiRequest::WorkerRequest { reply, .. }) = rx.recv().await {
let _ = reply.send(Err(RequestWorkerError::WorkerError {
code: "invalid_format".to_string(),
message: "unsupported format 'qoi'".to_string(),
}));
}
});

let response = router
.oneshot(
Request::builder()
.uri("/api/spawned/worker-a/snapshot")
.method("GET")
.header("x-api-key", "secret")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should succeed");

// classify_error maps "invalid_*" prefixes to 400 / "invalid_request".
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body = response_json(response).await;
assert_eq!(body["code"], json!("invalid_request"));
replier.await.expect("replier should complete");
}
}
Loading
Loading