Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ca9b1cd
feat: Add DELETE/UPDATE hooks to TableProvider trait
ethan-tyler Dec 7, 2025
0e9e691
test: Add unit tests for DmlCapabilities
ethan-tyler Dec 7, 2025
695716c
fix: Use assert_eq! instead of assert! for const values
ethan-tyler Dec 7, 2025
ebd7bcd
feat: Implement DELETE/UPDATE for MemTable
ethan-tyler Dec 7, 2025
5b35c4c
test: Update existing delete.slt and update.slt for MemTable DML
ethan-tyler Dec 7, 2025
01d7e0e
Merge pull request #1 from ethan-tyler/feat/tableprovider-dml-hooks
ethan-tyler Dec 7, 2025
a59d9a6
Merge pull request #2 from ethan-tyler/feat/memtable-dml-impl
ethan-tyler Dec 7, 2025
8b7205b
feat: Add DELETE/UPDATE hooks to TableProvider trait
ethan-tyler Dec 7, 2025
04fd988
test: Add unit tests for DmlCapabilities
ethan-tyler Dec 7, 2025
30d98c5
fix: Use assert_eq! instead of assert! for const values
ethan-tyler Dec 7, 2025
b038d04
fix: Use assert!() for bool assertions and update expected DML errors
ethan-tyler Dec 7, 2025
d21432f
feat: Implement DELETE/UPDATE for MemTable
ethan-tyler Dec 7, 2025
4adc590
test: Update existing delete.slt and update.slt for MemTable DML
ethan-tyler Dec 7, 2025
3e635b3
refactor: Remove DmlCapabilities, simplify DELETE/UPDATE API
ethan-tyler Dec 9, 2025
a713047
refactor: Remove DmlCapabilities, simplify DELETE/UPDATE API
ethan-tyler Dec 9, 2025
9a5788a
test: Update existing delete.slt and update.slt for MemTable DML
ethan-tyler Dec 9, 2025
14d6075
Merge branch 'main' into feat/tableprovider-dml-hooks
ethan-tyler Dec 9, 2025
c67dc67
Merge feat/memtable-dml-impl: Add MemTable DML implementation and tests
ethan-tyler Dec 9, 2025
9a206fa
Merge main into feat/tableprovider-dml-hooks
ethan-tyler Dec 9, 2025
bee5f1e
fix: Remove stale DmlCapabilities exports from merge
ethan-tyler Dec 9, 2025
8cf796c
Merge branch 'main' into feat/tableprovider-dml-hooks
ethan-tyler Dec 9, 2025
cd472f1
Merge branch 'main' into feat/tableprovider-dml-hooks
ethan-tyler Dec 19, 2025
94a2b68
fix: Evaluate UPDATE expressions only on matching rows
ethan-tyler Dec 19, 2025
f75188d
Merge branch 'main' into feat/tableprovider-dml-hooks
alamb Dec 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 346 additions & 3 deletions datafusion/catalog/src/memory/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ use std::sync::Arc;

use crate::TableProvider;

use arrow::datatypes::SchemaRef;
use arrow::array::{
Array, ArrayRef, BooleanArray, RecordBatch as ArrowRecordBatch, UInt64Array,
};
use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, filter_record_batch};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::error::Result;
use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err};
Expand All @@ -34,10 +39,14 @@ use datafusion_datasource::sink::DataSinkExec;
use datafusion_datasource::source::DataSourceExec;
use datafusion_expr::dml::InsertOp;
use datafusion_expr::{Expr, SortExpr, TableType};
use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs};
use datafusion_physical_expr::{
LexOrdering, create_physical_expr, create_physical_sort_exprs,
};
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
ExecutionPlan, ExecutionPlanProperties, Partitioning, common,
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
PlanProperties, common,
};
use datafusion_session::Session;

