Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/lance-graph-callcenter/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,8 @@ persist = ["dep:arrow", "dep:lance"]
query = ["dep:datafusion", "dep:arrow"]
realtime = ["dep:tokio", "dep:tokio-tungstenite", "dep:serde", "dep:serde_json"]
serve = ["realtime", "query", "dep:axum", "dep:tower-http"]
auth = ["dep:serde", "dep:serde_json"]
auth = ["query", "dep:serde", "dep:serde_json"]
full = ["persist", "query", "realtime", "serve", "auth"]

[dev-dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
387 changes: 387 additions & 0 deletions crates/lance-graph-callcenter/src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
//! DM-7 — JWT extraction middleware + `ActorContext` population.
//!
//! **Phase 1 (this file):** Extract and decode JWT payload (base64),
//! populate `ActorContext`. No signature verification — that requires
//! a JWK endpoint or static key, which is deployment-specific.
//!
//! **Phase 2 (future):** Plug in real verification via a `JwkSetProvider`
//! trait. The `JwtMiddleware::extract_actor` API won't change — only the
//! internal verification step gets wired.
//!
//! # JWT Payload Shape (expected claims)
//!
//! ```json
//! {
//! "sub": "user@example.com",
//! "tenant_id": 42,
//! "roles": ["viewer", "editor"]
//! }
//! ```
//!
//! - `sub` (required) — maps to `ActorContext.actor_id`.
//! - `tenant_id` (required) — maps to `ActorContext.tenant_id` (`TenantId = u64`).
//! - `roles` (optional) — maps to `ActorContext.roles`. Defaults to `[]`.
//!
//! # Zero New Dependencies
//!
//! Uses `serde` + `serde_json` (already gated under `[auth]` feature).
//! Base64 URL-safe decoding is implemented inline (~40 lines) — no
//! `base64` crate, no `jsonwebtoken` crate.
//!
//! Plan: `.claude/plans/callcenter-membrane-v1.md` § DM-7

use lance_graph_contract::auth::{ActorContext, AuthError};
use serde::Deserialize;

/// JWT extraction middleware.
///
/// Phase 1: base64-decode the payload section of a JWT and extract
/// `sub`, `tenant_id`, and `roles` into an `ActorContext`.
///
/// No signature verification in Phase 1 — the token is trusted as-is.
/// Phase 2 will add a `JwkSetProvider` trait for real verification.
pub struct JwtMiddleware;

/// Deserialization target for the JWT payload claims we care about.
#[derive(Deserialize)]
struct JwtClaims {
/// JWT `sub` claim — canonical actor identity.
sub: Option<String>,
/// Custom claim: tenant identifier.
tenant_id: Option<u64>,
/// Custom claim: actor roles. Optional; defaults to empty.
#[serde(default)]
roles: Vec<String>,
}

impl JwtMiddleware {
/// Extract `ActorContext` from a raw JWT token string.
///
/// The token should be in the standard `header.payload.signature`
/// format. Only the payload section is decoded and parsed.
///
/// # Phase 1 Limitations
///
/// - **No signature verification.** The signature section is ignored.
/// Deploy behind a reverse proxy or API gateway that validates
/// signatures before traffic reaches this layer.
/// - **No expiry checking.** `exp` / `nbf` / `iat` are ignored.
/// Phase 2 will enforce temporal validity.
///
/// # Errors
///
/// - `AuthError::MalformedToken` — token doesn't have 3 dot-separated parts.
/// - `AuthError::InvalidBase64` — payload isn't valid base64url.
/// - `AuthError::MissingSub` — payload JSON is missing the `sub` claim.
/// - `AuthError::InvalidPayload` — payload JSON can't be parsed.
pub fn extract_actor(token: &str) -> Result<ActorContext, AuthError> {
// Split into header.payload.signature
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::MalformedToken);
}

// Decode payload (middle part)
let payload_bytes = base64url_decode(parts[1])?;

// Parse JSON
let claims: JwtClaims = serde_json::from_slice(&payload_bytes)
.map_err(|e| AuthError::InvalidPayload(e.to_string()))?;

// Extract required fields
let actor_id = claims.sub.ok_or(AuthError::MissingSub)?;
if actor_id.is_empty() {
return Err(AuthError::MissingSub);
}

let tenant_id = claims.tenant_id.unwrap_or(0);

Ok(ActorContext::new(actor_id, tenant_id, claims.roles))
}

