diff --git a/src/enterprise/license.rs b/src/enterprise/license.rs index a2e394365b..4500bdfb8f 100644 --- a/src/enterprise/license.rs +++ b/src/enterprise/license.rs @@ -18,7 +18,6 @@ use crate::{ VERSION, }; -// FIXME: this should be a hardcoded IP, make sure to add appropriate host headers const LICENSE_SERVER_URL: &str = "https://update-service-dev.defguard.net/api/license/renew"; static LICENSE: RwLock> = RwLock::new(None); @@ -690,4 +689,52 @@ mod test { Utc.with_ymd_and_hms(2024, 8, 21, 9, 58, 35).unwrap() ); } + + #[test] + fn test_invalid_license() { + let license = "CigKIDVhMGRhZDRiOWNmZTRiNzZiYjkzYmI1Y2Q5MGM2ZjdjGLL+lrYGErYCiQEzBAABCgAdFiEE8h/UW/EuSO/G0WM4IRoGfgHZ0SsFAmbFvzUACgkQIRoGfgHZ0SuNQggAioLovxAyrgAn+LPO42QIlVHYG8oTs3jnpM0BMx3cXbfy7M0ECsC10HpzIkundems7SgYO/+iJfMMe4mj3kiA+uwacCmPW6VWTIVEIpX2jqRpv7DcDnUSeAszySZl6KhQS+35IPC0Gs2yQNU4/mDsa4VUv9DiL8s7rMM89fe4QmtjVRpFQVgGLm4IM+mRIXTySB2RwmVzw8+YE4z+w4emLxaKWjw4Q7CQxykkPNGlBj224jozs/Biw9eDYCbJOT/5KXNqZ2peht59n6RMVc0SNKE26E8hDmJ61M0Tzj57wQ6nZ3yh6KGyTdCIc9Y9wcrHwZ1Yw1tdh8j/fULUyPtNyA=="; + let license = License::from_base64(license).unwrap(); + assert!(validate_license(Some(&license)).is_err()); + assert!(validate_license(None).is_err()); + + // One day past the expiry date, non-subscription license + let license = License { + customer_id: "test".to_string(), + subscription: false, + valid_until: Some(Utc::now() - TimeDelta::days(1)), + }; + assert!(validate_license(Some(&license)).is_err()); + + // One day before the expiry date, non-subscription license + let license = License { + customer_id: "test".to_string(), + subscription: false, + valid_until: Some(Utc::now() + TimeDelta::days(1)), + }; + assert!(validate_license(Some(&license)).is_ok()); + + // No expiry date, non-subscription license + let license = License { + customer_id: "test".to_string(), + subscription: false, + valid_until: None, + }; + assert!(validate_license(Some(&license)).is_ok()); + + // One day past the maximum overdue date + let license = License { + customer_id: "test".to_string(), + subscription: true, + valid_until: Some(Utc::now() - MAX_OVERDUE_TIME - TimeDelta::days(1)), + }; + assert!(validate_license(Some(&license)).is_err()); + + // One day before the maximum overdue date + let license = License { + customer_id: "test".to_string(), + subscription: true, + valid_until: Some(Utc::now() - MAX_OVERDUE_TIME + TimeDelta::days(1)), + }; + assert!(validate_license(Some(&license)).is_ok()); + } } diff --git a/tests/openid_login.rs b/tests/openid_login.rs index 89ac2b1b24..f1b7f8d4e5 100644 --- a/tests/openid_login.rs +++ b/tests/openid_login.rs @@ -1,5 +1,11 @@ +use chrono::{Duration, Utc}; use defguard::{ - config::DefGuardConfig, db::DbPool, enterprise::handlers::openid_providers::AddProviderData, + config::DefGuardConfig, + db::DbPool, + enterprise::{ + handlers::openid_providers::AddProviderData, + license::{set_cached_license, License}, + }, handlers::Auth, }; use reqwest::{StatusCode, Url}; @@ -67,4 +73,14 @@ async fn test_openid_providers() { assert!(state.is_some()); let redirect_uri = url.query_pairs().find(|(key, _)| key == "redirect_uri"); assert!(redirect_uri.is_some()); + + // Test that the endpoint is forbidden when the license is expired + let new_license = License { + customer_id: "test".to_string(), + subscription: false, + valid_until: Some(Utc::now() - Duration::days(1)), + }; + set_cached_license(Some(new_license)); + let response = client.get("/api/v1/openid/auth_info").send().await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); }