Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,72 @@ mod content_length_tests {
}
}

/// Resolve the rustls [`ServerName`] used for TLS certificate verification from
/// an outbound request `authority`.
///
/// `authority` is in `host:port` form, where an IPv6 `host` is wrapped in
/// brackets (for example `[::1]:443`). An IP literal is recognized by parsing
/// the whole authority as a [`SocketAddr`]; this handles the bracketed IPv6
/// form, which splitting on the first `:` would truncate. Anything else is
/// treated as a host name, with the port stripped off before it is handed to
/// rustls.
///
/// [`ServerName`]: rustls::pki_types::ServerName
/// [`SocketAddr`]: std::net::SocketAddr
#[cfg(all(feature = "default-send-request", any(feature = "p2", feature = "p3")))]
fn tls_server_name(
authority: &str,
) -> Result<rustls::pki_types::ServerName<'static>, rustls::pki_types::InvalidDnsNameError> {
use rustls::pki_types::ServerName;

if let Ok(addr) = authority.parse::<std::net::SocketAddr>() {
return Ok(ServerName::from(addr.ip()));
}
let host = match authority.split_once(':') {
Some((host, _port)) => host,
None => authority,
};
Ok(ServerName::try_from(host)?.to_owned())
}

#[cfg(all(
test,
feature = "default-send-request",
any(feature = "p2", feature = "p3")
))]
mod tls_server_name_tests {
use super::tls_server_name;
use rustls::pki_types::ServerName;

#[test]
fn resolves_server_name_from_authority() {
// Host names keep their host and drop the port.
assert_eq!(
tls_server_name("example.com:443").unwrap(),
ServerName::try_from("example.com").unwrap()
);
assert_eq!(
tls_server_name("example.com").unwrap(),
ServerName::try_from("example.com").unwrap()
);

// IP literals resolve to an `IpAddress` server name. The bracketed IPv6
// form must not be truncated at the first `:`.
assert_eq!(
tls_server_name("127.0.0.1:80").unwrap(),
ServerName::from(std::net::Ipv4Addr::LOCALHOST)
);
assert_eq!(
tls_server_name("[::1]:443").unwrap(),
ServerName::from(std::net::Ipv6Addr::LOCALHOST)
);
assert_eq!(
tls_server_name("[2001:db8::1]:8443").unwrap(),
ServerName::from("2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap())
);
}
}

/// Set of [http::header::HeaderName], that are forbidden by default
/// for requests and responses originating in the guest.
pub const DEFAULT_FORBIDDEN_HEADERS: [HeaderName; 9] = [
Expand Down
14 changes: 4 additions & 10 deletions crates/wasi-http/src/p2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,6 @@ pub async fn default_send_request_handler(
})?;

let (mut sender, worker) = if use_tls {
use rustls::pki_types::ServerName;

// derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs
let root_cert_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.into(),
Expand All @@ -630,14 +628,10 @@ pub async fn default_send_request_handler(
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = ServerName::try_from(host)
.map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?
.to_owned();
let domain = crate::tls_server_name(&authority).map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
tracing::warn!("tls protocol error: {e:?}");
ErrorCode::TlsProtocolError
Expand Down
14 changes: 4 additions & 10 deletions crates/wasi-http/src/p3/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,6 @@ pub async fn default_send_request(
Err(..) => return Err(ErrorCode::ConnectionTimeout),
};
let stream = if use_tls {
use rustls::pki_types::ServerName;

// derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs
let root_cert_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.into(),
Expand All @@ -356,14 +354,10 @@ pub async fn default_send_request(
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = ServerName::try_from(host)
.map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?
.to_owned();
let domain = crate::tls_server_name(&authority).map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?;
let stream = connector.connect(domain, stream).await.map_err(|e| {
tracing::warn!("tls protocol error: {e:?}");
ErrorCode::TlsProtocolError
Expand Down
Loading