/// Extract `ActorContext` from an `Authorization: Bearer <token>` header value.
///
/// Strips the `Bearer ` prefix if present, then delegates to `extract_actor`.
pub fn extract_from_header(header_value: &str) -> Result<ActorContext, AuthError> {
let token = header_value
.strip_prefix("Bearer ")
.or_else(|| header_value.strip_prefix("bearer "))
.unwrap_or(header_value);
Comment on lines +107 to +109
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Accept any Bearer scheme casing

JwtMiddleware::extract_from_header only strips "Bearer " and "bearer ", but HTTP auth schemes are case-insensitive. Headers like "BEARER <token>" or mixed-case variants are therefore treated as raw tokens and fail JWT parsing, causing avoidable auth failures with clients/proxies that normalize casing differently. Parse the scheme token and compare with eq_ignore_ascii_case before extracting the credential.

Useful? React with 👍 / 👎.

Self::extract_actor(token)
}
}

// ═══════════════════════════════════════════════════════════════════════════
// MINIMAL BASE64URL DECODER
// ═══════════════════════════════════════════════════════════════════════════

/// Decode a base64url-encoded string (RFC 4648 §5) without padding.
///
/// JWT payloads use URL-safe base64 without padding characters (`=`).
/// This decoder handles both padded and unpadded inputs.
///
/// ~40 lines, no external crate. Handles the full base64url alphabet
/// (A-Z, a-z, 0-9, `-`, `_`).
fn base64url_decode(input: &str) -> Result<Vec<u8>, AuthError> {
// Base64url alphabet → 6-bit value
fn char_to_sextet(c: u8) -> Result<u8, AuthError> {
match c {
b'A'..=b'Z' => Ok(c - b'A'),
b'a'..=b'z' => Ok(c - b'a' + 26),
b'0'..=b'9' => Ok(c - b'0' + 52),
b'-' => Ok(62),
b'_' => Ok(63),
b'=' => Ok(0), // padding — value ignored
_ => Err(AuthError::InvalidBase64),
}
}

// Strip padding for length calculation
let stripped = input.trim_end_matches('=');
let input_bytes = stripped.as_bytes();
let len = input_bytes.len();

if len == 0 {
return Ok(Vec::new());
}

// Validate: base64 produces 3 output bytes per 4 input chars.
// Without padding: len%4 can be 0, 2, or 3 (never 1).
if len % 4 == 1 {
return Err(AuthError::InvalidBase64);
}

let out_len = len * 3 / 4;
let mut out = Vec::with_capacity(out_len);

// Process full 4-char groups
let full_groups = len / 4;
for i in 0..full_groups {
let base = i * 4;
let a = char_to_sextet(input_bytes[base])?;
let b = char_to_sextet(input_bytes[base + 1])?;
let c = char_to_sextet(input_bytes[base + 2])?;
let d = char_to_sextet(input_bytes[base + 3])?;

out.push((a << 2) | (b >> 4));
out.push((b << 4) | (c >> 2));
out.push((c << 6) | d);
}

// Handle remaining 2 or 3 chars
let remainder = len % 4;
if remainder >= 2 {
let base = full_groups * 4;
let a = char_to_sextet(input_bytes[base])?;
let b = char_to_sextet(input_bytes[base + 1])?;
out.push((a << 2) | (b >> 4));

if remainder == 3 {
let c = char_to_sextet(input_bytes[base + 2])?;
out.push((b << 4) | (c >> 2));
}
}

Ok(out)
}

/// Encode bytes as base64url without padding (for test helpers).
#[cfg(test)]
fn base64url_encode(input: &[u8]) -> String {
const ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";

let mut out = String::with_capacity((input.len() + 2) / 3 * 4);

for chunk in input.chunks(3) {
let b0 = chunk[0] as usize;
let b1 = chunk.get(1).copied().unwrap_or(0) as usize;
let b2 = chunk.get(2).copied().unwrap_or(0) as usize;

out.push(ALPHABET[(b0 >> 2)] as char);
out.push(ALPHABET[((b0 & 0x03) << 4) | (b1 >> 4)] as char);

if chunk.len() > 1 {
out.push(ALPHABET[((b1 & 0x0F) << 2) | (b2 >> 6)] as char);
}
if chunk.len() > 2 {
out.push(ALPHABET[(b2 & 0x3F)] as char);
}
}

out
}

