diff --git a/Cargo.lock b/Cargo.lock index 20db8c321..93a5b0757 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,7 +81,16 @@ dependencies = [ "block-padding", "byte-tools", "byteorder", - "generic-array", + "generic-array 0.12.4", +] + +[[package]] +name = "block-buffer" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +dependencies = [ + "generic-array 0.14.5", ] [[package]] @@ -176,6 +185,15 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" +[[package]] +name = "cpufeatures" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +dependencies = [ + "libc", +] + [[package]] name = "crossbeam-channel" version = "0.5.1" @@ -196,13 +214,33 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array 0.14.5", + "typenum", +] + [[package]] name = "digest" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5" dependencies = [ - "generic-array", + "generic-array 0.12.4", +] + +[[package]] +name = "digest" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +dependencies = [ + "block-buffer 0.10.2", + "crypto-common", ] [[package]] @@ -248,6 +286,7 @@ dependencies = [ "dropshot_endpoint", "expectorate", "futures", + "futures-util", "hostname", "http", "hyper", @@ -269,6 +308,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "sha1", "slog", "slog-async", "slog-bunyan", @@ -278,8 +318,10 @@ dependencies = [ "tempfile", "tokio", "tokio-rustls", + "tokio-tungstenite", "toml", "trybuild", + "tungstenite", "usdt", "uuid", "version_check", @@ -460,6 +502,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "generic-array" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.3" @@ -939,7 +991,7 @@ checksum = "54be6e404f5317079812fc8f9f5279de376d8856929e21c184ecf6bbd692a11d" dependencies = [ "maplit", "pest", - "sha-1", + "sha-1 0.8.2", ] [[package]] @@ -1312,12 +1364,34 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7d94d0bede923b3cea61f3f1ff57ff8cdfd77b400fb8f9998949e0cf04163df" dependencies = [ - "block-buffer", - "digest", + "block-buffer 0.7.3", + "digest 0.8.1", "fake-simd", "opaque-debug", ] +[[package]] +name = "sha-1" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + +[[package]] +name = "sha1" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77f4e7f65455545c2153c1253d25056825e77ee2533f0e41deb65a93a34852f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -1623,6 +1697,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-tungstenite" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.6.8" @@ -1693,6 +1779,25 @@ dependencies = [ "toml", ] +[[package]] +name = "tungstenite" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0" +dependencies = [ + "base64", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "sha-1 0.10.0", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.14.0" @@ -1827,6 +1932,12 @@ dependencies = [ "usdt-impl", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.1.2" diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 3fe074e3d..4f0a566a1 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -27,6 +27,7 @@ rustls = "0.20.6" rustls-pemfile = "1.0.1" serde_json = "1.0.83" serde_urlencoded = "0.7.1" +sha1 = "0.10.1" slog-async = "2.4.0" slog-bunyan = "2.4.0" slog-json = "2.6.1" @@ -84,6 +85,10 @@ trybuild = "1.0.64" # Used by the https examples and tests pem = "1.1" rcgen = "0.9.3" +# Used in a doc-test demonstrating the WebsocketUpgrade extractor. +tungstenite = "0.17.3" +tokio-tungstenite = "0.17.2" +futures-util = "0.3.21" [dev-dependencies.rustls] version = "0.20.6" diff --git a/dropshot/examples/websocket.rs b/dropshot/examples/websocket.rs new file mode 100644 index 000000000..a05c4b3c8 --- /dev/null +++ b/dropshot/examples/websocket.rs @@ -0,0 +1,95 @@ +// Copyright 2022 Oxide Computer Company +/*! + * Example use of Dropshot with a websocket endpoint. + */ + +use dropshot::channel; +use dropshot::ApiDescription; +use dropshot::ConfigDropshot; +use dropshot::ConfigLogging; +use dropshot::ConfigLoggingLevel; +use dropshot::HttpServerStarter; +use dropshot::Query; +use dropshot::RequestContext; +use dropshot::WebsocketConnection; +use futures_util::SinkExt; +use schemars::JsonSchema; +use serde::Deserialize; +use std::sync::Arc; +use tungstenite::protocol::Role; +use tungstenite::Message; + +#[tokio::main] +async fn main() -> Result<(), String> { + /* + * We must specify a configuration with a bind address. We'll use 127.0.0.1 + * since it's available and won't expose this server outside the host. We + * request port 0, which allows the operating system to pick any available + * port. + */ + let config_dropshot: ConfigDropshot = Default::default(); + + /* + * For simplicity, we'll configure an "info"-level logger that writes to + * stderr assuming that it's a terminal. + */ + let config_logging = + ConfigLogging::StderrTerminal { level: ConfigLoggingLevel::Info }; + let log = config_logging + .to_logger("example-basic") + .map_err(|error| format!("failed to create logger: {}", error))?; + + /* + * Build a description of the API. + */ + let mut api = ApiDescription::new(); + api.register(example_api_websocket_counter).unwrap(); + + /* + * Set up the server. + */ + let server = HttpServerStarter::new(&config_dropshot, api, (), &log) + .map_err(|error| format!("failed to create server: {}", error))? + .start(); + + /* + * Wait for the server to stop. Note that there's not any code to shut down + * this server, so we should never get past this point. + */ + server.await +} + +/* + * HTTP API interface + */ + +#[derive(Deserialize, JsonSchema)] +struct QueryParams { + start: Option, +} + +/** + * An eternally-increasing sequence of bytes, wrapping on overflow, starting + * from the value given for the query parameter "start." + */ +#[channel { + protocol = WEBSOCKETS, + path = "/counter", +}] +async fn example_api_websocket_counter( + _rqctx: Arc>, + upgraded: WebsocketConnection, + qp: Query, +) -> dropshot::WebsocketChannelResult { + let mut ws = tokio_tungstenite::WebSocketStream::from_raw_socket( + upgraded.into_inner(), + Role::Server, + None, + ) + .await; + let mut count = qp.into_inner().start.unwrap_or(0); + while ws.send(Message::Binary(vec![count])).await.is_ok() { + count = count.wrapping_add(1); + } + Ok(()) +} diff --git a/dropshot/src/api_description.rs b/dropshot/src/api_description.rs index fe6c9e668..e472c1879 100644 --- a/dropshot/src/api_description.rs +++ b/dropshot/src/api_description.rs @@ -45,7 +45,7 @@ pub struct ApiEndpoint { pub summary: Option, pub description: Option, pub tags: Vec, - pub paginated: bool, + pub extension_mode: ExtensionMode, pub visible: bool, } @@ -78,7 +78,7 @@ impl<'a, Context: ServerContext> ApiEndpoint { summary: None, description: None, tags: vec![], - paginated: func_parameters.paginated, + extension_mode: func_parameters.extension_mode, visible: true, } } @@ -688,11 +688,20 @@ impl ApiDescription { }) .next(); - if endpoint.paginated { - operation.extensions.insert( - crate::pagination::PAGINATION_EXTENSION.to_string(), - serde_json::json! {true}, - ); + match endpoint.extension_mode { + ExtensionMode::None => {} + ExtensionMode::Paginated => { + operation.extensions.insert( + crate::pagination::PAGINATION_EXTENSION.to_string(), + serde_json::json! {true}, + ); + } + ExtensionMode::Websocket => { + operation.extensions.insert( + crate::websocket::WEBSOCKET_EXTENSION.to_string(), + serde_json::json!({}), + ); + } } let response = if let Some(schema) = &endpoint.response.schema { @@ -1579,6 +1588,22 @@ pub struct TagExternalDocs { pub url: String, } +/** + * Dropshot/Progenitor features used by endpoints which are not a part of the base OpenAPI spec. + */ +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ExtensionMode { + None, + Paginated, + Websocket, +} + +impl Default for ExtensionMode { + fn default() -> Self { + ExtensionMode::None + } +} + #[cfg(test)] mod test { use super::j2oas_schema; diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index 88669bbe2..5f7575ec6 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -40,16 +40,17 @@ use super::http_util::CONTENT_TYPE_JSON; use super::http_util::CONTENT_TYPE_OCTET_STREAM; use super::server::DropshotState; use super::server::ServerContext; -use crate::api_description::ApiEndpointBodyContentType; use crate::api_description::ApiEndpointHeader; use crate::api_description::ApiEndpointParameter; use crate::api_description::ApiEndpointParameterLocation; use crate::api_description::ApiEndpointResponse; use crate::api_description::ApiSchemaGenerator; +use crate::api_description::{ApiEndpointBodyContentType, ExtensionMode}; use crate::pagination::PaginationParams; use crate::pagination::PAGINATION_PARAM_SENTINEL; use crate::router::VariableSet; use crate::to_map::to_map; +use crate::websocket::WEBSOCKET_PARAM_SENTINEL; use async_trait::async_trait; use bytes::Bytes; @@ -173,8 +174,8 @@ impl RequestContextArgument * `RequestContext`. Unlike most traits, `Extractor` essentially defines only a * constructor function, not instance functions. * - * The extractors that we provide (`Query`, `Path`, `TypedBody`, and - * `UntypedBody`) implement `Extractor` in order to construct themselves from + * The extractors that we provide (`Query`, `Path`, `TypedBody`, `UntypedBody`, and + * `WebsocketUpgrade`) implement `Extractor` in order to construct themselves from * the request. For example, `Extractor` is implemented for `Query` with a * function that reads the query string from the request, parses it, and * constructs a `Query` with it. @@ -202,7 +203,7 @@ pub trait Extractor: Send + Sync + Sized { * the associated endpoint is paginated. */ pub struct ExtractorMetadata { - pub paginated: bool, + pub extension_mode: ExtensionMode, pub parameters: Vec, } @@ -223,15 +224,21 @@ macro_rules! impl_extractor_for_tuple { fn metadata(_body_content_type: ApiEndpointBodyContentType) -> ExtractorMetadata { #[allow(unused_mut)] - let mut paginated = false; + let mut extension_mode = ExtensionMode::None; #[allow(unused_mut)] let mut parameters = vec![]; $( let mut metadata = $T::metadata(_body_content_type.clone()); - paginated = paginated | metadata.paginated; + extension_mode = match (extension_mode, metadata.extension_mode) { + (ExtensionMode::None, x) | (x, ExtensionMode::None) => x, + (x, y) if x != y => { + panic!("incompatible extension modes in tuple: {:?} != {:?}", x, y); + } + (_, x) => x, + }; parameters.append(&mut metadata.parameters); )* - ExtractorMetadata { paginated, parameters } + ExtractorMetadata { extension_mode, parameters } } } }} @@ -685,11 +692,23 @@ where ); let schema = generator.root_schema_for::().schema.into(); - let paginated = match schema_extensions(&schema) { + let extension_mode = match schema_extensions(&schema) { Some(extensions) => { - extensions.get(&PAGINATION_PARAM_SENTINEL.to_string()).is_some() + let paginated = extensions + .get(&PAGINATION_PARAM_SENTINEL.to_string()) + .is_some(); + let websocket = + extensions.get(&WEBSOCKET_PARAM_SENTINEL.to_string()).is_some(); + match (paginated, websocket) { + (false, false) => ExtensionMode::None, + (false, true) => ExtensionMode::Websocket, + (true, false) => ExtensionMode::Paginated, + (true, true) => panic!( + "Cannot use websocket and pagination in the same endpoint!" + ), + } } - None => false, + None => ExtensionMode::None, }; /* @@ -716,7 +735,7 @@ where }) .collect::>(); - ExtractorMetadata { paginated, parameters } + ExtractorMetadata { extension_mode, parameters } } fn schema_extensions( @@ -1039,7 +1058,10 @@ where }, vec![], ); - ExtractorMetadata { paginated: false, parameters: vec![body] } + ExtractorMetadata { + extension_mode: ExtensionMode::None, + parameters: vec![body], + } } } @@ -1116,7 +1138,7 @@ impl Extractor for UntypedBody { }, vec![], )], - paginated: false, + extension_mode: ExtensionMode::None, } } } @@ -1591,6 +1613,7 @@ fn schema_extract_description( #[cfg(test)] mod test { + use crate::api_description::ExtensionMode; use crate::{ api_description::ApiEndpointParameterMetadata, ApiEndpointParameter, ApiEndpointParameterLocation, PaginationParams, @@ -1628,10 +1651,10 @@ mod test { fn compare( actual: ExtractorMetadata, - paginated: bool, + extension_mode: ExtensionMode, parameters: Vec<(&str, bool)>, ) { - assert_eq!(actual.paginated, paginated); + assert_eq!(actual.extension_mode, extension_mode); /* * This is order-dependent. We might not really care if the order @@ -1659,7 +1682,7 @@ mod test { let params = get_metadata::(&ApiEndpointParameterLocation::Path); let expected = vec![("bar", true), ("baz", false), ("foo", true)]; - compare(params, false, expected); + compare(params, ExtensionMode::None, expected); } #[test] @@ -1672,7 +1695,7 @@ mod test { ("limit", false), ]; - compare(params, false, expected); + compare(params, ExtensionMode::None, expected); } #[test] @@ -1687,7 +1710,7 @@ mod test { ("page_token", false), ]; - compare(params, false, expected); + compare(params, ExtensionMode::None, expected); } #[test] @@ -1703,6 +1726,6 @@ mod test { ("page_token", false), ]; - compare(params, true, expected); + compare(params, ExtensionMode::Paginated, expected); } } diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index d758c5988..6d29c4d37 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -611,6 +611,7 @@ mod router; mod server; mod to_map; mod type_util; +mod websocket; pub mod test_util; @@ -624,6 +625,7 @@ pub use api_description::ApiEndpointParameter; pub use api_description::ApiEndpointParameterLocation; pub use api_description::ApiEndpointResponse; pub use api_description::EndpointTagPolicy; +pub use api_description::ExtensionMode; pub use api_description::OpenApiDefinition; pub use api_description::TagConfig; pub use api_description::TagDetails; @@ -664,6 +666,11 @@ pub use pagination::ResultsPage; pub use pagination::WhichPage; pub use server::ServerContext; pub use server::{HttpServer, HttpServerStarter}; +pub use websocket::WebsocketChannelResult; +pub use websocket::WebsocketConnection; +pub use websocket::WebsocketConnectionRaw; +pub use websocket::WebsocketEndpointResult; +pub use websocket::WebsocketUpgrade; /* * Users of the `endpoint` macro need the following macros: @@ -672,4 +679,5 @@ pub use handler::RequestContextArgument; pub use http::Method; extern crate dropshot_endpoint; +pub use dropshot_endpoint::channel; pub use dropshot_endpoint::endpoint; diff --git a/dropshot/src/router.rs b/dropshot/src/router.rs index 73e53cbb8..06e90200c 100644 --- a/dropshot/src/router.rs +++ b/dropshot/src/router.rs @@ -822,7 +822,7 @@ mod test { summary: None, description: None, tags: vec![], - paginated: false, + extension_mode: Default::default(), visible: true, } } diff --git a/dropshot/src/websocket.rs b/dropshot/src/websocket.rs new file mode 100644 index 000000000..14247a7cd --- /dev/null +++ b/dropshot/src/websocket.rs @@ -0,0 +1,374 @@ +// Copyright 2022 Oxide Computer Company +/*! + * Implements websocket upgrades as an Extractor for use in API route handler + * parameters to indicate that the given endpoint is meant to be upgraded to + * a websocket. + * + * This exposes a raw upgraded HTTP connection to a user-provided async future, + * which will be spawned to handle the incoming connection. + */ + +use crate::api_description::ExtensionMode; +use crate::{ + ApiEndpointBodyContentType, Extractor, ExtractorMetadata, HttpError, + RequestContext, ServerContext, +}; +use async_trait::async_trait; +use http::header; +use http::Response; +use http::StatusCode; +use hyper::upgrade::OnUpgrade; +use hyper::Body; +use schemars::JsonSchema; +use serde_json::json; +use sha1::{Digest, Sha1}; +use slog::Logger; +use std::future::Future; +use std::sync::Arc; + +/** + * WebsocketUpgrade is an Extractor used to upgrade and handle an HTTP request + * as a websocket when present in a Dropshot endpoint's function arguments. + * + * The consumer of this must call [WebsocketUpgrade::handle] for the connection + * to be upgraded. (This is done for you by `#[channel]`.) + */ +#[derive(Debug)] +pub struct WebsocketUpgrade(Option); + +/** + * This is the return type of the websocket-handling future provided to + * [`dropshot_endpoint::channel`] + * (which in turn provides it to [WebsocketUpgrade::handle]). + */ +pub type WebsocketChannelResult = + Result<(), Box>; + +/** + * [WebsocketUpgrade::handle]'s return type. + * The `#[endpoint]` handler must return the value returned by + * [WebsocketUpgrade::handle]. (This is done for you by `#[channel]`.) + */ +pub type WebsocketEndpointResult = Result, HttpError>; + +/** + * The upgraded connection passed as the second argument to the websocket + * handler function. [`WebsocketConnection::into_inner`] can be used to + * access the raw upgraded connection, for passing to any implementation + * of the websockets protocol. + */ +pub struct WebsocketConnection(WebsocketConnectionRaw); + +/// A type that implements [tokio::io::AsyncRead] + [tokio::io::AsyncWrite]. +pub type WebsocketConnectionRaw = hyper::upgrade::Upgraded; + +impl WebsocketConnection { + /// Consumes `self` and returns the held raw connection. + pub fn into_inner(self) -> WebsocketConnectionRaw { + self.0 + } +} + +#[derive(Debug)] +struct WebsocketUpgradeInner { + upgrade_fut: OnUpgrade, + accept_key: String, + route: String, + ws_log: Logger, +} + +// Borrowed from tungstenite-0.17.3 (rather than taking a whole dependency for this one function) +fn derive_accept_key(request_key: &[u8]) -> String { + // ... field is constructed by concatenating /key/ ... + // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::default(); + sha1.update(request_key); + sha1.update(WS_GUID); + base64::encode(&sha1.finalize()) +} + +/** + * This `Extractor` implementation constructs an instance of `WebsocketUpgrade` + * from an HTTP request, and returns an error if the given request does not + * contain websocket upgrade headers. + */ +#[async_trait] +impl Extractor for WebsocketUpgrade { + async fn from_request( + rqctx: Arc>, + ) -> Result { + let request = &mut *rqctx.request.lock().await; + + if !request + .headers() + .get(header::CONNECTION) + .and_then(|hv| hv.to_str().ok()) + .map(|hv| { + hv.split(|c| c == ',' || c == ' ') + .any(|vs| vs.eq_ignore_ascii_case("upgrade")) + }) + .unwrap_or(false) + { + return Err(HttpError::for_bad_request( + None, + "expected connection upgrade".to_string(), + )); + } + + if !request + .headers() + .get(header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| { + v.split(|c| c == ',' || c == ' ') + .any(|v| v.eq_ignore_ascii_case("websocket")) + }) + .unwrap_or(false) + { + return Err(HttpError::for_bad_request( + None, + "unexpected protocol for upgrade".to_string(), + )); + } + + if request + .headers() + .get(header::SEC_WEBSOCKET_VERSION) + .map(|v| v.as_bytes()) + != Some(b"13") + { + return Err(HttpError::for_bad_request( + None, + "missing or invalid websocket version".to_string(), + )); + } + + let accept_key = request + .headers() + .get(header::SEC_WEBSOCKET_KEY) + .map(|hv| hv.as_bytes()) + .map(|key| derive_accept_key(key)) + .ok_or_else(|| { + HttpError::for_bad_request( + None, + "missing websocket key".to_string(), + ) + })?; + + let route = request.uri().to_string(); + let upgrade_fut = hyper::upgrade::on(request); + // note: this is just used in our wrapper in `handle`; if a user wants + // to slog in their future, they can obtain it from rqctx the same way + // they do in any other endpoint & let it get `move`d into the closure + let ws_log = rqctx.log.new(o!( + "upgrade" => "websocket".to_string(), + )); + + Ok(Self(Some(WebsocketUpgradeInner { + upgrade_fut, + accept_key, + ws_log, + route, + }))) + } + + fn metadata( + _content_type: ApiEndpointBodyContentType, + ) -> ExtractorMetadata { + ExtractorMetadata { + parameters: vec![], + extension_mode: ExtensionMode::Websocket, + } + } +} + +impl WebsocketUpgrade { + /** + * Upgrade the HTTP connection to a websocket and spawn a user-provided + * async handler to service it. + * + * This function's return value should be the basis of the return value of + * your endpoint's function, as it sends the headers to tell the HTTP + * client that we are accepting the upgrade. + * + * `handler` is a closure that accepts a [`WebsocketConnection`] + * and returns a future that will be spawned by this function, + * in which the `WebsocketConnection`'s inner `Upgraded` connection may be + * used with your choice of websocket-handling code operating over an + * [`tokio::io::AsyncRead`] + [`tokio::io::AsyncWrite`] type + * (e.g. `tokio_tungstenite`). + * + * ``` + #[dropshot::endpoint { method = GET, path = "/my/ws/endpoint/{id}" }] + async fn my_ws_endpoint( + rqctx: std::sync::Arc>, + websock: dropshot::WebsocketUpgrade, + id: dropshot::Path, + ) -> dropshot::WebsocketEndpointResult { + let logger = rqctx.log.new(slog::o!()); + websock.handle(move |upgraded| async move { + slog::info!(logger, "Entered handler for ID {}", id.into_inner()); + use futures_util::stream::StreamExt; + let mut ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket( + upgraded.into_inner(), tungstenite::protocol::Role::Server, None + ).await; + slog::info!(logger, "Received from websocket: {:?}", ws_stream.next().await); + Ok(()) + }) + } + * ``` + * + * Note that as a consumer of this crate, you most likely do not want to + * call this function directly; rather, prefer to annotate your function + * with [`dropshot_endpoint::channel`] instead of `endpoint`. + */ + pub fn handle(mut self, handler: C) -> WebsocketEndpointResult + where + C: FnOnce(WebsocketConnection) -> F + Send + 'static, + F: Future + Send + 'static, + { + // we .take() here to tell Drop::drop that we handled the request. + match self.0.take() { + None => Err(HttpError::for_internal_error( + "Tried to handle websocket twice".to_string(), + )), + Some(WebsocketUpgradeInner { + upgrade_fut, + accept_key, + ws_log, + .. + }) => { + tokio::spawn(async move { + match upgrade_fut.await { + Ok(upgrade) => { + match handler(WebsocketConnection(upgrade)).await { + Ok(x) => Ok(x), + Err(e) => { + error!( + ws_log, + "Error returned from handler: {:?}", e + ); + Err(e) + } + } + } + Err(e) => { + error!( + ws_log, + "Error upgrading connection: {:?}", e + ); + Err(e.into()) + } + } + }); + Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, "Upgrade") + .header(header::UPGRADE, "websocket") + .header(header::SEC_WEBSOCKET_ACCEPT, accept_key) + .body(Body::empty()) + .map_err(Into::into) + } + } + } +} + +impl Drop for WebsocketUpgrade { + fn drop(&mut self) { + if let Some(inner) = self.0.take() { + debug!( + inner.ws_log, + "Didn't handle websocket in route {}", inner.route + ); + } + } +} + +// To indicate websocket usage by the endpoint to code generators (i.e. Progenitor) +pub(crate) const WEBSOCKET_EXTENSION: &str = "x-dropshot-websocket"; +pub(crate) const WEBSOCKET_PARAM_SENTINEL: &str = "x-dropshot-websocket-param"; + +impl JsonSchema for WebsocketUpgrade { + fn schema_name() -> String { + "WebsocketUpgrade".to_string() + } + + fn json_schema( + _gen: &mut schemars::gen::SchemaGenerator, + ) -> schemars::schema::Schema { + let mut schema = schemars::schema::SchemaObject::default(); + schema + .extensions + .insert(WEBSOCKET_PARAM_SENTINEL.to_string(), json!(true)); + schemars::schema::Schema::Object(schema) + } +} + +#[cfg(test)] +mod tests { + use crate::router::HttpRouter; + use crate::server::{DropshotState, ServerConfig}; + use crate::{Extractor, HttpError, RequestContext, WebsocketUpgrade}; + use futures::lock::Mutex; + use http::Request; + use hyper::Body; + use std::net::{IpAddr, Ipv6Addr, SocketAddr}; + use std::num::NonZeroU32; + use std::sync::Arc; + use std::time::Duration; + + async fn ws_upg_from_mock_rqctx() -> Result { + let log = slog::Logger::root(slog::Discard, slog::o!()).new(slog::o!()); + let fut = WebsocketUpgrade::from_request(Arc::new(RequestContext { + server: Arc::new(DropshotState { + private: (), + config: ServerConfig { + request_body_max_bytes: 0, + page_max_nitems: NonZeroU32::new(1).unwrap(), + page_default_nitems: NonZeroU32::new(1).unwrap(), + }, + router: HttpRouter::new(), + log: log.clone(), + local_addr: SocketAddr::new( + IpAddr::V6(Ipv6Addr::LOCALHOST), + 8080, + ), + tls: false, + }), + request: Arc::new(Mutex::new( + Request::builder() + .header(http::header::CONNECTION, "Upgrade") + .header(http::header::UPGRADE, "websocket") + .header(http::header::SEC_WEBSOCKET_VERSION, "13") + .header( + http::header::SEC_WEBSOCKET_KEY, + "aGFjayB0aGUgcGxhbmV0IQ==", + ) + .body(Body::empty()) + .unwrap(), + )), + path_variables: Default::default(), + body_content_type: Default::default(), + request_id: "".to_string(), + log: log.clone(), + })); + tokio::time::timeout(Duration::from_secs(1), fut) + .await + .expect("Deadlocked in WebsocketUpgrade constructor") + } + + #[tokio::test] + async fn test_ws_upg_task_is_spawned() { + let (send, recv) = tokio::sync::oneshot::channel(); + ws_upg_from_mock_rqctx() + .await + .unwrap() + .handle(move |_upgrade| async move { Ok(send.send(()).unwrap()) }) + .unwrap(); + // note: not a real connection, so we don't get our future's Ok, but we *do* spawn the task + let _ = tokio::time::timeout(Duration::from_secs(1), recv) + .await + .expect("Task not spawned or never completed"); + } +} diff --git a/dropshot/tests/fail/bad_endpoint3.stderr b/dropshot/tests/fail/bad_endpoint3.stderr index 6e7e4c552..290bb48c6 100644 --- a/dropshot/tests/fail/bad_endpoint3.stderr +++ b/dropshot/tests/fail/bad_endpoint3.stderr @@ -11,6 +11,7 @@ error[E0277]: the trait bound `String: Extractor` is not satisfied (T1,) TypedBody UntypedBody + WebsocketUpgrade dropshot::Path dropshot::Query note: required by a bound in `need_extractor` diff --git a/dropshot/tests/fail/bad_endpoint4.stderr b/dropshot/tests/fail/bad_endpoint4.stderr index 048332640..137ae560a 100644 --- a/dropshot/tests/fail/bad_endpoint4.stderr +++ b/dropshot/tests/fail/bad_endpoint4.stderr @@ -13,7 +13,7 @@ error[E0277]: the trait bound `QueryParams: schemars::JsonSchema` is not satisfi (T0, T1, T2, T3) (T0, T1, T2, T3, T4) (T0, T1, T2, T3, T4, T5) - and 145 others + and 146 others note: required by a bound in `dropshot::Query` --> src/handler.rs | diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index 1a67c2d4e..f30882586 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -15,6 +15,7 @@ * JSON body length) */ +use dropshot::channel; use dropshot::endpoint; use dropshot::test_util::object_delete; use dropshot::test_util::read_json; @@ -32,7 +33,11 @@ use dropshot::Query; use dropshot::RequestContext; use dropshot::TypedBody; use dropshot::UntypedBody; +use dropshot::WebsocketChannelResult; +use dropshot::WebsocketConnection; use dropshot::CONTENT_TYPE_JSON; +use futures_util::SinkExt; +use futures_util::StreamExt; use http::StatusCode; use hyper::Body; use hyper::Method; @@ -41,6 +46,9 @@ use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; use std::sync::Arc; +use tokio_tungstenite::WebSocketStream; +use tungstenite::protocol::Role; +use tungstenite::Message; use uuid::Uuid; extern crate slog; @@ -60,6 +68,7 @@ fn demo_api() -> ApiDescription { api.register(demo_handler_untyped_body).unwrap(); api.register(demo_handler_delete).unwrap(); api.register(demo_handler_headers).unwrap(); + api.register(demo_handler_websocket).unwrap(); /* * We don't need to exhaustively test these cases, as they're tested by unit @@ -734,6 +743,28 @@ async fn test_header_request() { assert_eq!(headers, vec!["hi", "howdy"]); } +/* + * The "test_demo_websocket" handler upgrades to a websocket and exchanges + * greetings with the client. + */ +#[tokio::test] +async fn test_demo_websocket() { + let api = demo_api(); + let testctx = common::test_setup("demo_websocket", api); + + let path = format!( + "ws://{}/testing/websocket", + testctx.client_testctx.bind_address + ); + let (mut ws, _resp) = tokio_tungstenite::connect_async(path).await.unwrap(); + + ws.send(Message::Text("hello server".to_string())).await.unwrap(); + let msg = ws.next().await.unwrap().unwrap(); + assert_eq!(msg, Message::Text("hello client".to_string())); + + testctx.teardown().await; +} + /* * Demo handler functions */ @@ -933,6 +964,27 @@ async fn demo_handler_headers( Ok(response) } +#[channel { + protocol = WEBSOCKETS, + path = "/testing/websocket" +}] +async fn demo_handler_websocket( + rqctx: RequestCtx, + upgraded: WebsocketConnection, +) -> WebsocketChannelResult { + use futures_util::stream::StreamExt; + let mut ws_stream = WebSocketStream::from_raw_socket( + upgraded.into_inner(), + Role::Server, + None, + ) + .await; + ws_stream.send(Message::Text("hello client".to_string())).await.unwrap(); + let msg = ws_stream.next().await.unwrap().unwrap(); + slog::info!(rqctx.log, "{}", msg); + Ok(()) +} + fn http_echo(t: &T) -> Result, HttpError> { Ok(Response::builder() .header(http::header::CONTENT_TYPE, CONTENT_TYPE_JSON) diff --git a/dropshot_endpoint/src/lib.rs b/dropshot_endpoint/src/lib.rs index 1cfe4d3a2..ea49a0cd0 100644 --- a/dropshot_endpoint/src/lib.rs +++ b/dropshot_endpoint/src/lib.rs @@ -10,13 +10,13 @@ */ #![allow(clippy::style)] -use proc_macro2::TokenStream; use quote::format_ident; use quote::quote; use quote::{quote_spanned, ToTokens}; use serde::Deserialize; use serde_tokenstream::from_tokenstream; use serde_tokenstream::Error; +use std::ops::DerefMut; use syn::spanned::Spanned; use syn_parsing::ItemFnForSignature; @@ -48,7 +48,7 @@ impl MethodType { } #[derive(Deserialize, Debug)] -struct Metadata { +struct EndpointMetadata { method: MethodType, path: String, tags: Option>, @@ -57,6 +57,21 @@ struct Metadata { _dropshot_crate: Option, } +#[allow(non_snake_case)] +#[derive(Deserialize, Debug)] +enum ChannelProtocol { + WEBSOCKETS, +} + +#[derive(Deserialize, Debug)] +struct ChannelMetadata { + protocol: ChannelProtocol, + path: String, + tags: Option>, + unpublished: Option, + _dropshot_crate: Option, +} + const DROPSHOT: &str = "dropshot"; const USAGE: &str = "Endpoint handlers must have the following signature: async fn( @@ -97,7 +112,131 @@ pub fn endpoint( attr: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - match do_endpoint(attr.into(), item.into()) { + do_output(do_endpoint(attr.into(), item.into())) +} + +fn do_endpoint( + attr: proc_macro2::TokenStream, + item: proc_macro2::TokenStream, +) -> Result<(proc_macro2::TokenStream, Vec), Error> { + let metadata = from_tokenstream(&attr)?; + // factored this way for now so #[channel] can use it too + do_endpoint_inner(metadata, attr, item) +} + +/// As with [`endpoint`], this attribute turns a handler function into a +/// Dropshot endpoint, but first wraps the handler function in such a way +/// that is spawned asynchronously and given the upgraded connection of +/// the given `protocol` (i.e. `WEBSOCKETS`). +/// +/// The first argument still must be an `Arc>`. +/// +/// The second argument passed to the handler function must be a +/// [`dropshot::WebsocketConnection`]. +/// +/// The function must return a [`dropshot::WebsocketChannelResult`] (which is +/// a general-purpose `Result<(), Box>`). +/// Returned error values will be written to the RequestContext's log. +/// +/// ```ignore +/// #[dropshot::channel { protocol = WEBSOCKETS, path = "/my/ws/channel/{id}" }] +/// ``` +#[proc_macro_attribute] +pub fn channel( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + do_output(do_channel(attr.into(), item.into())) +} + +fn do_channel( + attr: proc_macro2::TokenStream, + item: proc_macro2::TokenStream, +) -> Result<(proc_macro2::TokenStream, Vec), Error> { + let ChannelMetadata { protocol, path, tags, unpublished, _dropshot_crate } = + from_tokenstream(&attr)?; + match protocol { + ChannelProtocol::WEBSOCKETS => { + // here we construct a wrapper function and mutate the arguments a bit + // for the outer layer: we replace WebsocketConnection, which is not + // an extractor, with WebsocketUpgrade, which is. + let ItemFnForSignature { attrs, vis, mut sig, _block: body } = + syn::parse2(item)?; + + let inner_args = sig.inputs.clone(); + let inner_output = sig.output.clone(); + + let arg_names: Vec<_> = inner_args + .iter() + .map(|arg: &syn::FnArg| match arg { + syn::FnArg::Receiver(r) => r.self_token.to_token_stream(), + syn::FnArg::Typed(syn::PatType { pat, .. }) => { + pat.to_token_stream() + } + }) + .collect(); + let found = sig.inputs.iter_mut().nth(1).and_then(|arg| { + if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = arg { + if let syn::Pat::Ident(syn::PatIdent { + ident, + by_ref: None, + .. + }) = pat.deref_mut() + { + let conn_type = ty.clone(); + let conn_name = ident.clone(); + let span = ident.span(); + *ident = syn::Ident::new( + "__dropshot_websocket_upgrade", + span, + ); + *ty = Box::new(syn::Type::Verbatim( + quote! { dropshot::WebsocketUpgrade }, + )); + return Some((conn_name, conn_type)); + } + } + return None; + }); + if found.is_none() { + return Err(Error::new_spanned( + &attr, + "An argument of type dropshot::WebsocketConnection must be provided immediately following Arc>.", + )); + } + + sig.output = + syn::parse2(quote!(-> dropshot::WebsocketEndpointResult))?; + + let (conn_name, conn_type) = found.unwrap(); + + let new_item = quote! { + #(#attrs)* + #vis #sig { + async fn __dropshot_websocket_handler(#inner_args) #inner_output #body + __dropshot_websocket_upgrade.handle(move | #conn_name: #conn_type | async move { + __dropshot_websocket_handler(#(#arg_names),*).await + }) + } + }; + + let metadata = EndpointMetadata { + method: MethodType::GET, + path, + tags, + unpublished, + content_type: Some("application/json".to_string()), + _dropshot_crate, + }; + do_endpoint_inner(metadata, attr, new_item) + } + } +} + +fn do_output( + res: Result<(proc_macro2::TokenStream, Vec), Error>, +) -> proc_macro::TokenStream { + match res { Err(err) => err.to_compile_error().into(), Ok((endpoint, errors)) => { let compiler_errors = @@ -113,11 +252,12 @@ pub fn endpoint( } } -fn do_endpoint( - attr: TokenStream, - item: TokenStream, -) -> Result<(TokenStream, Vec), Error> { - let metadata = from_tokenstream::(&attr)?; +fn do_endpoint_inner( + metadata: EndpointMetadata, + attr: proc_macro2::TokenStream, + item: proc_macro2::TokenStream, +) -> Result<(proc_macro2::TokenStream, Vec), Error> { + let ast: ItemFnForSignature = syn::parse2(item.clone())?; let method = metadata.method.as_str(); let path = metadata.path; let content_type = @@ -132,8 +272,6 @@ fn do_endpoint( )); } - let ast: ItemFnForSignature = syn::parse2(item.clone())?; - let mut errors = Vec::new(); if ast.sig.constness.is_some() { @@ -472,7 +610,7 @@ fn do_endpoint( Ok((stream, errors)) } -fn get_crate(var: Option) -> TokenStream { +fn get_crate(var: Option) -> proc_macro2::TokenStream { if let Some(s) = var { if let Ok(ts) = syn::parse_str(s.as_str()) { return ts;