From 011e83b1d3ae2a1ea9b4260623cc949536826acb Mon Sep 17 00:00:00 2001 From: lif <> Date: Thu, 30 Jun 2022 23:03:54 +0000 Subject: [PATCH] Implement websockets in Dropshot as an Extractor and `#[channel]` macro This allows endpoints meant for handling continuous websocket streams to be easily constructed by providing an async handler that accepts the upgraded websocket connection as an argument. This change also refactors the previous `paginated` bool to instead be a dedicated type describing any 'extension' modes we may add atop the basic OpenAPI. --- Cargo.lock | 121 +++++++- dropshot/Cargo.toml | 5 + dropshot/examples/websocket.rs | 95 ++++++ dropshot/src/api_description.rs | 39 ++- dropshot/src/handler.rs | 61 ++-- dropshot/src/lib.rs | 8 + dropshot/src/router.rs | 2 +- dropshot/src/websocket.rs | 374 +++++++++++++++++++++++ dropshot/tests/fail/bad_endpoint3.stderr | 1 + dropshot/tests/fail/bad_endpoint4.stderr | 2 +- dropshot/tests/test_demo.rs | 52 ++++ dropshot_endpoint/src/lib.rs | 160 +++++++++- 12 files changed, 876 insertions(+), 44 deletions(-) create mode 100644 dropshot/examples/websocket.rs create mode 100644 dropshot/src/websocket.rs diff --git a/Cargo.lock b/Cargo.lock index 20db8c321..93a5b0757 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,7 +81,16 @@ dependencies = [ "block-padding", "byte-tools", "byteorder", - "generic-array", + "generic-array 0.12.4", +] + +[[package]] +name = "block-buffer" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +dependencies = [ + "generic-array 0.14.5", ] [[package]] @@ -176,6 +185,15 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" +[[package]] +name = "cpufeatures" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +dependencies = [ + "libc", +] + [[package]] name = "crossbeam-channel" version = "0.5.1" @@ -196,13 +214,33 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array 0.14.5", + "typenum", +] + [[package]] name = "digest" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5" dependencies = [ - "generic-array", + "generic-array 0.12.4", +] + +[[package]] +name = "digest" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +dependencies = [ + "block-buffer 0.10.2", + "crypto-common", ] [[package]] @@ -248,6 +286,7 @@ dependencies = [ "dropshot_endpoint", "expectorate", "futures", + "futures-util", "hostname", "http", "hyper", @@ -269,6 +308,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "sha1", "slog", "slog-async", "slog-bunyan", @@ -278,8 +318,10 @@ dependencies = [ "tempfile", "tokio", "tokio-rustls", + "tokio-tungstenite", "toml", "trybuild", + "tungstenite", "usdt", "uuid", "version_check", @@ -460,6 +502,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "generic-array" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.3" @@ -939,7 +991,7 @@ checksum = "54be6e404f5317079812fc8f9f5279de376d8856929e21c184ecf6bbd692a11d" dependencies = [ "maplit", "pest", - "sha-1", + "sha-1 0.8.2", ] [[package]] @@ -1312,12 +1364,34 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7d94d0bede923b3cea61f3f1ff57ff8cdfd77b400fb8f9998949e0cf04163df" dependencies = [ - "block-buffer", - "digest", + "block-buffer 0.7.3", + "digest 0.8.1", "fake-simd", "opaque-debug", ] +[[package]] +name = "sha-1" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + +[[package]] +name = "sha1" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77f4e7f65455545c2153c1253d25056825e77ee2533f0e41deb65a93a34852f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -1623,6 +1697,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-tungstenite" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.6.8" @@ -1693,6 +1779,25 @@ dependencies = [ "toml", ] +[[package]] +name = "tungstenite" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0" +dependencies = [ + "base64", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "sha-1 0.10.0", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.14.0" @@ -1827,6 +1932,12 @@ dependencies = [ "usdt-impl", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.1.2" diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 3fe074e3d..4f0a566a1 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -27,6 +27,7 @@ rustls = "0.20.6" rustls-pemfile = "1.0.1" serde_json = "1.0.83" serde_urlencoded = "0.7.1" +sha1 = "0.10.1" slog-async = "2.4.0" slog-bunyan = "2.4.0" slog-json = "2.6.1" @@ -84,6 +85,10 @@ trybuild = "1.0.64" # Used by the https examples and tests pem = "1.1" rcgen = "0.9.3" +# Used in a doc-test demonstrating the WebsocketUpgrade extractor. +tungstenite = "0.17.3" +tokio-tungstenite = "0.17.2" +futures-util = "0.3.21" [dev-dependencies.rustls] version = "0.20.6" diff --git a/dropshot/examples/websocket.rs b/dropshot/examples/websocket.rs new file mode 100644 index 000000000..a05c4b3c8 --- /dev/null +++ b/dropshot/examples/websocket.rs @@ -0,0 +1,95 @@ +// Copyright 2022 Oxide Computer Company +/*! + * Example use of Dropshot with a websocket endpoint. + */ + +use dropshot::channel; +use dropshot::ApiDescription; +use dropshot::ConfigDropshot; +use dropshot::ConfigLogging; +use dropshot::ConfigLoggingLevel; +use dropshot::HttpServerStarter; +use dropshot::Query; +use dropshot::RequestContext; +use dropshot::WebsocketConnection; +use futures_util::SinkExt; +use schemars::JsonSchema; +use serde::Deserialize; +use std::sync::Arc; +use tungstenite::protocol::Role; +use tungstenite::Message; + +#[tokio::main] +async fn main() -> Result<(), String> { + /* + * We must specify a configuration with a bind address. We'll use 127.0.0.1 + * since it's available and won't expose this server outside the host. We + * request port 0, which allows the operating system to pick any available + * port. + */ + let config_dropshot: ConfigDropshot = Default::default(); + + /* + * For simplicity, we'll configure an "info"-level logger that writes to + * stderr assuming that it's a terminal. + */ + let config_logging = + ConfigLogging::StderrTerminal { level: ConfigLoggingLevel::Info }; + let log = config_logging + .to_logger("example-basic") + .map_err(|error| format!("failed to create logger: {}", error))?; + + /* + * Build a description of the API. + */ + let mut api = ApiDescription::new(); + api.register(example_api_websocket_counter).unwrap(); + + /* + * Set up the server. + */ + let server = HttpServerStarter::new(&config_dropshot, api, (), &log) + .map_err(|error| format!("failed to create server: {}", error))? + .start(); + + /* + * Wait for the server to stop. Note that there's not any code to shut down + * this server, so we should never get past this point. + */ + server.await +} + +/* + * HTTP API interface + */ + +#[derive(Deserialize, JsonSchema)] +struct QueryParams { + start: Option, +} + +/** + * An eternally-increasing sequence of bytes, wrapping on overflow, starting + * from the value given for the query parameter "start." + */ +#[channel { + protocol = WEBSOCKETS, + path = "/counter", +}] +async fn example_api_websocket_counter( + _rqctx: Arc>, + upgraded: WebsocketConnection, + qp: Query, +) -> dropshot::WebsocketChannelResult { + let mut ws = tokio_tungstenite::WebSocketStream::from_raw_socket( + upgraded.into_inner(), + Role::Server, + None, + ) + .await; + let mut count = qp.into_inner().start.unwrap_or(0); + while ws.send(Message::Binary(vec![count])).await.is_ok() { + count = count.wrapping_add(1); + } + Ok(()) +} diff --git a/dropshot/src/api_description.rs b/dropshot/src/api_description.rs index fe6c9e668..e472c1879 100644 --- a/dropshot/src/api_description.rs +++ b/dropshot/src/api_description.rs @@ -45,7 +45,7 @@ pub struct ApiEndpoint { pub summary: Option, pub description: Option, pub tags: Vec, - pub paginated: bool, + pub extension_mode: ExtensionMode, pub visible: bool, } @@ -78,7 +78,7 @@ impl<'a, Context: ServerContext> ApiEndpoint { summary: None, description: None, tags: vec![], - paginated: func_parameters.paginated, + extension_mode: func_parameters.extension_mode, visible: true, } } @@ -688,11 +688,20 @@ impl ApiDescription { }) .next(); - if endpoint.paginated { - operation.extensions.insert( - crate::pagination::PAGINATION_EXTENSION.to_string(), - serde_json::json! {true}, - ); + match endpoint.extension_mode { + ExtensionMode::None => {} + ExtensionMode::Paginated => { + operation.extensions.insert( + crate::pagination::PAGINATION_EXTENSION.to_string(), + serde_json::json! {true}, + ); + } + ExtensionMode::Websocket => { + operation.extensions.insert( + crate::websocket::WEBSOCKET_EXTENSION.to_string(), + serde_json::json!({}), + ); + } } let response = if let Some(schema) = &endpoint.response.schema { @@ -1579,6 +1588,22 @@ pub struct TagExternalDocs { pub url: String, } +/** + * Dropshot/Progenitor features used by endpoints which are not a part of the base OpenAPI spec. + */ +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ExtensionMode { + None, + Paginated, + Websocket, +} + +impl Default for ExtensionMode { + fn default() -> Self { + ExtensionMode::None + } +} + #[cfg(test)] mod test { use super::j2oas_schema; diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index 88669bbe2..5f7575ec6 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -40,16 +40,17 @@ use super::http_util::CONTENT_TYPE_JSON; use super::http_util::CONTENT_TYPE_OCTET_STREAM; use super::server::DropshotState; use super::server::ServerContext; -use crate::api_description::ApiEndpointBodyContentType; use crate::api_description::ApiEndpointHeader; use crate::api_description::ApiEndpointParameter; use crate::api_description::ApiEndpointParameterLocation; use crate::api_description::ApiEndpointResponse; use crate::api_description::ApiSchemaGenerator; +use crate::api_description::{ApiEndpointBodyContentType, ExtensionMode}; use crate::pagination::PaginationParams; use crate::pagination::PAGINATION_PARAM_SENTINEL; use crate::router::VariableSet; use crate::to_map::to_map; +use crate::websocket::WEBSOCKET_PARAM_SENTINEL; use async_trait::async_trait; use bytes::Bytes; @@ -173,8 +174,8 @@ impl RequestContextArgument * `RequestContext`. Unlike most traits, `Extractor` essentially defines only a * constructor function, not instance functions. * - * The extractors that we provide (`Query`, `Path`, `TypedBody`, and - * `UntypedBody`) implement `Extractor` in order to construct themselves from + * The extractors that we provide (`Query`, `Path`, `TypedBody`, `UntypedBody`, and + * `WebsocketUpgrade`) implement `Extractor` in order to construct themselves from * the request. For example, `Extractor` is implemented for `Query` with a * function that reads the query string from the request, parses it, and * constructs a `Query` with it. @@ -202,7 +203,7 @@ pub trait Extractor: Send + Sync + Sized { * the associated endpoint is paginated. */ pub struct ExtractorMetadata { - pub paginated: bool, + pub extension_mode: ExtensionMode, pub parameters: Vec, } @@ -223,15 +224,21 @@ macro_rules! impl_extractor_for_tuple { fn metadata(_body_content_type: ApiEndpointBodyContentType) -> ExtractorMetadata { #[allow(unused_mut)] - let mut paginated = false; + let mut extension_mode = ExtensionMode::None; #[allow(unused_mut)] let mut parameters = vec![]; $( let mut metadata = $T::metadata(_body_content_type.clone()); - paginated = paginated | metadata.paginated; + extension_mode = match (extension_mode, metadata.extension_mode) { + (ExtensionMode::None, x) | (x, ExtensionMode::None) => x, + (x, y) if x != y => { + panic!("incompatible extension modes in tuple: {:?} != {:?}", x, y); + } + (_, x) => x, + }; parameters.append(&mut metadata.parameters); )* - ExtractorMetadata { paginated, parameters } + ExtractorMetadata { extension_mode, parameters } } } }} @@ -685,11 +692,23 @@ where ); let schema = generator.root_schema_for::().schema.into(); - let paginated = match schema_extensions(&schema) { + let extension_mode = match schema_extensions(&schema) { Some(extensions) => { - extensions.get(&PAGINATION_PARAM_SENTINEL.to_string()).is_some() + let paginated = extensions + .get(&PAGINATION_PARAM_SENTINEL.to_string()) + .is_some(); + let websocket = + extensions.get(&WEBSOCKET_PARAM_SENTINEL.to_string()).is_some(); + match (paginated, websocket) { + (false, false) => ExtensionMode::None, + (false, true) => ExtensionMode::Websocket, + (true, false) => ExtensionMode::Paginated, + (true, true) => panic!( + "Cannot use websocket and pagination in the same endpoint!" + ), + } } - None => false, + None => ExtensionMode::None, }; /* @@ -716,7 +735,7 @@ where }) .collect::>(); - ExtractorMetadata { paginated, parameters } + ExtractorMetadata { extension_mode, parameters } } fn schema_extensions( @@ -1039,7 +1058,10 @@ where }, vec![], ); - ExtractorMetadata { paginated: false, parameters: vec![body] } + ExtractorMetadata { + extension_mode: ExtensionMode::None, + parameters: vec![body], + } } } @@ -1116,7 +1138,7 @@ impl Extractor for UntypedBody { }, vec![], )], - paginated: false, + extension_mode: ExtensionMode::None, } } } @@ -1591,6 +1613,7 @@ fn schema_extract_description( #[cfg(test)] mod test { + use crate::api_description::ExtensionMode; use crate::{ api_description::ApiEndpointParameterMetadata, ApiEndpointParameter, ApiEndpointParameterLocation, PaginationParams, @@ -1628,10 +1651,10 @@ mod test { fn compare( actual: ExtractorMetadata, - paginated: bool, + extension_mode: ExtensionMode, parameters: Vec<(&str, bool)>, ) { - assert_eq!(actual.paginated, paginated); + assert_eq!(actual.extension_mode, extension_mode); /* * This is order-dependent. We might not really care if the order @@ -1659,7 +1682,7 @@ mod test { let params = get_metadata::(&ApiEndpointParameterLocation::Path); let expected = vec![("bar", true), ("baz", false), ("foo", true)]; - compare(params, false, expected); + compare(params, ExtensionMode::None, expected); } #[test] @@ -1672,7 +1695,7 @@ mod test { ("limit", false), ]; - compare(params, false, expected); + compare(params, ExtensionMode::None, expected); } #[test] @@ -1687,7 +1710,7 @@ mod test { ("page_token", false), ]; - compare(params, false, expected); + compare(params, ExtensionMode::None, expected); } #[test] @@ -1703,6 +1726,6 @@ mod test { ("page_token", false), ]; - compare(params, true, expected); + compare(params, ExtensionMode::Paginated, expected); } } diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index d758c5988..6d29c4d37 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -611,6 +611,7 @@ mod router; mod server; mod to_map; mod type_util; +mod websocket; pub mod test_util; @@ -624,6 +625,7 @@ pub use api_description::ApiEndpointParameter; pub use api_description::ApiEndpointParameterLocation; pub use api_description::ApiEndpointResponse; pub use api_description::EndpointTagPolicy; +pub use api_description::ExtensionMode; pub use api_description::OpenApiDefinition; pub use api_description::TagConfig; pub use api_description::TagDetails; @@ -664,6 +666,11 @@ pub use pagination::ResultsPage; pub use pagination::WhichPage; pub use server::ServerContext; pub use server::{HttpServer, HttpServerStarter}; +pub use websocket::WebsocketChannelResult; +pub use websocket::WebsocketConnection; +pub use websocket::WebsocketConnectionRaw; +pub use websocket::WebsocketEndpointResult; +pub use websocket::WebsocketUpgrade; /* * Users of the `endpoint` macro need the following macros: @@ -672,4 +679,5 @@ pub use handler::RequestContextArgument; pub use http::Method; extern crate dropshot_endpoint; +pub use dropshot_endpoint::channel; pub use dropshot_endpoint::endpoint; diff --git a/dropshot/src/router.rs b/dropshot/src/router.rs index 73e53cbb8..06e90200c 100644 --- a/dropshot/src/router.rs +++ b/dropshot/src/router.rs @@ -822,7 +822,7 @@ mod test { summary: None, description: None, tags: vec![], - paginated: false, + extension_mode: Default::default(), visible: true, } } diff --git a/dropshot/src/websocket.rs b/dropshot/src/websocket.rs new file mode 100644 index 000000000..14247a7cd --- /dev/null +++ b/dropshot/src/websocket.rs @@ -0,0 +1,374 @@ +// Copyright 2022 Oxide Computer Company +/*! + * Implements websocket upgrades as an Extractor for use in API route handler + * parameters to indicate that the given endpoint is meant to be upgraded to + * a websocket. + * + * This exposes a raw upgraded HTTP connection to a user-provided async future, + * which will be spawned to handle the incoming connection. + */ + +use crate::api_description::ExtensionMode; +use crate::{ + ApiEndpointBodyContentType, Extractor, ExtractorMetadata, HttpError, + RequestContext, ServerContext, +}; +use async_trait::async_trait; +use http::header; +use http::Response; +use http::StatusCode; +use hyper::upgrade::OnUpgrade; +use hyper::Body; +use schemars::JsonSchema; +use serde_json::json; +use sha1::{Digest, Sha1}; +use slog::Logger; +use std::future::Future; +use std::sync::Arc; + +/** + * WebsocketUpgrade is an Extractor used to upgrade and handle an HTTP request + * as a websocket when present in a Dropshot endpoint's function arguments. + * + * The consumer of this must call [WebsocketUpgrade::handle] for the connection + * to be upgraded. (This is done for you by `#[channel]`.) + */ +#[derive(Debug)] +pub struct WebsocketUpgrade(Option); + +/** + * This is the return type of the websocket-handling future provided to + * [`dropshot_endpoint::channel`] + * (which in turn provides it to [WebsocketUpgrade::handle]). + */ +pub type WebsocketChannelResult = + Result<(), Box>; + +/** + * [WebsocketUpgrade::handle]'s return type. + * The `#[endpoint]` handler must return the value returned by + * [WebsocketUpgrade::handle]. (This is done for you by `#[channel]`.) + */ +pub type WebsocketEndpointResult = Result, HttpError>; + +/** + * The upgraded connection passed as the second argument to the websocket + * handler function. [`WebsocketConnection::into_inner`] can be used to + * access the raw upgraded connection, for passing to any implementation + * of the websockets protocol. + */ +pub struct WebsocketConnection(WebsocketConnectionRaw); + +/// A type that implements [tokio::io::AsyncRead] + [tokio::io::AsyncWrite]. +pub type WebsocketConnectionRaw = hyper::upgrade::Upgraded; + +impl WebsocketConnection { + /// Consumes `self` and returns the held raw connection. + pub fn into_inner(self) -> WebsocketConnectionRaw { + self.0 + } +} + +#[derive(Debug)] +struct WebsocketUpgradeInner { + upgrade_fut: OnUpgrade, + accept_key: String, + route: String, + ws_log: Logger, +} + +// Borrowed from tungstenite-0.17.3 (rather than taking a whole dependency for this one function) +fn derive_accept_key(request_key: &[u8]) -> String { + // ... field is constructed by concatenating /key/ ... + // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::default(); + sha1.update(request_key); + sha1.update(WS_GUID); + base64::encode(&sha1.finalize()) +} + +/** + * This `Extractor` implementation constructs an instance of `WebsocketUpgrade` + * from an HTTP request, and returns an error if the given request does not + * contain websocket upgrade headers. + */ +#[async_trait] +impl Extractor for WebsocketUpgrade { + async fn from_request( + rqctx: Arc>, + ) -> Result { + let request = &mut *rqctx.request.lock().await; + + if !request + .headers() + .get(header::CONNECTION) + .and_then(|hv| hv.to_str().ok()) + .map(|hv| { + hv.split(|c| c == ',' || c == ' ') + .any(|vs| vs.eq_ignore_ascii_case("upgrade")) + }) + .unwrap_or(false) + { + return Err(HttpError::for_bad_request( + None, + "expected connection upgrade".to_string(), + )); + } + + if !request + .headers() + .get(header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| { + v.split(|c| c == ',' || c == ' ') + .any(|v| v.eq_ignore_ascii_case("websocket")) + }) + .unwrap_or(false) + { + return Err(HttpError::for_bad_request( + None, + "unexpected protocol for upgrade".to_string(), + )); + } + + if request + .headers() + .get(header::SEC_WEBSOCKET_VERSION) + .map(|v| v.as_bytes()) + != Some(b"13") + { + return Err(HttpError::for_bad_request( + None, + "missing or invalid websocket version".to_string(), + )); + } + + let accept_key = request + .headers() + .get(header::SEC_WEBSOCKET_KEY) + .map(|hv| hv.as_bytes()) + .map(|key| derive_accept_key(key)) + .ok_or_else(|| { + HttpError::for_bad_request( + None, + "missing websocket key".to_string(), + ) + })?; + + let route = request.uri().to_string(); + let upgrade_fut = hyper::upgrade::on(request); + // note: this is just used in our wrapper in `handle`; if a user wants + // to slog in their future, they can obtain it from rqctx the same way + // they do in any other endpoint & let it get `move`d into the closure + let ws_log = rqctx.log.new(o!( + "upgrade" => "websocket".to_string(), + )); + + Ok(Self(Some(WebsocketUpgradeInner { + upgrade_fut, + accept_key, + ws_log, + route, + }))) + } + + fn metadata( + _content_type: ApiEndpointBodyContentType, + ) -> ExtractorMetadata { + ExtractorMetadata { + parameters: vec![], + extension_mode: ExtensionMode::Websocket, + } + } +} + +impl WebsocketUpgrade { + /** + * Upgrade the HTTP connection to a websocket and spawn a user-provided + * async handler to service it. + * + * This function's return value should be the basis of the return value of + * your endpoint's function, as it sends the headers to tell the HTTP + * client that we are accepting the upgrade. + * + * `handler` is a closure that accepts a [`WebsocketConnection`] + * and returns a future that will be spawned by this function, + * in which the `WebsocketConnection`'s inner `Upgraded` connection may be + * used with your choice of websocket-handling code operating over an + * [`tokio::io::AsyncRead`] + [`tokio::io::AsyncWrite`] type + * (e.g. `tokio_tungstenite`). + * + * ``` + #[dropshot::endpoint { method = GET, path = "/my/ws/endpoint/{id}" }] + async fn my_ws_endpoint( + rqctx: std::sync::Arc>, + websock: dropshot::WebsocketUpgrade, + id: dropshot::Path, + ) -> dropshot::WebsocketEndpointResult { + let logger = rqctx.log.new(slog::o!()); + websock.handle(move |upgraded| async move { + slog::info!(logger, "Entered handler for ID {}", id.into_inner()); + use futures_util::stream::StreamExt; + let mut ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket( + upgraded.into_inner(), tungstenite::protocol::Role::Server, None + ).await; + slog::info!(logger, "Received from websocket: {:?}", ws_stream.next().await); + Ok(()) + }) + } + * ``` + * + * Note that as a consumer of this crate, you most likely do not want to + * call this function directly; rather, prefer to annotate your function + * with [`dropshot_endpoint::channel`] instead of `endpoint`. + */ + pub fn handle(mut self, handler: C) -> WebsocketEndpointResult + where + C: FnOnce(WebsocketConnection) -> F + Send + 'static, + F: Future + Send + 'static, + { + // we .take() here to tell Drop::drop that we handled the request. + match self.0.take() { + None => Err(HttpError::for_internal_error( + "Tried to handle websocket twice".to_string(), + )), + Some(WebsocketUpgradeInner { + upgrade_fut, + accept_key, + ws_log, + .. + }) => { + tokio::spawn(async move { + match upgrade_fut.await { + Ok(upgrade) => { + match handler(WebsocketConnection(upgrade)).await { + Ok(x) => Ok(x), + Err(e) => { + error!( + ws_log, + "Error returned from handler: {:?}", e + ); + Err(e) + } + } + } + Err(e) => { + error!( + ws_log, + "Error upgrading connection: {:?}", e + ); + Err(e.into()) + } + } + }); + Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, "Upgrade") + .header(header::UPGRADE, "websocket") + .header(header::SEC_WEBSOCKET_ACCEPT, accept_key) + .body(Body::empty()) + .map_err(Into::into) + } + } + } +} + +impl Drop for WebsocketUpgrade { + fn drop(&mut self) { + if let Some(inner) = self.0.take() { + debug!( + inner.ws_log, + "Didn't handle websocket in route {}", inner.route + ); + } + } +} + +// To indicate websocket usage by the endpoint to code generators (i.e. Progenitor) +pub(crate) const WEBSOCKET_EXTENSION: &str = "x-dropshot-websocket"; +pub(crate) const WEBSOCKET_PARAM_SENTINEL: &str = "x-dropshot-websocket-param"; + +impl JsonSchema for WebsocketUpgrade { + fn schema_name() -> String { + "WebsocketUpgrade".to_string() + } + + fn json_schema( + _gen: &mut schemars::gen::SchemaGenerator, + ) -> schemars::schema::Schema { + let mut schema = schemars::schema::SchemaObject::default(); + schema + .extensions + .insert(WEBSOCKET_PARAM_SENTINEL.to_string(), json!(true)); + schemars::schema::Schema::Object(schema) + } +} + +#[cfg(test)] +mod tests { + use crate::router::HttpRouter; + use crate::server::{DropshotState, ServerConfig}; + use crate::{Extractor, HttpError, RequestContext, WebsocketUpgrade}; + use futures::lock::Mutex; + use http::Request; + use hyper::Body; + use std::net::{IpAddr, Ipv6Addr, SocketAddr}; + use std::num::NonZeroU32; + use std::sync::Arc; + use std::time::Duration; + + async fn ws_upg_from_mock_rqctx() -> Result { + let log = slog::Logger::root(slog::Discard, slog::o!()).new(slog::o!()); + let fut = WebsocketUpgrade::from_request(Arc::new(RequestContext { + server: Arc::new(DropshotState { + private: (), + config: ServerConfig { + request_body_max_bytes: 0, + page_max_nitems: NonZeroU32::new(1).unwrap(), + page_default_nitems: NonZeroU32::new(1).unwrap(), + }, + router: HttpRouter::new(), + log: log.clone(), + local_addr: SocketAddr::new( + IpAddr::V6(Ipv6Addr::LOCALHOST), + 8080, + ), + tls: false, + }), + request: Arc::new(Mutex::new( + Request::builder() + .header(http::header::CONNECTION, "Upgrade") + .header(http::header::UPGRADE, "websocket") + .header(http::header::SEC_WEBSOCKET_VERSION, "13") + .header( + http::header::SEC_WEBSOCKET_KEY, + "aGFjayB0aGUgcGxhbmV0IQ==", + ) + .body(Body::empty()) + .unwrap(), + )), + path_variables: Default::default(), + body_content_type: Default::default(), + request_id: "".to_string(), + log: log.clone(), + })); + tokio::time::timeout(Duration::from_secs(1), fut) + .await + .expect("Deadlocked in WebsocketUpgrade constructor") + } + + #[tokio::test] + async fn test_ws_upg_task_is_spawned() { + let (send, recv) = tokio::sync::oneshot::channel(); + ws_upg_from_mock_rqctx() + .await + .unwrap() + .handle(move |_upgrade| async move { Ok(send.send(()).unwrap()) }) + .unwrap(); + // note: not a real connection, so we don't get our future's Ok, but we *do* spawn the task + let _ = tokio::time::timeout(Duration::from_secs(1), recv) + .await + .expect("Task not spawned or never completed"); + } +} diff --git a/dropshot/tests/fail/bad_endpoint3.stderr b/dropshot/tests/fail/bad_endpoint3.stderr index 6e7e4c552..290bb48c6 100644 --- a/dropshot/tests/fail/bad_endpoint3.stderr +++ b/dropshot/tests/fail/bad_endpoint3.stderr @@ -11,6 +11,7 @@ error[E0277]: the trait bound `String: Extractor` is not satisfied (T1,) TypedBody UntypedBody + WebsocketUpgrade dropshot::Path dropshot::Query note: required by a bound in `need_extractor` diff --git a/dropshot/tests/fail/bad_endpoint4.stderr b/dropshot/tests/fail/bad_endpoint4.stderr index 048332640..137ae560a 100644 --- a/dropshot/tests/fail/bad_endpoint4.stderr +++ b/dropshot/tests/fail/bad_endpoint4.stderr @@ -13,7 +13,7 @@ error[E0277]: the trait bound `QueryParams: schemars::JsonSchema` is not satisfi (T0, T1, T2, T3) (T0, T1, T2, T3, T4) (T0, T1, T2, T3, T4, T5) - and 145 others + and 146 others note: required by a bound in `dropshot::Query` --> src/handler.rs | diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index 1a67c2d4e..f30882586 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -15,6 +15,7 @@ * JSON body length) */ +use dropshot::channel; use dropshot::endpoint; use dropshot::test_util::object_delete; use dropshot::test_util::read_json; @@ -32,7 +33,11 @@ use dropshot::Query; use dropshot::RequestContext; use dropshot::TypedBody; use dropshot::UntypedBody; +use dropshot::WebsocketChannelResult; +use dropshot::WebsocketConnection; use dropshot::CONTENT_TYPE_JSON; +use futures_util::SinkExt; +use futures_util::StreamExt; use http::StatusCode; use hyper::Body; use hyper::Method; @@ -41,6 +46,9 @@ use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; use std::sync::Arc; +use tokio_tungstenite::WebSocketStream; +use tungstenite::protocol::Role; +use tungstenite::Message; use uuid::Uuid; extern crate slog; @@ -60,6 +68,7 @@ fn demo_api() -> ApiDescription { api.register(demo_handler_untyped_body).unwrap(); api.register(demo_handler_delete).unwrap(); api.register(demo_handler_headers).unwrap(); + api.register(demo_handler_websocket).unwrap(); /* * We don't need to exhaustively test these cases, as they're tested by unit @@ -734,6 +743,28 @@ async fn test_header_request() { assert_eq!(headers, vec!["hi", "howdy"]); } +/* + * The "test_demo_websocket" handler upgrades to a websocket and exchanges + * greetings with the client. + */ +#[tokio::test] +async fn test_demo_websocket() { + let api = demo_api(); + let testctx = common::test_setup("demo_websocket", api); + + let path = format!( + "ws://{}/testing/websocket", + testctx.client_testctx.bind_address + ); + let (mut ws, _resp) = tokio_tungstenite::connect_async(path).await.unwrap(); + + ws.send(Message::Text("hello server".to_string())).await.unwrap(); + let msg = ws.next().await.unwrap().unwrap(); + assert_eq!(msg, Message::Text("hello client".to_string())); + + testctx.teardown().await; +} + /* * Demo handler functions */ @@ -933,6 +964,27 @@ async fn demo_handler_headers( Ok(response) } +#[channel { + protocol = WEBSOCKETS, + path = "/testing/websocket" +}] +async fn demo_handler_websocket( + rqctx: RequestCtx, + upgraded: WebsocketConnection, +) -> WebsocketChannelResult { + use futures_util::stream::StreamExt; + let mut ws_stream = WebSocketStream::from_raw_socket( + upgraded.into_inner(), + Role::Server, + None, + ) + .await; + ws_stream.send(Message::Text("hello client".to_string())).await.unwrap(); + let msg = ws_stream.next().await.unwrap().unwrap(); + slog::info!(rqctx.log, "{}", msg); + Ok(()) +} + fn http_echo(t: &T) -> Result, HttpError> { Ok(Response::builder() .header(http::header::CONTENT_TYPE, CONTENT_TYPE_JSON) diff --git a/dropshot_endpoint/src/lib.rs b/dropshot_endpoint/src/lib.rs index 1cfe4d3a2..ea49a0cd0 100644 --- a/dropshot_endpoint/src/lib.rs +++ b/dropshot_endpoint/src/lib.rs @@ -10,13 +10,13 @@ */ #![allow(clippy::style)] -use proc_macro2::TokenStream; use quote::format_ident; use quote::quote; use quote::{quote_spanned, ToTokens}; use serde::Deserialize; use serde_tokenstream::from_tokenstream; use serde_tokenstream::Error; +use std::ops::DerefMut; use syn::spanned::Spanned; use syn_parsing::ItemFnForSignature; @@ -48,7 +48,7 @@ impl MethodType { } #[derive(Deserialize, Debug)] -struct Metadata { +struct EndpointMetadata { method: MethodType, path: String, tags: Option>, @@ -57,6 +57,21 @@ struct Metadata { _dropshot_crate: Option, } +#[allow(non_snake_case)] +#[derive(Deserialize, Debug)] +enum ChannelProtocol { + WEBSOCKETS, +} + +#[derive(Deserialize, Debug)] +struct ChannelMetadata { + protocol: ChannelProtocol, + path: String, + tags: Option>, + unpublished: Option, + _dropshot_crate: Option, +} + const DROPSHOT: &str = "dropshot"; const USAGE: &str = "Endpoint handlers must have the following signature: async fn( @@ -97,7 +112,131 @@ pub fn endpoint( attr: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - match do_endpoint(attr.into(), item.into()) { + do_output(do_endpoint(attr.into(), item.into())) +} + +fn do_endpoint( + attr: proc_macro2::TokenStream, + item: proc_macro2::TokenStream, +) -> Result<(proc_macro2::TokenStream, Vec), Error> { + let metadata = from_tokenstream(&attr)?; + // factored this way for now so #[channel] can use it too + do_endpoint_inner(metadata, attr, item) +} + +/// As with [`endpoint`], this attribute turns a handler function into a +/// Dropshot endpoint, but first wraps the handler function in such a way +/// that is spawned asynchronously and given the upgraded connection of +/// the given `protocol` (i.e. `WEBSOCKETS`). +/// +/// The first argument still must be an `Arc>`. +/// +/// The second argument passed to the handler function must be a +/// [`dropshot::WebsocketConnection`]. +/// +/// The function must return a [`dropshot::WebsocketChannelResult`] (which is +/// a general-purpose `Result<(), Box>`). +/// Returned error values will be written to the RequestContext's log. +/// +/// ```ignore +/// #[dropshot::channel { protocol = WEBSOCKETS, path = "/my/ws/channel/{id}" }] +/// ``` +#[proc_macro_attribute] +pub fn channel( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + do_output(do_channel(attr.into(), item.into())) +} + +fn do_channel( + attr: proc_macro2::TokenStream, + item: proc_macro2::TokenStream, +) -> Result<(proc_macro2::TokenStream, Vec), Error> { + let ChannelMetadata { protocol, path, tags, unpublished, _dropshot_crate } = + from_tokenstream(&attr)?; + match protocol { + ChannelProtocol::WEBSOCKETS => { + // here we construct a wrapper function and mutate the arguments a bit + // for the outer layer: we replace WebsocketConnection, which is not + // an extractor, with WebsocketUpgrade, which is. + let ItemFnForSignature { attrs, vis, mut sig, _block: body } = + syn::parse2(item)?; + + let inner_args = sig.inputs.clone(); + let inner_output = sig.output.clone(); + + let arg_names: Vec<_> = inner_args + .iter() + .map(|arg: &syn::FnArg| match arg { + syn::FnArg::Receiver(r) => r.self_token.to_token_stream(), + syn::FnArg::Typed(syn::PatType { pat, .. }) => { + pat.to_token_stream() + } + }) + .collect(); + let found = sig.inputs.iter_mut().nth(1).and_then(|arg| { + if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = arg { + if let syn::Pat::Ident(syn::PatIdent { + ident, + by_ref: None, + .. + }) = pat.deref_mut() + { + let conn_type = ty.clone(); + let conn_name = ident.clone(); + let span = ident.span(); + *ident = syn::Ident::new( + "__dropshot_websocket_upgrade", + span, + ); + *ty = Box::new(syn::Type::Verbatim( + quote! { dropshot::WebsocketUpgrade }, + )); + return Some((conn_name, conn_type)); + } + } + return None; + }); + if found.is_none() { + return Err(Error::new_spanned( + &attr, + "An argument of type dropshot::WebsocketConnection must be provided immediately following Arc>.", + )); + } + + sig.output = + syn::parse2(quote!(-> dropshot::WebsocketEndpointResult))?; + + let (conn_name, conn_type) = found.unwrap(); + + let new_item = quote! { + #(#attrs)* + #vis #sig { + async fn __dropshot_websocket_handler(#inner_args) #inner_output #body + __dropshot_websocket_upgrade.handle(move | #conn_name: #conn_type | async move { + __dropshot_websocket_handler(#(#arg_names),*).await + }) + } + }; + + let metadata = EndpointMetadata { + method: MethodType::GET, + path, + tags, + unpublished, + content_type: Some("application/json".to_string()), + _dropshot_crate, + }; + do_endpoint_inner(metadata, attr, new_item) + } + } +} + +fn do_output( + res: Result<(proc_macro2::TokenStream, Vec), Error>, +) -> proc_macro::TokenStream { + match res { Err(err) => err.to_compile_error().into(), Ok((endpoint, errors)) => { let compiler_errors = @@ -113,11 +252,12 @@ pub fn endpoint( } } -fn do_endpoint( - attr: TokenStream, - item: TokenStream, -) -> Result<(TokenStream, Vec), Error> { - let metadata = from_tokenstream::(&attr)?; +fn do_endpoint_inner( + metadata: EndpointMetadata, + attr: proc_macro2::TokenStream, + item: proc_macro2::TokenStream, +) -> Result<(proc_macro2::TokenStream, Vec), Error> { + let ast: ItemFnForSignature = syn::parse2(item.clone())?; let method = metadata.method.as_str(); let path = metadata.path; let content_type = @@ -132,8 +272,6 @@ fn do_endpoint( )); } - let ast: ItemFnForSignature = syn::parse2(item.clone())?; - let mut errors = Vec::new(); if ast.sig.constness.is_some() { @@ -472,7 +610,7 @@ fn do_endpoint( Ok((stream, errors)) } -fn get_crate(var: Option) -> TokenStream { +fn get_crate(var: Option) -> proc_macro2::TokenStream { if let Some(s) = var { if let Ok(ts) = syn::parse_str(s.as_str()) { return ts;