/// Build a minimal unsigned JWT from a JSON payload string (for tests).
#[cfg(test)]
fn make_test_jwt(payload_json: &str) -> String {
let header = base64url_encode(b"{\"alg\":\"none\",\"typ\":\"JWT\"}");
let payload = base64url_encode(payload_json.as_bytes());
// No signature (Phase 1 doesn't verify)
format!("{header}.{payload}.")
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
use super::*;

// ── Base64url decoder tests ──

#[test]
fn base64url_roundtrip() {
let original = b"Hello, JWT world! \xF0\x9F\x94\x91";
let encoded = base64url_encode(original);
let decoded = base64url_decode(&encoded).unwrap();
assert_eq!(decoded, original);
}

#[test]
fn base64url_empty() {
assert_eq!(base64url_decode("").unwrap(), Vec::<u8>::new());
}

#[test]
fn base64url_padding_tolerance() {
// "Hello" base64url = "SGVsbG8" (no padding) or "SGVsbG8=" (with padding)
let expected = b"Hello";
assert_eq!(base64url_decode("SGVsbG8").unwrap(), expected);
assert_eq!(base64url_decode("SGVsbG8=").unwrap(), expected);
}

#[test]
fn base64url_invalid_char() {
assert_eq!(base64url_decode("!!!"), Err(AuthError::InvalidBase64));
}

#[test]
fn base64url_invalid_length() {
// len%4 == 1 is invalid
assert_eq!(base64url_decode("A"), Err(AuthError::InvalidBase64));
}

// ── JWT extraction tests ──

#[test]
fn valid_jwt_full_claims() {
let jwt = make_test_jwt(
r#"{"sub":"user@example.com","tenant_id":42,"roles":["admin","viewer"]}"#,
);
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
assert_eq!(ctx.actor_id, "user@example.com");
assert_eq!(ctx.tenant_id, 42);
assert_eq!(ctx.roles, vec!["admin", "viewer"]);
assert!(ctx.is_admin());
}

#[test]
fn valid_jwt_minimal_claims() {
let jwt = make_test_jwt(r#"{"sub":"bot-123","tenant_id":1}"#);
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
assert_eq!(ctx.actor_id, "bot-123");
assert_eq!(ctx.tenant_id, 1);
assert!(ctx.roles.is_empty());
assert!(!ctx.is_admin());
}

#[test]
fn valid_jwt_empty_roles() {
let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":0,"roles":[]}"#);
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
assert!(ctx.roles.is_empty());
}

#[test]
fn valid_jwt_missing_tenant_defaults_to_zero() {
let jwt = make_test_jwt(r#"{"sub":"x"}"#);
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
assert_eq!(ctx.tenant_id, 0);
}

#[test]
fn missing_sub_error() {
let jwt = make_test_jwt(r#"{"tenant_id":1,"roles":["viewer"]}"#);
assert_eq!(
JwtMiddleware::extract_actor(&jwt),
Err(AuthError::MissingSub)
);
}

#[test]
fn empty_sub_error() {
let jwt = make_test_jwt(r#"{"sub":"","tenant_id":1}"#);
assert_eq!(
JwtMiddleware::extract_actor(&jwt),
Err(AuthError::MissingSub)
);
}

#[test]
fn malformed_token_no_dots() {
assert_eq!(
JwtMiddleware::extract_actor("not-a-jwt"),
Err(AuthError::MalformedToken)
);
}

#[test]
fn malformed_token_two_parts() {
assert_eq!(
JwtMiddleware::extract_actor("header.payload"),
Err(AuthError::MalformedToken)
);
}

#[test]
fn malformed_token_four_parts() {
assert_eq!(
JwtMiddleware::extract_actor("a.b.c.d"),
Err(AuthError::MalformedToken)
);
}

#[test]
fn invalid_base64_payload() {
// Valid structure (3 parts) but middle part is bad base64
assert!(matches!(
JwtMiddleware::extract_actor("header.!!!invalid.sig"),
Err(AuthError::InvalidBase64)
));
}

#[test]
fn invalid_json_payload() {
let header = base64url_encode(b"{}");
let payload = base64url_encode(b"not json at all {{{");
let token = format!("{header}.{payload}.");
assert!(matches!(
JwtMiddleware::extract_actor(&token),
Err(AuthError::InvalidPayload(_))
));
}

#[test]
fn extract_from_bearer_header() {
let jwt = make_test_jwt(r#"{"sub":"user@test.com","tenant_id":7}"#);
let header = format!("Bearer {jwt}");
let ctx = JwtMiddleware::extract_from_header(&header).unwrap();
assert_eq!(ctx.actor_id, "user@test.com");
assert_eq!(ctx.tenant_id, 7);
}

#[test]
fn extract_from_header_lowercase_bearer() {
let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":1}"#);
let header = format!("bearer {jwt}");
let ctx = JwtMiddleware::extract_from_header(&header).unwrap();
assert_eq!(ctx.actor_id, "x");
}

#[test]
fn extract_from_header_no_prefix() {
let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":1}"#);
let ctx = JwtMiddleware::extract_from_header(&jwt).unwrap();
assert_eq!(ctx.actor_id, "x");
}
}
Loading
Loading