diff --git a/crates/wasi-http/src/lib.rs b/crates/wasi-http/src/lib.rs index 0b40aa8694a2..874ccbe5c3ed 100644 --- a/crates/wasi-http/src/lib.rs +++ b/crates/wasi-http/src/lib.rs @@ -71,6 +71,72 @@ mod content_length_tests { } } +/// Resolve the rustls [`ServerName`] used for TLS certificate verification from +/// an outbound request `authority`. +/// +/// `authority` is in `host:port` form, where an IPv6 `host` is wrapped in +/// brackets (for example `[::1]:443`). An IP literal is recognized by parsing +/// the whole authority as a [`SocketAddr`]; this handles the bracketed IPv6 +/// form, which splitting on the first `:` would truncate. Anything else is +/// treated as a host name, with the port stripped off before it is handed to +/// rustls. +/// +/// [`ServerName`]: rustls::pki_types::ServerName +/// [`SocketAddr`]: std::net::SocketAddr +#[cfg(all(feature = "default-send-request", any(feature = "p2", feature = "p3")))] +fn tls_server_name( + authority: &str, +) -> Result, rustls::pki_types::InvalidDnsNameError> { + use rustls::pki_types::ServerName; + + if let Ok(addr) = authority.parse::() { + return Ok(ServerName::from(addr.ip())); + } + let host = match authority.split_once(':') { + Some((host, _port)) => host, + None => authority, + }; + Ok(ServerName::try_from(host)?.to_owned()) +} + +#[cfg(all( + test, + feature = "default-send-request", + any(feature = "p2", feature = "p3") +))] +mod tls_server_name_tests { + use super::tls_server_name; + use rustls::pki_types::ServerName; + + #[test] + fn resolves_server_name_from_authority() { + // Host names keep their host and drop the port. + assert_eq!( + tls_server_name("example.com:443").unwrap(), + ServerName::try_from("example.com").unwrap() + ); + assert_eq!( + tls_server_name("example.com").unwrap(), + ServerName::try_from("example.com").unwrap() + ); + + // IP literals resolve to an `IpAddress` server name. The bracketed IPv6 + // form must not be truncated at the first `:`. + assert_eq!( + tls_server_name("127.0.0.1:80").unwrap(), + ServerName::from(std::net::Ipv4Addr::LOCALHOST) + ); + assert_eq!( + tls_server_name("[::1]:443").unwrap(), + ServerName::from(std::net::Ipv6Addr::LOCALHOST) + ); + assert_eq!( + tls_server_name("[2001:db8::1]:8443").unwrap(), + ServerName::from("2001:db8::1".parse::().unwrap()) + ); + } +} + /// Set of [http::header::HeaderName], that are forbidden by default /// for requests and responses originating in the guest. pub const DEFAULT_FORBIDDEN_HEADERS: [HeaderName; 9] = [ diff --git a/crates/wasi-http/src/p2/mod.rs b/crates/wasi-http/src/p2/mod.rs index 0e6a4f268416..962c37291158 100644 --- a/crates/wasi-http/src/p2/mod.rs +++ b/crates/wasi-http/src/p2/mod.rs @@ -620,8 +620,6 @@ pub async fn default_send_request_handler( })?; let (mut sender, worker) = if use_tls { - use rustls::pki_types::ServerName; - // derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs let root_cert_store = rustls::RootCertStore { roots: webpki_roots::TLS_SERVER_ROOTS.into(), @@ -630,14 +628,10 @@ pub async fn default_send_request_handler( .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); - let mut parts = authority.split(":"); - let host = parts.next().unwrap_or(&authority); - let domain = ServerName::try_from(host) - .map_err(|e| { - tracing::warn!("dns lookup error: {e:?}"); - dns_error("invalid dns name".to_string(), 0) - })? - .to_owned(); + let domain = crate::tls_server_name(&authority).map_err(|e| { + tracing::warn!("dns lookup error: {e:?}"); + dns_error("invalid dns name".to_string(), 0) + })?; let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { tracing::warn!("tls protocol error: {e:?}"); ErrorCode::TlsProtocolError diff --git a/crates/wasi-http/src/p3/request.rs b/crates/wasi-http/src/p3/request.rs index aa3ed706abbb..8bdf9d4b7d8b 100644 --- a/crates/wasi-http/src/p3/request.rs +++ b/crates/wasi-http/src/p3/request.rs @@ -346,8 +346,6 @@ pub async fn default_send_request( Err(..) => return Err(ErrorCode::ConnectionTimeout), }; let stream = if use_tls { - use rustls::pki_types::ServerName; - // derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs let root_cert_store = rustls::RootCertStore { roots: webpki_roots::TLS_SERVER_ROOTS.into(), @@ -356,14 +354,10 @@ pub async fn default_send_request( .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); - let mut parts = authority.split(":"); - let host = parts.next().unwrap_or(&authority); - let domain = ServerName::try_from(host) - .map_err(|e| { - tracing::warn!("dns lookup error: {e:?}"); - dns_error("invalid dns name".to_string(), 0) - })? - .to_owned(); + let domain = crate::tls_server_name(&authority).map_err(|e| { + tracing::warn!("dns lookup error: {e:?}"); + dns_error("invalid dns name".to_string(), 0) + })?; let stream = connector.connect(domain, stream).await.map_err(|e| { tracing::warn!("tls protocol error: {e:?}"); ErrorCode::TlsProtocolError