Expand Down Expand Up @@ -295,4 +304,338 @@ impl TableProvider for MemTable {
fn get_column_default(&self, column: &str) -> Option<&Expr> {
self.column_defaults.get(column)
}

async fn delete_from(
&self,
state: &dyn Session,
filters: Vec<Expr>,
) -> Result<Arc<dyn ExecutionPlan>> {
// Early exit if table has no partitions
if self.batches.is_empty() {
return Ok(Arc::new(DmlResultExec::new(0)));
}

*self.sort_order.lock() = vec![];

let mut total_deleted: u64 = 0;
let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;

for partition_data in &self.batches {
let mut partition = partition_data.write().await;
let mut new_batches = Vec::with_capacity(partition.len());

for batch in partition.iter() {
if batch.num_rows() == 0 {
continue;
}

// Evaluate filters - None means "match all rows"
let filter_mask = evaluate_filters_to_mask(
&filters,
batch,
&df_schema,
state.execution_props(),
)?;

let (delete_count, keep_mask) = match filter_mask {
Some(mask) => {
// Count rows where mask is true (will be deleted)
let count = mask.iter().filter(|v| v == &Some(true)).count();
// Keep rows where predicate is false or NULL (SQL three-valued logic)
let keep: BooleanArray =
mask.iter().map(|v| Some(v != Some(true))).collect();
(count, keep)
}
None => {
// No filters = delete all rows
(
batch.num_rows(),
BooleanArray::from(vec![false; batch.num_rows()]),
)
}
};

total_deleted += delete_count as u64;

let filtered_batch = filter_record_batch(batch, &keep_mask)?;
if filtered_batch.num_rows() > 0 {
new_batches.push(filtered_batch);
}
}

*partition = new_batches;
}

Ok(Arc::new(DmlResultExec::new(total_deleted)))
}

async fn update(
&self,
state: &dyn Session,
assignments: Vec<(String, Expr)>,
filters: Vec<Expr>,
) -> Result<Arc<dyn ExecutionPlan>> {
// Early exit if table has no partitions
if self.batches.is_empty() {
return Ok(Arc::new(DmlResultExec::new(0)));
}

// Validate column names upfront with clear error messages
let available_columns: Vec<&str> = self
.schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
for (column_name, _) in &assignments {
if self.schema.field_with_name(column_name).is_err() {
return plan_err!(
"UPDATE failed: column '{}' does not exist. Available columns: {}",
column_name,
available_columns.join(", ")
);
}
}

let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;

// Create physical expressions for assignments upfront (outside batch loop)
let physical_assignments: HashMap<
String,
Arc<dyn datafusion_physical_plan::PhysicalExpr>,
> = assignments
.iter()
.map(|(name, expr)| {
let physical_expr =
create_physical_expr(expr, &df_schema, state.execution_props())?;
Ok((name.clone(), physical_expr))
})
.collect::<Result<_>>()?;

*self.sort_order.lock() = vec![];

let mut total_updated: u64 = 0;

for partition_data in &self.batches {
let mut partition = partition_data.write().await;
let mut new_batches = Vec::with_capacity(partition.len());

for batch in partition.iter() {
if batch.num_rows() == 0 {
continue;
}

// Evaluate filters - None means "match all rows"
let filter_mask = evaluate_filters_to_mask(
&filters,
batch,
&df_schema,
state.execution_props(),
)?;

let (update_count, update_mask) = match filter_mask {
Some(mask) => {
// Count rows where mask is true (will be updated)
let count = mask.iter().filter(|v| v == &Some(true)).count();
// Normalize mask: only true (not NULL) triggers update
let normalized: BooleanArray =
mask.iter().map(|v| Some(v == Some(true))).collect();
(count, normalized)
}
None => {
// No filters = update all rows
(
batch.num_rows(),
BooleanArray::from(vec![true; batch.num_rows()]),
)
}
};

total_updated += update_count as u64;

if update_count == 0 {
new_batches.push(batch.clone());
continue;
}

let mut new_columns: Vec<ArrayRef> =
Vec::with_capacity(batch.num_columns());

for field in self.schema.fields() {
let column_name = field.name();
let original_column =
batch.column_by_name(column_name).ok_or_else(|| {
datafusion_common::DataFusionError::Internal(format!(
"Column '{column_name}' not found in batch"
))
})?;

let new_column = if let Some(physical_expr) =
physical_assignments.get(column_name.as_str())
{
// Use evaluate_selection to only evaluate on matching rows.
// This avoids errors (e.g., divide-by-zero) on rows that won't
// be updated. The result is scattered back with nulls for
// non-matching rows, which zip() will replace with originals.
let new_values =
physical_expr.evaluate_selection(batch, &update_mask)?;
let new_array = new_values.into_array(batch.num_rows())?;

// Convert to &dyn Array which implements Datum
let new_arr: &dyn Array = new_array.as_ref();
let orig_arr: &dyn Array = original_column.as_ref();
zip(&update_mask, &new_arr, &orig_arr)?
} else {
Arc::clone(original_column)
};

new_columns.push(new_column);
}

let updated_batch =
ArrowRecordBatch::try_new(Arc::clone(&self.schema), new_columns)?;
new_batches.push(updated_batch);
}

*partition = new_batches;
}

Ok(Arc::new(DmlResultExec::new(total_updated)))
}
}

