Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions crates/defguard_core/src/enterprise/posture/evaluation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use defguard_common::db::Id;
use defguard_proto::enterprise::posture::{
DevicePostureCheckRequest, DevicePostureData, UnavailableReason,
bool_check::Result as BoolResult, string_check::Result as StringResult,
Expand Down Expand Up @@ -202,19 +203,21 @@ fn evaluate_os_rule(
/// Returns [`PostureResult::Fail`] with accumulated [`FailureReason`]s otherwise.
pub async fn validate_posture(
pool: &PgPool,
request: &DevicePostureCheckRequest,
location_id: Id,
pubkey: &str,
posture_data: &Option<DevicePostureData>,
) -> Result<PostureResult, PostureCheckError> {
debug!(
"Performing posture check for device {}: {:?}",
request.pubkey, request.device_posture_data
pubkey, posture_data
);

// If location has no assigned postures - pass immediately (no license required).
let posture_ids = DevicePostureLocation::find_by_location(pool, request.location_id).await?;
let posture_ids = DevicePostureLocation::find_by_location(pool, location_id).await?;
if posture_ids.is_empty() {
debug!(
"No posture policies assigned to location {} — passing device {}",
request.location_id, request.pubkey
location_id, pubkey
);
return Ok(PostureResult::Pass);
}
Expand All @@ -223,17 +226,17 @@ pub async fn validate_posture(
if !is_enterprise_license_active() {
warn!(
"No active enterprise license - posture check aborted for device {}",
request.pubkey
pubkey
);
return Err(PostureCheckError::NoActiveEnterpriseLicense);
}

let data = match request.device_posture_data.as_ref() {
let data = match posture_data.as_ref() {
Some(d) => d,
None => {
info!(
"Missing posture data - posture check failed for device {}",
request.pubkey
pubkey
);
return Ok(PostureResult::Fail(vec![FailureReason::MissingPostureData]));
}
Expand Down Expand Up @@ -305,7 +308,7 @@ pub async fn validate_posture(
}

if all_failures.is_empty() {
info!("Posture check passed for device {}", request.pubkey);
info!("Posture check passed for device {}", pubkey);
Ok(PostureResult::Pass)
} else {
Ok(PostureResult::Fail(all_failures))
Expand Down
76 changes: 61 additions & 15 deletions crates/defguard_core/src/grpc/proxy/client_mfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl ClientMfaServer {
pub async fn start_client_mfa_login(
&mut self,
request: ClientMfaStartRequest,
) -> Result<ClientMfaStartResponse, Status> {
) -> Result<ClientMfaStartOutcome, Status> {
debug!("Starting desktop client login: {request:?}");
// fetch location
let Ok(Some(location)) =
Expand Down Expand Up @@ -203,6 +203,39 @@ impl ClientMfaServer {
// validate user is allowed to connect to a given location
Self::validate_location_access(&self.pool, &location, &user_info).await?;

// Evaluate postures if necessary.
let has_postures = location.has_postures(&self.pool).await.map_err(|err| {
error!(
"Failed to fetch postures for location {}({}): {err}",
location.name, location.id
);
Status::internal("unexpected error")
})?;
if has_postures {
let posture_result = validate_posture(
&self.pool,
location.id,
&request.pubkey,
&request.posture_data,
)
.await
.map_err(|err| match err {
PostureCheckError::NoActiveEnterpriseLicense => {
Status::failed_precondition("enterprise license required for posture checks")
}
PostureCheckError::DbError(e) => {
error!("DB error during posture validation: {e}");
Status::internal("unexpected error")
}
})?;

// Posture check failed - return payload with reasons
if let PostureResult::Fail(reasons) = posture_result {
let failed_checks = reasons.iter().map(|r| r.to_string()).collect();
return Ok(ClientMfaStartOutcome::Rejected { failed_checks });
}
}

user.verify_mfa_state(&self.pool).await.map_err(|err| {
error!(
"Failed to verify MFA state for user {}: {err}",
Expand Down Expand Up @@ -378,10 +411,10 @@ impl ClientMfaServer {
},
);

Ok(ClientMfaStartResponse {
Ok(ClientMfaStartOutcome::Approved(ClientMfaStartResponse {
token,
challenge: response_challenge,
})
}))
}

/// Checks if given user is allowed to access a location
Expand Down Expand Up @@ -824,18 +857,22 @@ impl ClientMfaServer {
Self::validate_location_access(&self.pool, &location, &user_info).await?;

// Evaluate posture.
let posture_result =
validate_posture(&self.pool, &request)
.await
.map_err(|err| match err {
PostureCheckError::NoActiveEnterpriseLicense => Status::failed_precondition(
"enterprise license required for posture checks",
),
PostureCheckError::DbError(e) => {
error!("DB error during posture validation: {e}");
Status::internal("unexpected error")
}
})?;
let posture_result = validate_posture(
&self.pool,
location.id,
&request.pubkey,
&request.device_posture_data,
)
.await
.map_err(|err| match err {
PostureCheckError::NoActiveEnterpriseLicense => {
Status::failed_precondition("enterprise license required for posture checks")
}
PostureCheckError::DbError(e) => {
error!("DB error during posture validation: {e}");
Status::internal("unexpected error")
}
})?;

// Posture check failed - return payload with reasons
if let PostureResult::Fail(reasons) = posture_result {
Expand Down Expand Up @@ -992,6 +1029,15 @@ pub enum PostureCheckOutcome {
Rejected { failed_checks: Vec<String> },
}

/// Result of a [`ClientMfaServer::start_client_mfa_login`] call.
/// Adds posture check outcome info.
pub enum ClientMfaStartOutcome {
/// Posture evaluation succeeded or was unnecessary.
Approved(ClientMfaStartResponse),
/// Posture evaluation failed; the contained list describes which checks failed.
Rejected { failed_checks: Vec<String> },
}

#[cfg(test)]
mod tests {
use std::{
Expand Down
18 changes: 8 additions & 10 deletions crates/defguard_core/src/location_management/allowed_peers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use defguard_common::db::{Id, models::WireguardNetwork};
use defguard_proto::gateway::Peer;
use sqlx::{PgExecutor, query};
use sqlx::{PgConnection, PgExecutor, query};

use crate::grpc::should_prevent_service_location_usage;

Expand All @@ -11,13 +11,10 @@ use crate::grpc::should_prevent_service_location_usage;
///
/// If the location is a service location, only returns peers if enterprise features are enabled.
/// MFA-enabled locations only return peers backed by an active session with a runtime preshared key.
pub async fn get_location_allowed_peers<'e, E>(
pub async fn get_location_allowed_peers(
location: &WireguardNetwork<Id>,
executor: E,
) -> sqlx::Result<Vec<Peer>>
where
E: PgExecutor<'e>,
{
conn: &mut PgConnection,
) -> sqlx::Result<Vec<Peer>> {
debug!("Fetching all allowed peers for location {}", location.id);

if should_prevent_service_location_usage(location) {
Expand All @@ -28,7 +25,8 @@ where
return Ok(Vec::new());
}

if !location.mfa_enabled() {
let has_postures = location.has_postures(&mut *conn).await?;
if !location.mfa_enabled() && !has_postures {
let rows = query!(
"SELECT d.wireguard_pubkey pubkey, \
ARRAY( \
Expand All @@ -44,7 +42,7 @@ where
ORDER BY d.id ASC",
location.id,
)
.fetch_all(executor)
.fetch_all(&mut *conn)
.await?;

return Ok(rows
Expand Down Expand Up @@ -84,7 +82,7 @@ where
ORDER BY d.id ASC",
location.id,
)
.fetch_all(executor)
.fetch_all(&mut *conn)
.await?;

Ok(rows
Expand Down
2 changes: 1 addition & 1 deletion crates/defguard_gateway_manager/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl GatewayHandler {
);
}

let peers = get_location_allowed_peers(&network, &self.pool).await?;
let peers = get_location_allowed_peers(&network, &mut *conn).await?;

let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn).await?;
let payload = Some(core_response::Payload::Config(Configuration::new(
Expand Down
13 changes: 10 additions & 3 deletions crates/defguard_proxy_manager/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use defguard_core::{
},
grpc::{
GatewayEvent,
proxy::client_mfa::{ClientLoginSession, ClientMfaServer, PostureCheckOutcome},
proxy::client_mfa::{ClientLoginSession, ClientMfaServer, ClientMfaStartOutcome, PostureCheckOutcome},
},
version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported},
};
Expand Down Expand Up @@ -663,9 +663,16 @@ impl ProxyHandler {
.start_client_mfa_login(request)
.await
{
Ok(response_payload) => {
Ok(ClientMfaStartOutcome::Approved(response_payload)) => {
Some(core_response::Payload::ClientMfaStart(response_payload))
}
},
Ok(ClientMfaStartOutcome::Rejected{ failed_checks }) => {
Some(core_response::Payload::DevicePostureRejected(
DevicePostureRejection {
failed_posture_checks: failed_checks,
},
))
},
Err(err) => {
error!("client MFA start error {err}");
Some(core_response::Payload::CoreError(err.into()))
Expand Down
Loading