diff --git a/crates/lance-graph-callcenter/src/policy.rs b/crates/lance-graph-callcenter/src/policy.rs index 7f5e819f..81f7a324 100644 --- a/crates/lance-graph-callcenter/src/policy.rs +++ b/crates/lance-graph-callcenter/src/policy.rs @@ -19,11 +19,17 @@ use std::sync::Arc; #[cfg(feature = "auth-rls-lite")] -use datafusion::common::tree_node::Transformed; +use datafusion::arrow::datatypes::DataType; +#[cfg(feature = "auth-rls-lite")] +use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; #[cfg(feature = "auth-rls-lite")] use datafusion::common::Result as DFResult; #[cfg(feature = "auth-rls-lite")] -use datafusion::logical_expr::LogicalPlan; +use datafusion::common::{DataFusionError, ScalarValue}; +#[cfg(feature = "auth-rls-lite")] +use datafusion::logical_expr::{ + lit, ColumnarValue, Expr, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; #[cfg(feature = "auth-rls-lite")] use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule}; @@ -107,6 +113,79 @@ pub struct ColumnMaskRewriter { pub actor_role: String, } +#[cfg(feature = "auth-rls-lite")] +impl ColumnMaskRewriter { + /// Replace a column expression according to the given [`RedactionMode`]. + /// + /// `Hash` mode binds an UDF reference (`policy_hash_v1`) that is + /// intentionally NOT registered yet — see [`NotYetWiredHashUdf`] + /// and PR-F1b. Plans build, but execution fails loud with + /// `NotImplemented("policy_hash_v1 UDF not yet registered ...")`. + /// This is the "loud > silent" fix for the silent placeholder hole. + /// `Truncate(n)` uses DataFusion's built-in `substr`. + fn mask_expr(expr: &Expr, mode: &RedactionMode) -> Expr { + match mode { + RedactionMode::Null => Expr::Literal(ScalarValue::Null, None), + RedactionMode::Constant => lit("[REDACTED]"), + RedactionMode::Hash => { + // Reference the unregistered policy_hash_v1 UDF. The + // plan builds (so call sites compose), but executing + // the plan returns a `NotImplemented` error at the + // first row — preventing silent disclosure if the + // wiring is forgotten. + Expr::ScalarFunction( + datafusion::logical_expr::expr::ScalarFunction::new_udf( + Arc::new(ScalarUDF::from(NotYetWiredHashUdf::new())), + vec![expr.clone()], + ), + ) + } + RedactionMode::Truncate(n) => { + // Use DataFusion's built-in `substr(col, 1, n)`. + let col_expr = expr.clone(); + let start = lit(1_i64); + let length = lit(*n as i64); + Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( + Arc::new(datafusion::functions::unicode::substr().as_ref().clone()), + vec![col_expr, start, length], + )) + } + } + } + + /// Recursively rewrite an expression tree, replacing every + /// `Expr::Column(c)` whose name is in the policy with the + /// appropriate redaction expression. + /// + /// Uses `Expr::transform_down` from DataFusion's TreeNode trait — + /// it walks into BinaryExpr operands, AggregateFunction args, + /// ScalarFunction args, Cast expressions, Sort exprs, and so on. + /// This is what closes the WHERE / JOIN / aggregate leak (Loose + /// End #1). + fn rewrite_expr_deep(expr: Expr, policy: &ColumnMaskPolicy) -> DFResult> { + expr.transform_down(|e| match e { + Expr::Column(ref col) => { + if let Some(mode) = policy.columns.get(col.name()) { + // `Jump` prevents `transform_down` from descending + // into the freshly-built mask expression. Otherwise, + // RedactionMode::Hash (which wraps the column in a + // ScalarFunction whose first arg is the original + // Column) would re-enter the visitor on that Column + // and recurse infinitely. + Ok(Transformed::new( + Self::mask_expr(&e, mode), + true, + TreeNodeRecursion::Jump, + )) + } else { + Ok(Transformed::no(e)) + } + } + other => Ok(Transformed::no(other)), + }) + } +} + #[cfg(feature = "auth-rls-lite")] impl PolicyRewriter for ColumnMaskRewriter { fn kind(&self) -> PolicyKind { @@ -116,11 +195,52 @@ impl PolicyRewriter for ColumnMaskRewriter { "column_mask" } fn rewrite_plan(&self, plan: LogicalPlan) -> DFResult> { - // Walk plan; on Projection, rewrite expressions for redacted columns. - // For this PR ship the structural skeleton; the actual UDF wrap lands - // in a follow-up once redaction UDFs are registered. - // TODO: wrap Expr::Column(c) in mask_udf(...) for c in policy.columns - Ok(Transformed::no(plan)) + // Resolve the policy by walking the plan's input chain to the + // nearest TableScan. Without a TableScan we have no table + // name → no policy applies. + let table_name = Self::extract_table_name(&plan); + let Some(policy) = table_name + .as_deref() + .and_then(|t| self.registry.lookup(t)) + else { + return Ok(Transformed::no(plan)); + }; + let policy = policy.clone(); + + // Walk EVERY expression in this node — Projection's projection + // list, Filter's predicate, Aggregate's group/aggr exprs, + // Join's on/filter, Sort's exprs, etc. `map_expressions` + // dispatches per-variant in DataFusion 52, so a single call + // covers WHERE / JOIN / GROUP BY / ORDER BY / aggregate args. + // Closes Loose End #1 (PR-F1) — the WHERE / aggregate leak + // that the Projection-only rewriter let through. + let transformed = plan.map_expressions(|e| Self::rewrite_expr_deep(e, &policy))?; + // `map_expressions` doesn't recompute schemas for some variants + // (Projection, Aggregate); recompute so field types stay + // consistent after a Column was replaced by a literal/UDF call + // of a different type. + transformed.map_data(|p| p.recompute_schema()) + } +} + +#[cfg(feature = "auth-rls-lite")] +impl ColumnMaskRewriter { + /// Walk down the plan tree to find a `TableScan` and extract its name. + /// This is a best-effort heuristic for v1; it handles the common case + /// of `Projection → TableScan` and `Projection → Filter → TableScan`. + fn extract_table_name(plan: &LogicalPlan) -> Option { + match plan { + LogicalPlan::TableScan(scan) => Some(scan.table_name.table().to_string()), + // Recurse through single-input nodes (Filter, Sort, Limit, etc.) + other => { + let inputs = other.inputs(); + if inputs.len() == 1 { + Self::extract_table_name(inputs[0]) + } else { + None + } + } + } } } @@ -144,6 +264,91 @@ impl OptimizerRule for ColumnMaskRewriter { } } +// ── policy_hash_v1 — intentionally-unregistered hard-fail UDF ──────────────── +// +// Loose End #2 (PR-F1 close): the previous Hash redaction returned a +// silent `lit("***REDACTED***")` placeholder. If the real hash UDF is +// forgotten in a follow-up wiring, every Hash-masked column would +// silently render as `"***REDACTED***"` — a string, not a hash, with +// no surface signal that the policy is mis-wired. +// +// This UDF replaces the placeholder. It binds at plan time (so plans +// COMPOSE), but its `invoke_with_args` returns `NotImplemented` — +// execution fails loudly with "policy_hash_v1 UDF not yet registered +// — see PR-F1b". Loud > silent. +// +// PR-F1b will replace the body with a real FNV-64 / SHA-256-truncated +// implementation and register the UDF in the SessionContext. +#[cfg(feature = "auth-rls-lite")] +#[derive(Debug)] +pub struct NotYetWiredHashUdf { + signature: Signature, +} + +#[cfg(feature = "auth-rls-lite")] +impl Default for NotYetWiredHashUdf { + fn default() -> Self { + Self::new() + } +} + +#[cfg(feature = "auth-rls-lite")] +impl NotYetWiredHashUdf { + pub fn new() -> Self { + // Accept any single argument — the actual hash will be over + // the column's bytes, which we treat as opaque at plan time. + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +#[cfg(feature = "auth-rls-lite")] +impl PartialEq for NotYetWiredHashUdf { + fn eq(&self, other: &Self) -> bool { + self.name() == other.name() + } +} +#[cfg(feature = "auth-rls-lite")] +impl Eq for NotYetWiredHashUdf {} +#[cfg(feature = "auth-rls-lite")] +impl std::hash::Hash for NotYetWiredHashUdf { + fn hash(&self, s: &mut H) { + self.name().hash(s); + } +} + +#[cfg(feature = "auth-rls-lite")] +impl ScalarUDFImpl for NotYetWiredHashUdf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + /// Stable name — surfaces in plan rendering so test assertions + /// (and operators inspecting EXPLAIN output) can verify the + /// wrap was applied. + fn name(&self) -> &str { + "policy_hash_v1" + } + fn signature(&self) -> &Signature { + &self.signature + } + /// Hash output is conventionally a 64-bit unsigned integer + /// (FNV-64 is the v1 target). Fixing the return type here means + /// upstream operators (aggregates, joins) get a stable schema + /// even though the real implementation is deferred. + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::UInt64) + } + fn invoke_with_args( + &self, + _args: datafusion::logical_expr::ScalarFunctionArgs, + ) -> DFResult { + Err(DataFusionError::NotImplemented( + "policy_hash_v1 UDF not yet registered — see PR-F1b".into(), + )) + } +} + // ── Row encryption policy (stub, no executor yet) ──────────────────────────── /// Row encryption policy: encrypt selected columns at rest using a key @@ -302,8 +507,381 @@ mod tests { let transformed = rewriter .rewrite_plan(plan) .expect("rewrite should succeed"); - // Skeleton implementation — should be a no-op until the UDF wrap - // lands. + // Empty registry — no projection to rewrite, so pass-through. assert!(!transformed.transformed); } + + // ── Column masking rewrite tests (one per RedactionMode) ──────────── + + #[cfg(feature = "auth-rls-lite")] + mod mask_rewrite_tests { + use super::*; + use datafusion::common::tree_node::TreeNode; + use datafusion::datasource::{provider_as_source, MemTable}; + use datafusion::logical_expr::builder::LogicalPlanBuilder; + use datafusion::optimizer::OptimizerContext; + + /// Schema: `name` (Utf8), `ssn` (Utf8), `card` (Utf8), `score` (Int32). + fn mask_schema() -> Arc { + Arc::new(arrow::datatypes::Schema::new(vec![ + arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false), + arrow::datatypes::Field::new("ssn", arrow::datatypes::DataType::Utf8, false), + arrow::datatypes::Field::new("card", arrow::datatypes::DataType::Utf8, false), + arrow::datatypes::Field::new("score", arrow::datatypes::DataType::Int32, true), + ])) + } + + fn mem_source( + schema: Arc, + ) -> Arc { + Arc::new(MemTable::try_new(schema, vec![vec![]]).unwrap()) + } + + /// Build a Projection → TableScan plan selecting all columns. + fn scan_with_projection(table_name: &str) -> LogicalPlan { + use datafusion::logical_expr::col; + let src = provider_as_source(mem_source(mask_schema())); + LogicalPlanBuilder::scan(table_name, src, None) + .unwrap() + .project(vec![ + col("name"), + col("ssn"), + col("card"), + col("score"), + ]) + .unwrap() + .build() + .unwrap() + } + + fn apply(plan: LogicalPlan, rewriter: &ColumnMaskRewriter) -> LogicalPlan { + let cfg = OptimizerContext::new(); + plan.transform_down(|n| rewriter.rewrite(n, &cfg)) + .unwrap() + .data + } + + fn make_rewriter( + table: &str, + masks: Vec<(&str, RedactionMode)>, + ) -> ColumnMaskRewriter { + let mut columns = HashMap::new(); + for (col, mode) in masks { + columns.insert(col.to_string(), mode); + } + let mut registry = ColumnMaskRegistry::new(); + registry.register(ColumnMaskPolicy { + table_name: table.to_string(), + columns, + }); + ColumnMaskRewriter { + registry: Arc::new(registry), + actor_role: "analyst".to_string(), + } + } + + #[test] + fn redaction_mode_null_replaces_column_with_null() { + let plan = scan_with_projection("customers"); + let rewriter = make_rewriter("customers", vec![("ssn", RedactionMode::Null)]); + let rewritten = apply(plan, &rewriter); + let s = format!("{rewritten}"); + // The SSN column should be replaced with NULL. + assert!( + s.contains("NULL"), + "expected NULL literal in rewritten plan: {s}" + ); + // Other columns should remain. + assert!(s.contains("name"), "name column should be preserved: {s}"); + assert!(s.contains("card"), "card column should be preserved: {s}"); + } + + #[test] + fn redaction_mode_constant_replaces_column_with_redacted() { + let plan = scan_with_projection("customers"); + let rewriter = make_rewriter("customers", vec![("ssn", RedactionMode::Constant)]); + let rewritten = apply(plan, &rewriter); + let s = format!("{rewritten}"); + assert!( + s.contains("[REDACTED]"), + "expected [REDACTED] literal in rewritten plan: {s}" + ); + assert!(s.contains("name"), "name column should be preserved: {s}"); + } + + #[test] + fn redaction_mode_hash_binds_not_yet_wired_udf() { + // PR-F1 close (Loose End #2): Hash mode now binds the + // intentionally-unregistered `policy_hash_v1` UDF instead + // of emitting a silent `***REDACTED***` placeholder. + // Plans build (so the rewriter composes), but execution + // fails loud — preventing silent disclosure when wiring is + // forgotten. The real implementation lands in PR-F1b. + let plan = scan_with_projection("customers"); + let rewriter = make_rewriter("customers", vec![("ssn", RedactionMode::Hash)]); + let rewritten = apply(plan, &rewriter); + let s = format!("{rewritten}"); + assert!( + s.contains("policy_hash_v1"), + "expected policy_hash_v1 UDF reference in rewritten plan: {s}" + ); + assert!( + !s.contains("***REDACTED***"), + "Hash mode must not emit silent ***REDACTED*** placeholder: {s}" + ); + assert!(s.contains("name"), "name column should be preserved: {s}"); + } + + #[test] + fn redaction_mode_truncate_wraps_column_in_substr() { + let plan = scan_with_projection("customers"); + let rewriter = + make_rewriter("customers", vec![("card", RedactionMode::Truncate(4))]); + let rewritten = apply(plan, &rewriter); + let s = format!("{rewritten}"); + // Truncate(4) should produce a substr() call. + assert!( + s.contains("substr"), + "expected substr function in rewritten plan: {s}" + ); + // The name column should be untouched. + assert!(s.contains("name"), "name column should be preserved: {s}"); + } + } + + // ── Full-tree leak tests (PR-F1 close: Loose End #1) ─────────────────── + // + // These tests pin the security-critical invariant that masked columns + // do NOT leak through Filter (WHERE), Aggregate (MAX/SUM/...), + // GROUP BY, JOIN, or Sort nodes. They build plans without MemTable + // (so they compile under `auth-rls-lite` alone) using + // `datafusion::logical_expr::table_scan`. + #[cfg(feature = "auth-rls-lite")] + mod full_tree_leak_tests { + use super::*; + use datafusion::common::tree_node::TreeNode; + use datafusion::functions_aggregate::expr_fn::max; + use datafusion::logical_expr::{col, table_scan, Expr, LogicalPlan}; + use datafusion::optimizer::OptimizerContext; + + fn users_schema() -> arrow::datatypes::Schema { + arrow::datatypes::Schema::new(vec![ + arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false), + arrow::datatypes::Field::new("ssn", arrow::datatypes::DataType::Utf8, false), + arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false), + ]) + } + + fn make_rewriter(masks: Vec<(&str, RedactionMode)>) -> ColumnMaskRewriter { + let mut columns = HashMap::new(); + for (col_name, mode) in masks { + columns.insert(col_name.to_string(), mode); + } + let mut registry = ColumnMaskRegistry::new(); + registry.register(ColumnMaskPolicy { + table_name: "users".to_string(), + columns, + }); + ColumnMaskRewriter { + registry: Arc::new(registry), + actor_role: "analyst".to_string(), + } + } + + fn apply(plan: LogicalPlan, rewriter: &ColumnMaskRewriter) -> LogicalPlan { + let cfg = OptimizerContext::new(); + plan.transform_down(|n| rewriter.rewrite(n, &cfg)) + .unwrap() + .data + } + + /// Recursively scan a plan for any Filter node and collect predicate + /// expression strings. Used to assert WHERE clauses got rewritten. + fn filter_predicates(plan: &LogicalPlan) -> Vec { + let mut out = Vec::new(); + collect_filters(plan, &mut out); + out + } + fn collect_filters(plan: &LogicalPlan, out: &mut Vec) { + if let LogicalPlan::Filter(f) = plan { + out.push(format!("{}", f.predicate)); + } + for input in plan.inputs() { + collect_filters(input, out); + } + } + + /// Recursively scan a plan for any Aggregate node and collect + /// aggregate-expression strings. Used to assert MAX/SUM/... + /// arguments got rewritten. + fn aggregate_exprs(plan: &LogicalPlan) -> Vec { + let mut out = Vec::new(); + collect_aggregates(plan, &mut out); + out + } + fn collect_aggregates(plan: &LogicalPlan, out: &mut Vec) { + if let LogicalPlan::Aggregate(a) = plan { + for e in &a.aggr_expr { + out.push(format!("{}", e)); + } + } + for input in plan.inputs() { + collect_aggregates(input, out); + } + } + + /// Reference plain `Expr::Column("ssn")` inside a Filter predicate + /// and confirm the rewriter walked into the Filter, not just the + /// outer Projection. + /// + /// Plan shape: + /// Projection(id) → Filter(ssn = '123-45-6789') → TableScan(users) + /// + /// Pre-fix: the rewriter only rewrites Projection (which projects + /// `id`, untouched), and the Filter still references + /// `ssn` directly — leaking the unmasked SSN. + /// Post-fix: the Filter predicate's `ssn` reference is replaced + /// with the configured mask (here, the [REDACTED] + /// constant). + #[test] + fn test_where_clause_does_not_leak_unmasked_column() { + let schema = users_schema(); + let plan = table_scan(Some("users"), &schema, None) + .unwrap() + .filter(col("ssn").eq(Expr::Literal( + datafusion::common::ScalarValue::Utf8(Some("123-45-6789".into())), + None, + ))) + .unwrap() + .project(vec![col("id")]) + .unwrap() + .build() + .unwrap(); + + let rewriter = make_rewriter(vec![("ssn", RedactionMode::Constant)]); + let rewritten = apply(plan, &rewriter); + + let preds = filter_predicates(&rewritten); + assert!( + !preds.is_empty(), + "expected at least one Filter node in the rewritten plan" + ); + // The predicate must NOT mention bare `ssn` (the unmasked + // column ref); it must reference the mask literal instead. + for p in &preds { + assert!( + p.contains("[REDACTED]"), + "Filter predicate must contain the mask literal — leaked unmasked ssn: {p}" + ); + assert!( + !contains_bare_ssn(p), + "Filter predicate still references bare `ssn` column — column leaked: {p}" + ); + } + } + + /// Reference `Expr::AggregateFunction(MAX, [Column(ssn)])` inside + /// an Aggregate node and confirm the aggregate's argument got + /// rewritten. + /// + /// Plan shape: + /// Aggregate(MAX(ssn)) → TableScan(users) + /// + /// Pre-fix: rewriter only handles Projection; the Aggregate's + /// `MAX(ssn)` argument is unchanged → MAX runs over + /// unmasked SSN values, exposing the maximum. + /// Post-fix: `MAX(ssn)` becomes `MAX([REDACTED])` — the + /// aggregate sees the mask, not the raw column. + #[test] + fn test_max_ssn_aggregate_is_masked() { + let schema = users_schema(); + let plan = table_scan(Some("users"), &schema, None) + .unwrap() + .aggregate(Vec::::new(), vec![max(col("ssn"))]) + .unwrap() + .build() + .unwrap(); + + let rewriter = make_rewriter(vec![("ssn", RedactionMode::Constant)]); + let rewritten = apply(plan, &rewriter); + + let aggs = aggregate_exprs(&rewritten); + assert!( + !aggs.is_empty(), + "expected at least one Aggregate node in the rewritten plan" + ); + for a in &aggs { + assert!( + a.contains("[REDACTED]"), + "Aggregate must operate on the mask literal, not bare ssn: {a}" + ); + assert!( + !contains_bare_ssn(a), + "Aggregate still references bare `ssn` column — column leaked: {a}" + ); + } + } + + /// Hash mode binds an unregistered UDF reference so plans BUILD + /// (no panic at plan time), but a Hash-masked column never + /// resolves to the silent `***REDACTED***` literal. + #[test] + fn test_hash_mode_binds_not_yet_wired_udf_not_silent_placeholder() { + let schema = users_schema(); + let plan = table_scan(Some("users"), &schema, None) + .unwrap() + .project(vec![col("ssn")]) + .unwrap() + .build() + .unwrap(); + + let rewriter = make_rewriter(vec![("ssn", RedactionMode::Hash)]); + let rewritten = apply(plan, &rewriter); + let s = format!("{rewritten}"); + // Plan-time: the unregistered UDF name appears in the plan. + assert!( + s.contains("policy_hash_v1"), + "Hash mode must bind the unregistered policy_hash_v1 UDF — got: {s}" + ); + // Plan-time: the silent placeholder must NOT be there. + assert!( + !s.contains("***REDACTED***"), + "Hash mode must not emit silent ***REDACTED*** placeholder — got: {s}" + ); + } + + /// Helper: does the predicate string still mention `ssn` as a + /// bare column (not nested inside an alias / mask literal)? + /// + /// Conservative: if the literal substring "ssn" appears AND it's + /// not part of "[REDACTED]" / mask-literal context, treat it as + /// a leak. The Display impls used by DataFusion include the + /// column name verbatim, e.g. `users.ssn` for a Column ref. + fn contains_bare_ssn(s: &str) -> bool { + // Either the unqualified or qualified column reference. + s.contains("users.ssn") || ends_with_word(s, "ssn") + } + fn ends_with_word(s: &str, word: &str) -> bool { + // Look for the word with a non-alphanumeric boundary on either + // side. This avoids false positives on tokens like "ssn_hash". + let bytes = s.as_bytes(); + let wlen = word.len(); + if bytes.len() < wlen { + return false; + } + for i in 0..=bytes.len() - wlen { + if &bytes[i..i + wlen] == word.as_bytes() { + let before_ok = i == 0 + || !(bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_'); + let after_ok = i + wlen == bytes.len() + || !(bytes[i + wlen].is_ascii_alphanumeric() + || bytes[i + wlen] == b'_'); + if before_ok && after_ok { + return true; + } + } + } + false + } + } }