Skip to content

Commit 94a2b68

Browse files
committed
fix: Evaluate UPDATE expressions only on matching rows
- Use evaluate_selection() instead of evaluate() for UPDATE assignments to avoid evaluating expressions on rows excluded by WHERE clause. This prevents errors like divide-by-zero on non-updated rows. - Extract evaluate_filters_to_mask() helper shared by delete_from/update - Hoist physical expression creation outside batch loop for efficiency - Simplify strip_column_qualifiers() signature - Add regression test for divide-by-zero edge case
1 parent cd472f1 commit 94a2b68

File tree

3 files changed

+137
-90
lines changed

3 files changed

+137
-90
lines changed

datafusion/catalog/src/memory/table.rs

Lines changed: 107 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -329,50 +329,35 @@ impl TableProvider for MemTable {
329329
continue;
330330
}
331331

332-
let filter_mask = if filters.is_empty() {
333-
BooleanArray::from(vec![true; batch.num_rows()])
334-
} else {
335-
let mut combined_mask: Option<BooleanArray> = None;
336-
337-
for filter_expr in &filters {
338-
let physical_expr = create_physical_expr(
339-
filter_expr,
340-
&df_schema,
341-
state.execution_props(),
342-
)?;
343-
344-
let result = physical_expr.evaluate(batch)?;
345-
let array = result.into_array(batch.num_rows())?;
346-
let bool_array = array
347-
.as_any()
348-
.downcast_ref::<BooleanArray>()
349-
.ok_or_else(|| {
350-
datafusion_common::DataFusionError::Internal(
351-
"Filter did not evaluate to boolean".to_string(),
352-
)
353-
})?
354-
.clone();
355-
356-
combined_mask = Some(match combined_mask {
357-
Some(existing) => and(&existing, &bool_array)?,
358-
None => bool_array,
359-
});
332+
// Evaluate filters - None means "match all rows"
333+
let filter_mask = evaluate_filters_to_mask(
334+
&filters,
335+
batch,
336+
&df_schema,
337+
state.execution_props(),
338+
)?;
339+
340+
let (delete_count, keep_mask) = match filter_mask {
341+
Some(mask) => {
342+
// Count rows where mask is true (will be deleted)
343+
let count = mask.iter().filter(|v| v == &Some(true)).count();
344+
// Keep rows where predicate is false or NULL (SQL three-valued logic)
345+
let keep: BooleanArray =
346+
mask.iter().map(|v| Some(v != Some(true))).collect();
347+
(count, keep)
348+
}
349+
None => {
350+
// No filters = delete all rows
351+
(
352+
batch.num_rows(),
353+
BooleanArray::from(vec![false; batch.num_rows()]),
354+
)
360355
}
361-
362-
combined_mask.unwrap_or_else(|| {
363-
BooleanArray::from(vec![true; batch.num_rows()])
364-
})
365356
};
366357

367-
let delete_count =
368-
filter_mask.iter().filter(|v| v == &Some(true)).count();
369358
total_deleted += delete_count as u64;
370359

371-
// Keep rows where predicate is false or NULL (SQL three-valued logic)
372-
let keep_mask: BooleanArray =
373-
filter_mask.iter().map(|v| Some(v != Some(true))).collect();
374360
let filtered_batch = filter_record_batch(batch, &keep_mask)?;
375-
376361
if filtered_batch.num_rows() > 0 {
377362
new_batches.push(filtered_batch);
378363
}
@@ -412,15 +397,24 @@ impl TableProvider for MemTable {
412397
}
413398
}
414399

415-
let assignment_map: HashMap<&str, &Expr> = assignments
400+
let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
401+
402+
// Create physical expressions for assignments upfront (outside batch loop)
403+
let physical_assignments: HashMap<
404+
String,
405+
Arc<dyn datafusion_physical_plan::PhysicalExpr>,
406+
> = assignments
416407
.iter()
417-
.map(|(name, expr)| (name.as_str(), expr))
418-
.collect();
408+
.map(|(name, expr)| {
409+
let physical_expr =
410+
create_physical_expr(expr, &df_schema, state.execution_props())?;
411+
Ok((name.clone(), physical_expr))
412+
})
413+
.collect::<Result<_>>()?;
419414

420415
*self.sort_order.lock() = vec![];
421416

422417
let mut total_updated: u64 = 0;
423-
let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
424418

425419
for partition_data in &self.batches {
426420
let mut partition = partition_data.write().await;
@@ -431,54 +425,39 @@ impl TableProvider for MemTable {
431425
continue;
432426
}
433427

434-
let filter_mask = if filters.is_empty() {
435-
BooleanArray::from(vec![true; batch.num_rows()])
436-
} else {
437-
let mut combined_mask: Option<BooleanArray> = None;
438-
439-
for filter_expr in &filters {
440-
let physical_expr = create_physical_expr(
441-
filter_expr,
442-
&df_schema,
443-
state.execution_props(),
444-
)?;
445-
446-
let result = physical_expr.evaluate(batch)?;
447-
let array = result.into_array(batch.num_rows())?;
448-
let bool_array = array
449-
.as_any()
450-
.downcast_ref::<BooleanArray>()
451-
.ok_or_else(|| {
452-
datafusion_common::DataFusionError::Internal(
453-
"Filter did not evaluate to boolean".to_string(),
454-
)
455-
})?
456-
.clone();
457-
458-
combined_mask = Some(match combined_mask {
459-
Some(existing) => and(&existing, &bool_array)?,
460-
None => bool_array,
461-
});
428+
// Evaluate filters - None means "match all rows"
429+
let filter_mask = evaluate_filters_to_mask(
430+
&filters,
431+
batch,
432+
&df_schema,
433+
state.execution_props(),
434+
)?;
435+
436+
let (update_count, update_mask) = match filter_mask {
437+
Some(mask) => {
438+
// Count rows where mask is true (will be updated)
439+
let count = mask.iter().filter(|v| v == &Some(true)).count();
440+
// Normalize mask: only true (not NULL) triggers update
441+
let normalized: BooleanArray =
442+
mask.iter().map(|v| Some(v == Some(true))).collect();
443+
(count, normalized)
444+
}
445+
None => {
446+
// No filters = update all rows
447+
(
448+
batch.num_rows(),
449+
BooleanArray::from(vec![true; batch.num_rows()]),
450+
)
462451
}
463-
464-
combined_mask.unwrap_or_else(|| {
465-
BooleanArray::from(vec![true; batch.num_rows()])
466-
})
467452
};
468453

469-
let update_count =
470-
filter_mask.iter().filter(|v| v == &Some(true)).count();
471454
total_updated += update_count as u64;
472455

473456
if update_count == 0 {
474457
new_batches.push(batch.clone());
475458
continue;
476459
}
477460

478-
// Normalize mask: only true (not NULL) triggers update
479-
let update_mask: BooleanArray =
480-
filter_mask.iter().map(|v| Some(v == Some(true))).collect();
481-
482461
let mut new_columns: Vec<ArrayRef> =
483462
Vec::with_capacity(batch.num_columns());
484463

@@ -491,16 +470,15 @@ impl TableProvider for MemTable {
491470
))
492471
})?;
493472

494-
let new_column = if let Some(value_expr) =
495-
assignment_map.get(column_name.as_str())
473+
let new_column = if let Some(physical_expr) =
474+
physical_assignments.get(column_name.as_str())
496475
{
497-
let physical_expr = create_physical_expr(
498-
value_expr,
499-
&df_schema,
500-
state.execution_props(),
501-
)?;
502-
503-
let new_values = physical_expr.evaluate(batch)?;
476+
// Use evaluate_selection to only evaluate on matching rows.
477+
// This avoids errors (e.g., divide-by-zero) on rows that won't
478+
// be updated. The result is scattered back with nulls for
479+
// non-matching rows, which zip() will replace with originals.
480+
let new_values =
481+
physical_expr.evaluate_selection(batch, &update_mask)?;
504482
let new_array = new_values.into_array(batch.num_rows())?;
505483

506484
// Convert to &dyn Array which implements Datum
@@ -526,6 +504,46 @@ impl TableProvider for MemTable {
526504
}
527505
}
528506

507+
/// Evaluate filter expressions against a batch and return a combined boolean mask.
508+
/// Returns None if filters is empty (meaning "match all rows").
509+
/// The returned mask has true for rows that match the filter predicates.
510+
fn evaluate_filters_to_mask(
511+
filters: &[Expr],
512+
batch: &RecordBatch,
513+
df_schema: &DFSchema,
514+
execution_props: &datafusion_expr::execution_props::ExecutionProps,
515+
) -> Result<Option<BooleanArray>> {
516+
if filters.is_empty() {
517+
return Ok(None);
518+
}
519+
520+
let mut combined_mask: Option<BooleanArray> = None;
521+
522+
for filter_expr in filters {
523+
let physical_expr =
524+
create_physical_expr(filter_expr, df_schema, execution_props)?;
525+
526+
let result = physical_expr.evaluate(batch)?;
527+
let array = result.into_array(batch.num_rows())?;
528+
let bool_array = array
529+
.as_any()
530+
.downcast_ref::<BooleanArray>()
531+
.ok_or_else(|| {
532+
datafusion_common::DataFusionError::Internal(
533+
"Filter did not evaluate to boolean".to_string(),
534+
)
535+
})?
536+
.clone();
537+
538+
combined_mask = Some(match combined_mask {
539+
Some(existing) => and(&existing, &bool_array)?,
540+
None => bool_array,
541+
});
542+
}
543+
544+
Ok(combined_mask)
545+
}
546+
529547
/// Returns a single row with the count of affected rows.
530548
#[derive(Debug)]
531549
struct DmlResultExec {

datafusion/core/src/physical_planner.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,7 @@ fn get_physical_expr_pair(
19061906
/// splitting AND conjunctions into individual expressions.
19071907
/// Column qualifiers are stripped so expressions can be evaluated against
19081908
/// the TableProvider's schema.
1909+
///
19091910
fn extract_dml_filters(input: &Arc<LogicalPlan>) -> Result<Vec<Expr>> {
19101911
let mut filters = Vec::new();
19111912

@@ -1929,7 +1930,7 @@ fn strip_column_qualifiers(expr: Expr) -> Result<Expr> {
19291930
if let Expr::Column(col) = &e
19301931
&& col.relation.is_some()
19311932
{
1932-
// Create unqualified column
1933+
// Strip the qualifier
19331934
return Ok(Transformed::yes(Expr::Column(Column::new_unqualified(
19341935
col.name.clone(),
19351936
))));
@@ -1943,6 +1944,7 @@ fn strip_column_qualifiers(expr: Expr) -> Result<Expr> {
19431944
/// For UPDATE statements, the SQL planner encodes assignments as a projection
19441945
/// over the source table. This function extracts column name and expression pairs
19451946
/// from the projection. Column qualifiers are stripped from the expressions.
1947+
///
19461948
fn extract_update_assignments(input: &Arc<LogicalPlan>) -> Result<Vec<(String, Expr)>> {
19471949
// The UPDATE input plan structure is:
19481950
// Projection(updated columns as expressions with aliases)

datafusion/sqllogictest/test_files/dml_update.slt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,30 @@ UPDATE test_update_error SET nonexistent = 'value';
257257

258258
statement ok
259259
DROP TABLE test_update_error;
260+
261+
# Test UPDATE with expression that would error on non-matching rows
262+
# Regression test: expressions should only be evaluated on rows that match
263+
# the WHERE clause, not all rows. This prevents divide-by-zero errors
264+
# on rows that won't be updated.
265+
statement ok
266+
CREATE TABLE test_update_div(id INT, divisor INT, result INT);
267+
268+
statement ok
269+
INSERT INTO test_update_div VALUES (1, 0, 0), (2, 2, 0), (3, 5, 0);
270+
271+
# This should succeed: 1/divisor is only evaluated where divisor != 0
272+
# Row 1 (divisor=0) is excluded by WHERE clause and expression is NOT evaluated
273+
query I
274+
UPDATE test_update_div SET result = 100 / divisor WHERE divisor != 0;
275+
----
276+
2
277+
278+
query III rowsort
279+
SELECT * FROM test_update_div;
280+
----
281+
1 0 0
282+
2 2 50
283+
3 5 20
284+
285+
statement ok
286+
DROP TABLE test_update_div;

0 commit comments

Comments
 (0)