diff --git a/Cargo.lock b/Cargo.lock index 4735b94c..e5d58f45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -425,7 +425,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" dependencies = [ "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", "windows-link", ] @@ -659,6 +661,7 @@ version = "2.0.0" dependencies = [ "axum", "base64", + "chrono", "clap", "defguard_certs", "defguard_version", @@ -683,6 +686,7 @@ dependencies = [ "tonic-prost-build", "tower", "tracing", + "tracing-subscriber", "vergen-git2", "x25519-dalek", ] @@ -731,7 +735,7 @@ dependencies = [ [[package]] name = "defguard_version" version = "0.0.0" -source = "git+https://github.com/DefGuard/defguard.git?rev=640bae9a0aea1e11395f0a29fb8c84eeefd7f115#640bae9a0aea1e11395f0a29fb8c84eeefd7f115" +source = "git+https://github.com/DefGuard/defguard.git?rev=5be16525f5208739fd79384b30d8ac5056ffdb2f#5be16525f5208739fd79384b30d8ac5056ffdb2f" dependencies = [ "axum", "http", diff --git a/Cargo.toml b/Cargo.toml index 2cf78883..eaef5a4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2024" axum = "0.8" base64 = "0.22" clap = { version = "4.5", features = ["derive", "env"] } -defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "640bae9a0aea1e11395f0a29fb8c84eeefd7f115" } +defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "5be16525f5208739fd79384b30d8ac5056ffdb2f" } defguard_wireguard_rs = { git = "https://github.com/DefGuard/wireguard-rs", rev = "d0b01eabca015ea6c7ddf4e255a0228074684e96" } defguard_certs = { git = "https://github.com/DefGuard/defguard.git", rev = "290bdee718f51179c71e07f3bce3f8a0cbfb9379" } env_logger = "0.11" @@ -35,6 +35,8 @@ tonic = { version = "0.14", default-features = false, features = [ tracing = "0.1" tonic-prost = "0.14" tower = "0.5" +chrono = "0.4.43" +tracing-subscriber = "0.3.22" [target.'cfg(target_os = "linux")'.dependencies] nftnl = { git = "https://github.com/DefGuard/nftnl-rs.git", rev = "1a1147271f43b9d7182a114bb056a5224c35d38f" } diff --git a/Dockerfile b/Dockerfile index 289d4dc3..867278f4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ RUN cargo build --release FROM public.ecr.aws/docker/library/debian:13-slim RUN apt-get update && apt-get -y --no-install-recommends install \ - iproute2 wireguard-tools sudo ca-certificates iptables ebtables nftables && \ + iproute2 wireguard-tools sudo ca-certificates iptables ebtables nftables lsb-release && \ apt-get clean && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY --from=builder /app/target/release/defguard-gateway /usr/local/bin diff --git a/proto b/proto index 906412eb..fdbe98ca 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 906412eb50ac605f4904a355c2f325f3645c117e +Subproject commit fdbe98caa9413b626833da210b5b588b287bb146 diff --git a/src/lib.rs b/src/lib.rs index 902c4ec6..9c661dea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,7 @@ use syslog::{BasicLogger, Facility, Formatter3164}; use tokio::sync::oneshot; pub mod enterprise; +pub mod logging; pub mod setup; pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "+", env!("VERGEN_GIT_SHA")); diff --git a/src/logging.rs b/src/logging.rs new file mode 100644 index 00000000..a7945407 --- /dev/null +++ b/src/logging.rs @@ -0,0 +1,75 @@ +use defguard_version::Version; +use tokio::sync::mpsc::Sender; +use tracing::{Event, Subscriber}; +use tracing_subscriber::{Layer, layer::SubscriberExt, util::SubscriberInitExt}; + +use crate::proto::gateway::LogEntry; + +pub fn init_tracing(own_version: &Version, level: &str, logs_tx: Option>) { + let subscriber = tracing_subscriber::registry(); + let subscriber = + defguard_version::tracing::with_version_formatters(own_version, level, subscriber); + + if let Some(tx) = logs_tx { + let sender_layer = LogSenderLayer::new(tx); + subscriber.with(sender_layer).init(); + } else { + subscriber.init(); + } + + info!("Tracing initialized"); +} + +/// A tracing layer that sends log entries to a gRPC logs channel. +pub struct LogSenderLayer { + logs_tx: Sender, +} + +impl LogSenderLayer { + #[must_use] + pub const fn new(logs_tx: Sender) -> Self { + Self { logs_tx } + } +} + +impl Layer for LogSenderLayer +where + S: Subscriber, +{ + fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) { + if self.logs_tx.is_closed() { + return; + } + + let mut visitor = LogVisitor::default(); + event.record(&mut visitor); + + let entry = LogEntry { + level: format!("{:?}", event.metadata().level()), + target: event.metadata().target().to_string(), + message: visitor.message, + timestamp: chrono::Utc::now().to_rfc3339(), + fields: visitor.fields, + }; + + // Drop the buffer overflow error for now + let _ = self.logs_tx.try_send(entry); + } +} + +#[derive(Default)] +struct LogVisitor { + message: String, + fields: std::collections::HashMap, +} + +impl tracing::field::Visit for LogVisitor { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { + if field.name() == "message" { + self.message = format!("{value:?}"); + } else { + self.fields + .insert(field.name().to_string(), format!("{value:?}")); + } + } +} diff --git a/src/main.rs b/src/main.rs index 1e1c7af9..04e9e0e8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use defguard_gateway::{ execute_command, gateway::{Gateway, GatewayServer, TlsConfig, run_stats}, init_syslog, + logging::init_tracing, server::run_server, setup::GatewaySetupServer, }; @@ -20,7 +21,7 @@ use defguard_version::Version; #[cfg(not(any(target_os = "macos", target_os = "netbsd")))] use defguard_wireguard_rs::Kernel; use defguard_wireguard_rs::{Userspace, WGApi}; -use tokio::task::JoinSet; +use tokio::{sync::mpsc, task::JoinSet}; #[tokio::main] async fn main() -> Result<(), GatewayError> { @@ -35,6 +36,27 @@ async fn main() -> Result<(), GatewayError> { file.write_all(pid.to_string().as_bytes())?; } + let cert_dir = &config.cert_dir; + if !cert_dir.exists() { + tokio::fs::create_dir_all(cert_dir).await?; + } + + let (grpc_cert, grpc_key) = ( + read_to_string(cert_dir.join(GRPC_CERT_NAME)).ok(), + read_to_string(cert_dir.join(GRPC_KEY_NAME)).ok(), + ); + + let needs_setup = grpc_cert.is_none() || grpc_key.is_none(); + + // TODO: The channel size may need to be adjusted or some other approach should be used + // to avoid dropping log messages. + let (logs_tx, logs_rx) = if needs_setup { + let (logs_tx, logs_rx) = mpsc::channel(200); + (Some(logs_tx), Some(logs_rx)) + } else { + (None, None) + }; + // setup logging if config.use_syslog { if let Err(error) = init_syslog(&config, pid) { @@ -42,8 +64,7 @@ async fn main() -> Result<(), GatewayError> { return Err(error); } } else { - let version = Version::parse(VERSION)?; - defguard_version::tracing::init(version, &config.log_level)?; + init_tracing(&Version::parse(VERSION)?, &config.log_level, logs_tx); } if let Some(pre_up) = &config.pre_up { @@ -86,28 +107,19 @@ async fn main() -> Result<(), GatewayError> { let gateway = Arc::new(Mutex::new(gateway)); tasks.spawn(run_stats(Arc::clone(&gateway), config.stats_period())); - let cert_dir = &config.cert_dir; - if !cert_dir.exists() { - tokio::fs::create_dir_all(cert_dir).await?; - } - let tls_config = if let (Some(cert), Some(key)) = ( - read_to_string(cert_dir.join(GRPC_CERT_NAME)).ok(), - read_to_string(cert_dir.join(GRPC_KEY_NAME)).ok(), - ) { - log::info!( - "Using existing gRPC TLS certificates from {}", - cert_dir.display() - ); - TlsConfig { - grpc_cert_pem: cert, - grpc_key_pem: key, - } - } else { + let tls_config = if needs_setup { log::info!( "gRPC TLS certificates not found in {}. They will be generated during setup.", cert_dir.display() ); - let setup_server = GatewaySetupServer::default(); + + let Some(logs_rx) = logs_rx else { + return Err(GatewayError::SetupError( + "Logs receiver channel is missing during gateway setup".to_string(), + )); + }; + + let setup_server = GatewaySetupServer::new(Arc::new(tokio::sync::Mutex::new(logs_rx))); let tls_config = setup_server.await_setup(config.clone()).await?; let cert_path = cert_dir.join(GRPC_CERT_NAME); @@ -120,6 +132,19 @@ async fn main() -> Result<(), GatewayError> { ); tls_config + } else if let (Some(cert), Some(key)) = (grpc_cert, grpc_key) { + log::info!( + "Using existing gRPC TLS certificates from {}", + cert_dir.display() + ); + TlsConfig { + grpc_cert_pem: cert, + grpc_key_pem: key, + } + } else { + return Err(GatewayError::SetupError( + "gRPC TLS certificates are missing after setup".to_string(), + )); }; // Launch gRPC server. diff --git a/src/setup.rs b/src/setup.rs index 0c650acd..b309c821 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -1,13 +1,11 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::{ - Arc, LazyLock, Mutex, - atomic::{AtomicBool, Ordering}, - }, + sync::{Arc, LazyLock, Mutex}, }; use defguard_version::{Version, server::DefguardVersionLayer}; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Request, Response, Status, transport::Server}; use tower::ServiceBuilder; use tracing::instrument; @@ -17,9 +15,12 @@ use crate::{ config::Config, error::GatewayError, gateway::TlsConfig, - proto::gateway::{DerPayload, InitialSetupInfo, gateway_setup_server}, + proto::gateway::{CertificateInfo, DerPayload, LogEntry, gateway_setup_server}, }; +const AUTH_HEADER: &str = "authorization"; +type LogsReceiver = Arc>>; + static SETUP_CHANNEL: LazyLock> = LazyLock::new(|| { let (tx, rx) = oneshot::channel(); ( @@ -30,30 +31,27 @@ static SETUP_CHANNEL: LazyLock> = LazyLock::new(|| { pub struct GatewaySetupServer { key_pair: Arc>>, - setup_in_progress: Arc, + logs_rx: LogsReceiver, + current_session_token: Arc>>, } impl Clone for GatewaySetupServer { fn clone(&self) -> Self { Self { key_pair: Arc::clone(&self.key_pair), - setup_in_progress: Arc::clone(&self.setup_in_progress), + logs_rx: Arc::clone(&self.logs_rx), + current_session_token: Arc::clone(&self.current_session_token), } } } -impl Default for GatewaySetupServer { - fn default() -> Self { - Self::new() - } -} - impl GatewaySetupServer { #[must_use] - pub fn new() -> Self { + pub fn new(logs_rx: LogsReceiver) -> Self { Self { key_pair: Arc::new(Mutex::new(None)), - setup_in_progress: Arc::new(AtomicBool::new(false)), + logs_rx, + current_session_token: Arc::new(Mutex::new(None)), } } @@ -93,39 +91,154 @@ impl GatewaySetupServer { GatewayError::SetupError("Failed to receive setup configuration from Core".into()) }) } + + fn is_setup_in_progress(&self) -> bool { + let in_progress = self + .current_session_token + .lock() + .expect("Failed to acquire lock on current session token during gateway setup") + .is_some(); + debug!("Setup in progress check: {in_progress}"); + in_progress + } + + fn clear_setup_session(&self) { + debug!("Terminating setup session"); + self.current_session_token + .lock() + .expect("Failed to acquire lock on current session token during gateway setup") + .take(); + debug!("Setup session terminated"); + } + + fn initialize_setup_session(&self, token: String) { + debug!("Establishing new setup session with Core"); + self.current_session_token + .lock() + .expect("Failed to acquire lock on current session token during gateway setup") + .replace(token); + debug!("Setup session established"); + } + + fn verify_session_token(&self, token: &str) -> bool { + debug!("Validating setup session authorization"); + let is_valid = (*self + .current_session_token + .lock() + .expect("Failed to acquire lock on current session token during gateway setup")) + .as_ref() + .is_some_and(|t| t == token); + debug!("Authorization validation result: {is_valid}"); + is_valid + } } #[tonic::async_trait] impl gateway_setup_server::GatewaySetup for GatewaySetupServer { + type StartStream = UnboundedReceiverStream>; + #[instrument(skip(self, request))] - async fn start( + async fn start(&self, request: Request<()>) -> Result, Status> { + debug!("Core initiated setup process, preparing to stream logs"); + if self.is_setup_in_progress() { + error!("Setup already in progress, rejecting new setup request"); + return Err(Status::resource_exhausted("Setup already in progress")); + } + + debug!("Authenticating setup session with Core"); + let token = request + .metadata() + .get(AUTH_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .ok_or_else(|| Status::unauthenticated("Missing or invalid authorization token"))?; + + debug!("Setup session authenticated successfully"); + self.initialize_setup_session(token.to_string()); + + debug!("Preparing to forward Gateway logs to Core in real-time"); + let logs_rx = self.logs_rx.clone(); + + let (tx, rx) = mpsc::unbounded_channel(); + let self_clone = self.clone(); + + debug!("Starting log streaming to Core"); + tokio::spawn(async move { + loop { + let maybe_log_entry = logs_rx.lock().await.try_recv(); + match maybe_log_entry { + Ok(log_entry) => { + if tx.send(Ok(log_entry)).is_err() { + debug!( + "Failed to send log entry to gRPC stream: receiver disconnected" + ); + break; + } + } + Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { + if tx.is_closed() { + debug!("gRPC stream receiver disconnected"); + break; + } + tokio::task::yield_now().await; + } + Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { + debug!("Logs receiver disconnected"); + break; + } + } + } + self_clone.clear_setup_session(); + }); + + debug!("Log stream established, Core will now receive real-time Gateway logs"); + Ok(Response::new(UnboundedReceiverStream::new(rx))) + } + + #[instrument(skip(self, request))] + async fn get_csr( &self, - request: Request, + request: Request, ) -> Result, Status> { - if self.setup_in_progress.load(Ordering::SeqCst) { - return Err(Status::already_exists("Setup is already in progress")); + debug!("Core requested Certificate Signing Request (CSR) generation"); + let token = request + .metadata() + .get(AUTH_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .ok_or_else(|| Status::unauthenticated("Missing or invalid authorization token"))?; + + debug!("Validating Core's authorization for this setup step"); + if !self.verify_session_token(token) { + error!("Invalid session token in get_csr request"); + return Err(Status::unauthenticated("Invalid session token")); } - self.setup_in_progress.store(true, Ordering::SeqCst); - let initial_info = request.into_inner(); + let setup_info = request.into_inner(); + debug!( + "Will generate certificate for hostname: {}", + setup_info.cert_hostname + ); - let new_key_pair = match defguard_certs::generate_key_pair() { + debug!("Generating key pair"); + let key_pair = match defguard_certs::generate_key_pair() { Ok(kp) => kp, Err(err) => { error!("Failed to generate key pair: {err}"); - self.setup_in_progress.store(false, Ordering::SeqCst); - return Err(Status::internal(format!( - "Failed to generate key pair: {err}" - ))); + self.clear_setup_session(); + return Err(Status::internal("Failed to generate key pair")); } }; + debug!("Key pair created"); - let subject_alt_names = vec![initial_info.cert_hostname]; + let subject_alt_names = vec![setup_info.cert_hostname]; + debug!("Preparing Certificate Signing Request for hostname: {subject_alt_names:?}",); let csr = match defguard_certs::Csr::new( - &new_key_pair, + &key_pair, &subject_alt_names, vec![ + // TODO: Change it? (defguard_certs::DnType::CommonName, "Defguard Gateway"), (defguard_certs::DnType::OrganizationName, "Defguard"), ], @@ -133,83 +246,105 @@ impl gateway_setup_server::GatewaySetup for GatewaySetupServer { Ok(csr) => csr, Err(err) => { error!("Failed to generate CSR: {err}"); - self.setup_in_progress.store(false, Ordering::SeqCst); + self.clear_setup_session(); return Err(Status::internal(format!("Failed to generate CSR: {err}"))); } }; + debug!("Certificate Signing Request prepared"); - let response = DerPayload { - der_data: csr.to_der().to_vec(), - }; + self.key_pair + .lock() + .expect("Failed to acquire lock on key pair during gateway setup when trying to store generated key pair") + .replace(key_pair); - { - let mut key_pair_lock = self.key_pair.lock().expect("Failed to lock key_pair mutex"); - *key_pair_lock = Some(new_key_pair); - } + debug!("Encoding Certificate Signing Request for transmission"); + let csr_der = csr.to_der(); + let csr_request = DerPayload { + der_data: csr_der.to_vec(), + }; + debug!( + "Sending Certificate Signing Request to Core for signing ({} bytes)", + csr_request.der_data.len() + ); - Ok(Response::new(response)) + Ok(Response::new(csr_request)) } #[instrument(skip(self, request))] async fn send_cert(&self, request: Request) -> Result, Status> { + debug!("Core sending back signed certificate for installation"); + let token = request + .metadata() + .get(AUTH_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .ok_or_else(|| Status::unauthenticated("Missing or invalid authorization token"))?; + + debug!("Validating Core's authorization to complete setup"); + if !self.verify_session_token(token) { + error!("Invalid session token in send_cert request"); + return Err(Status::unauthenticated("Invalid session token")); + } + let der_payload = request.into_inner(); + let cert_der = der_payload.der_data; + debug!( + "Received signed certificate from Core ({} bytes)", + cert_der.len() + ); + + debug!("Parsing received certificate DER data"); + let grpc_cert_pem = + match defguard_certs::der_to_pem(&cert_der, defguard_certs::PemLabel::Certificate) { + Ok(pem) => pem, + Err(err) => { + error!("Failed to convert certificate DER to PEM: {err}"); + self.clear_setup_session(); + return Err(Status::internal(format!( + "Failed to convert certificate DER to PEM: {err}" + ))); + } + }; + debug!("Certificate processed successfully"); let key_pair = { let key_pair = self .key_pair .lock() - .expect("Failed to lock key_pair mutex") + .expect("Failed to acquire lock on key pair during gateway setup when trying to receive certificate") .take(); if let Some(kp) = key_pair { kp } else { - error!("Key pair not found. The setup session may not have been started properly."); - self.setup_in_progress.store(false, Ordering::SeqCst); + error!( + "Key pair not found during Gateway setup. Key pair generation step might have failed." + ); + self.clear_setup_session(); return Err(Status::internal( - "Key pair not found. The setup session may not have been started properly.", + "Key pair not found during Gateway setup. Key pair generation step might have failed.", )); } }; - info!( - "Received certificate of length: {}", - der_payload.der_data.len() - ); - - let cert_pem = match defguard_certs::der_to_pem( - &der_payload.der_data, - defguard_certs::PemLabel::Certificate, - ) { - Ok(pem) => pem, - Err(err) => { - error!("Failed to convert certificate DER format to PEM: {err}"); - self.setup_in_progress.store(false, Ordering::SeqCst); - return Err(Status::internal(format!( - "Failed to convert certificate DER format to PEM: {err}" - ))); - } - }; - - let config = TlsConfig { + let configuration = TlsConfig { grpc_key_pem: key_pair.serialize_pem(), - grpc_cert_pem: cert_pem, + grpc_cert_pem, }; - { - let Some(sender) = SETUP_CHANNEL.0.lock().await.take() else { - error!("Setup channel sender not found"); - self.setup_in_progress.store(false, Ordering::SeqCst); - return Err(Status::internal("Setup channel sender not found")); - }; + let Some(sender) = SETUP_CHANNEL.0.lock().await.take() else { + error!("Setup channel sender not found"); + return Err(Status::internal("Setup channel sender not found")); + }; - sender.send(config).map_err(|_| { - error!("Failed to send setup configuration through channel"); - Status::internal("Failed to send setup configuration through channel") - })?; - } + sender.send(configuration).map_err(|_| { + error!("Failed to send setup configuration through channel"); + Status::internal("Failed to send setup configuration through channel") + })?; - self.setup_in_progress.store(false, Ordering::SeqCst); + debug!("Setup process completed successfully, cleaning up temporary session"); + self.clear_setup_session(); + debug!("Confirming successful setup to Core"); Ok(Response::new(())) } }