/// Evaluate filter expressions against a batch and return a combined boolean mask.
/// Returns None if filters is empty (meaning "match all rows").
/// The returned mask has true for rows that match the filter predicates.
fn evaluate_filters_to_mask(
filters: &[Expr],
batch: &RecordBatch,
df_schema: &DFSchema,
execution_props: &datafusion_expr::execution_props::ExecutionProps,
) -> Result<Option<BooleanArray>> {
if filters.is_empty() {
return Ok(None);
}

let mut combined_mask: Option<BooleanArray> = None;

for filter_expr in filters {
let physical_expr =
create_physical_expr(filter_expr, df_schema, execution_props)?;

let result = physical_expr.evaluate(batch)?;
let array = result.into_array(batch.num_rows())?;
let bool_array = array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"Filter did not evaluate to boolean".to_string(),
)
})?
.clone();

combined_mask = Some(match combined_mask {
Some(existing) => and(&existing, &bool_array)?,
None => bool_array,
});
}

Ok(combined_mask)
}

/// Returns a single row with the count of affected rows.
#[derive(Debug)]
struct DmlResultExec {
rows_affected: u64,
schema: SchemaRef,
properties: PlanProperties,
}

impl DmlResultExec {
fn new(rows_affected: u64) -> Self {
let schema = Arc::new(Schema::new(vec![Field::new(
"count",
DataType::UInt64,
false,
)]));

let properties = PlanProperties::new(
datafusion_physical_expr::EquivalenceProperties::new(Arc::clone(&schema)),
Partitioning::UnknownPartitioning(1),
datafusion_physical_plan::execution_plan::EmissionType::Final,
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
);

Self {
rows_affected,
schema,
properties,
}
}
}

impl DisplayAs for DmlResultExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default
| DisplayFormatType::Verbose
| DisplayFormatType::TreeRender => {
write!(f, "DmlResultExec: rows_affected={}", self.rows_affected)
}
}
}
}

impl ExecutionPlan for DmlResultExec {
fn name(&self) -> &str {
"DmlResultExec"
}

fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}

fn properties(&self) -> &PlanProperties {
&self.properties
}

fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}

fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(self)
}

fn execute(
&self,
_partition: usize,
_context: Arc<datafusion_execution::TaskContext>,
) -> Result<datafusion_execution::SendableRecordBatchStream> {
// Create a single batch with the count
let count_array = UInt64Array::from(vec![self.rows_affected]);
let batch = ArrowRecordBatch::try_new(
Arc::clone(&self.schema),
vec![Arc::new(count_array) as ArrayRef],
)?;

// Create a stream that yields just this one batch
let stream = futures::stream::iter(vec![Ok(batch)]);
Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.schema),
stream,
)))
}
}
Loading