diff --git a/vortex-duckdb/cpp/expr.cpp b/vortex-duckdb/cpp/expr.cpp index 6470a9d338d..afe2573adc2 100644 --- a/vortex-duckdb/cpp/expr.cpp +++ b/vortex-duckdb/cpp/expr.cpp @@ -47,6 +47,12 @@ extern "C" duckdb_vx_expr_class duckdb_vx_expr_get_class(duckdb_vx_expr ffi_expr return static_cast(expr->GetExpressionClass()); } +extern "C" duckdb_logical_type duckdb_vx_expr_get_return_type(duckdb_vx_expr ffi_expr) { + D_ASSERT(ffi_expr); + auto expr = reinterpret_cast(ffi_expr); + return reinterpret_cast(&expr->return_type); +} + extern "C" const char *duckdb_vx_expr_get_bound_column_ref_get_name(duckdb_vx_expr ffi_expr) { if (!ffi_expr) { return nullptr; diff --git a/vortex-duckdb/cpp/include/expr.h b/vortex-duckdb/cpp/include/expr.h index 457a944e5d5..5b7997596d6 100644 --- a/vortex-duckdb/cpp/include/expr.h +++ b/vortex-duckdb/cpp/include/expr.h @@ -213,6 +213,10 @@ typedef enum DUCKDB_VX_EXPR_TYPE { duckdb_vx_expr_class duckdb_vx_expr_get_class(duckdb_vx_expr expr); +/// Return the (bound) return type of the expression. The logical type is borrowed from the +/// expression and must not be freed. +duckdb_logical_type duckdb_vx_expr_get_return_type(duckdb_vx_expr expr); + const char *duckdb_vx_expr_get_bound_column_ref_get_name(duckdb_vx_expr expr); duckdb_value duckdb_vx_expr_bound_constant_get_value(duckdb_vx_expr expr); diff --git a/vortex-duckdb/src/convert/expr.rs b/vortex-duckdb/src/convert/expr.rs index 324086e5775..387b644fe30 100644 --- a/vortex-duckdb/src/convert/expr.rs +++ b/vortex-duckdb/src/convert/expr.rs @@ -22,6 +22,7 @@ use vortex::expr::get_item; use vortex::expr::is_not_null; use vortex::expr::is_null; use vortex::expr::list_contains; +use vortex::expr::list_length; use vortex::expr::lit; use vortex::expr::not; use vortex::expr::or_collect; @@ -37,6 +38,7 @@ use vortex::scalar_fn::fns::like::LikeOptions; use vortex::scalar_fn::fns::literal::Literal; use vortex::scalar_fn::fns::operators::Operator; +use crate::cpp::DUCKDB_TYPE; use crate::cpp::DUCKDB_VX_EXPR_TYPE; use crate::duckdb; use crate::duckdb::BoundFunction; @@ -57,6 +59,20 @@ fn from_bound_str(value: &duckdb::ExpressionRef) -> VortexResult { } } +/// Whether the expression's return type is a `LIST` or fixed-size `ARRAY`. +fn returns_a_list(expr: &duckdb::ExpressionRef) -> bool { + matches!( + expr.return_type().as_type_id(), + DUCKDB_TYPE::DUCKDB_TYPE_LIST | DUCKDB_TYPE::DUCKDB_TYPE_ARRAY + ) +} + +/// Wrap `expr` in `list_length`. Since vortex `list_length` returns u64 but duckdb equivalents +/// return i64, we must cast as well. +fn build_list_length(expr: Expression, nullability: Nullability) -> Expression { + cast(list_length(expr), DType::Primitive(PType::I64, nullability)) +} + fn try_from_bound_function( func: &BoundFunction, col_sub: Option<&Expression>, @@ -115,6 +131,37 @@ fn try_from_bound_function( }; Like.new_expr(LikeOptions::default(), [value, lit(pattern)]) } + "array_length" => { + let children = func.children().collect::>(); + // Only accept array_length(expr) rather than array_length(expr, dim). + if children.len() != 1 { + return Ok(None); + } + let Some(col) = try_from_expression_inner(children[0], col_sub)? else { + return Ok(None); + }; + + // We don't know the column's nullability here, so we set it to nullable. + build_list_length(col, Nullability::Nullable) + } + // len/length semantics depend on the return type of underlying expr. + "len" | "length" => { + let children: Vec<_> = func.children().collect(); + vortex_ensure!(children.len() == 1); + let child = children[0]; + + if returns_a_list(child) { + let Some(col) = try_from_expression_inner(child, col_sub)? else { + return Ok(None); + }; + + // Same nullability rationale as in "array_length" branch. + let list_len_expr = build_list_length(col, Nullability::Nullable); + return Ok(Some(list_len_expr)); + } else { + return Ok(None); + } + } _ => { debug!("bound function {}", func.scalar_function.name()); return Ok(None); @@ -137,6 +184,11 @@ pub(super) fn try_from_bound_expression_with_col_sub( try_from_expression_inner(value, Some(col_sub)) } +fn is_supported_length_alias(func: &BoundFunction) -> bool { + let children: Vec<_> = func.children().collect(); + children.len() == 1 && returns_a_list(children[0]) +} + // Called before pushdown_complex_filter or a table filter expression call. // As we support complex filter pushdown, Duckdb pushes expressions to Vortex. // However, it doesn't know what type of expressions we can handle. Here we list @@ -173,6 +225,8 @@ pub fn can_push_expression(value: &duckdb::ExpressionRef) -> bool { || name == "~~" || name == "!~~" || name == "strlen" + || name == "array_length" + || (matches!(name, "len" | "length") && is_supported_length_alias(&func)) } ExpressionClass::BoundOperator(op) => { if !matches!( @@ -190,6 +244,13 @@ pub fn can_push_expression(value: &duckdb::ExpressionRef) -> bool { } } +/// Applies `list_length` expression to a duckdb field +fn list_length_on_field(field: &DuckdbField) -> Expression { + let col = get_item(field.name.as_str(), root()); + + build_list_length(col, field.dtype.nullability()) +} + pub fn try_from_projection_expression( value: &duckdb::ExpressionRef, field: &DuckdbField, @@ -208,6 +269,13 @@ pub fn try_from_projection_expression( let col = cast(col, dtype); Some(col) } + "array_length" => { + // Only accept array_length(expr) rather than array_length(expr, dim). + (func.children().count() == 1).then(|| list_length_on_field(field)) + } + // len/length have different semantics depending on field dtype. + "len" | "length" => matches!(field.dtype, DType::List(..) | DType::FixedSizeList(..)) + .then(|| list_length_on_field(field)), _ => None, }) } diff --git a/vortex-duckdb/src/duckdb/expr.rs b/vortex-duckdb/src/duckdb/expr.rs index 2b206bc192f..48f255f744e 100644 --- a/vortex-duckdb/src/duckdb/expr.rs +++ b/vortex-duckdb/src/duckdb/expr.rs @@ -10,6 +10,8 @@ use std::ptr; use crate::cpp; use crate::cpp::duckdb_vx_expr_class; use crate::duckdb::DDBString; +use crate::duckdb::LogicalType; +use crate::duckdb::LogicalTypeRef; use crate::duckdb::ScalarFunction; use crate::duckdb::ScalarFunctionRef; use crate::duckdb::Value; @@ -33,6 +35,11 @@ impl ExpressionRef { unsafe { cpp::duckdb_vx_expr_get_class(self.as_ptr()) } } + /// The return type of this expression. + pub fn return_type(&self) -> &LogicalTypeRef { + unsafe { LogicalType::borrow(cpp::duckdb_vx_expr_get_return_type(self.as_ptr())) } + } + /// Match the subclass of the expression. pub fn as_class(&self) -> Option> { Some( diff --git a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs index 6cc28483571..a5fe72ec68a 100644 --- a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs +++ b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs @@ -1014,3 +1014,101 @@ fn test_geometry() { let area = vec.as_slice_with_len::(chunk.len().as_())[0]; assert_eq!(area, 1000.0); } + +/// `SELECT array_length(list)` / `len(list)` / `length(list)` should push the list-length +/// computation into the Vortex scan (computed from offsets, without materializing the list +/// elements) and return the per-row element counts. +#[test] +fn test_vortex_scan_list_length_projection() { + let file = RUNTIME.block_on(async { + let integers = PrimitiveArray::from_iter([ + 10i32, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, + ]); + // Variable-length lists with 3, 4, 1, 5, 2 elements respectively. + let offsets = buffer![0i32, 3, 7, 8, 13, 15]; + let list_array = ListArray::try_new( + integers.into_array(), + offsets.into_array(), + Validity::AllValid, + ) + .unwrap(); + + write_single_column_vortex_file("int_list", list_array).await + }); + + let conn = database_connection(); + let file_path = file.path().to_string_lossy(); + + // `len`/`length` bind to the same DuckDB function set as `array_length` for list arguments. + for func in ["array_length", "len", "length"] { + let result = conn + .query(&format!("SELECT {func}(int_list) FROM '{file_path}'")) + .unwrap(); + + let mut lengths = Vec::new(); + for chunk in result { + let len = chunk.len().as_(); + let vec = chunk.get_vector(0); + lengths.extend_from_slice(vec.as_slice_with_len::(len)); + } + + assert_eq!(lengths, vec![3, 4, 1, 5, 2], "{func}(int_list) mismatch"); + } +} + +/// `WHERE array_length(list) >= k` should push down as a complex filter. +#[test] +fn test_vortex_scan_list_length_filter() { + let file = RUNTIME.block_on(async { + let integers = PrimitiveArray::from_iter([ + 10i32, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, + ]); + // Variable-length lists with 3, 4, 1, 5, 2 elements respectively. + let offsets = buffer![0i32, 3, 7, 8, 13, 15]; + let list_array = ListArray::try_new( + integers.into_array(), + offsets.into_array(), + Validity::AllValid, + ) + .unwrap(); + + write_single_column_vortex_file("int_list", list_array).await + }); + + // Lists with length >= 4: the 4-element and 5-element lists => 2 rows. + let count = scan_vortex_file_single_row::( + file, + "SELECT COUNT(*) FROM ? WHERE array_length(int_list) >= 4", + 0, + ); + assert_eq!(count, 2); +} + +/// `array_length`/`len`/`length` over a FixedSizeList column. The length is the fixed list size. +#[test] +fn test_vortex_scan_fixed_size_list_length_projection() { + let file = RUNTIME.block_on(async { + // 6 fixed-size lists of 4 i32 elements each. + let elements = (0..24i32).collect::(); + let fsl = FixedSizeListArray::new(elements.into_array(), 4, Validity::AllValid, 6); + write_single_column_vortex_file("int_lists", fsl).await + }); + + let conn = database_connection(); + let file_path = file.path().to_string_lossy(); + + for func in ["array_length", "len", "length"] { + let result = conn + .query(&format!("SELECT {func}(int_lists) FROM '{file_path}'")) + .unwrap(); + + let mut lengths = Vec::new(); + for chunk in result { + let len = chunk.len().as_(); + let vec = chunk.get_vector(0); + lengths.extend_from_slice(vec.as_slice_with_len::(len)); + } + + assert_eq!(lengths, vec![4i64; 6], "{func}(int_lists) mismatch"); + } +} diff --git a/vortex-sqllogictest/slt/duckdb/list_length_pushdown.slt b/vortex-sqllogictest/slt/duckdb/list_length_pushdown.slt new file mode 100644 index 00000000000..387e4f15bce --- /dev/null +++ b/vortex-sqllogictest/slt/duckdb/list_length_pushdown.slt @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +include ../setup.slt.no + +# Two list columns so we can exercise list-length pushdown over the same and +# different columns, in both SELECT and WHERE. +statement ok +CREATE TABLE list_test (id INTEGER, a INTEGER[], b INTEGER[]); + +statement ok +INSERT INTO list_test VALUES + (1, [10, 20, 30], [1]), + (2, [40, 50, 60, 70], [2, 3]), + (3, [80], [4, 5, 6]), + (4, [90, 100, 110, 120, 130], [7, 8, 9, 10]), + (5, [140, 150], []); + +statement ok +COPY (SELECT * FROM list_test) TO '$__TEST_DIR__/list-length.vortex'; + +# array_length projection is pushed into the scan. The "SELECT projections" +# marker in the Vortex EXPLAIN output indicates pushdown succeeded. +query TT +EXPLAIN (FORMAT json) +SELECT array_length(a) FROM '$__TEST_DIR__/list-length.vortex'; +---- +:"SELECT projections".*vortex\.list\.length + +# len/length/array_length all map to list_length for list arguments. +query I +SELECT array_length(a) FROM '$__TEST_DIR__/list-length.vortex' ORDER BY id; +---- +3 +4 +1 +5 +2 + +query TT +EXPLAIN (FORMAT json) +SELECT len(a) FROM '$__TEST_DIR__/list-length.vortex'; +---- +:"SELECT projections".*vortex\.list\.length + +query I +SELECT len(a) FROM '$__TEST_DIR__/list-length.vortex' ORDER BY id; +---- +3 +4 +1 +5 +2 + +query TT +EXPLAIN (FORMAT json) +SELECT length(a) FROM '$__TEST_DIR__/list-length.vortex'; +---- +:"SELECT projections".*vortex\.list\.length + +query I +SELECT length(a) FROM '$__TEST_DIR__/list-length.vortex' ORDER BY id; +---- +3 +4 +1 +5 +2 + +# array_length over the other list column, including the empty list (length 0). +query TT +EXPLAIN (FORMAT json) +SELECT array_length(b) FROM '$__TEST_DIR__/list-length.vortex'; +---- +:"SELECT projections".*vortex\.list\.length + +query I +SELECT array_length(b) FROM '$__TEST_DIR__/list-length.vortex' ORDER BY id; +---- +1 +2 +3 +4 +0 + +# array_length in WHERE is pushed into the scan as a complex filter: it shows up under the +# scan's "Filters" as a list_length expression (and DuckDB drops its own filter operator). The +# SELECT here is a plain column, so `list.length` can only appear via the pushed-down filter. +query TT +EXPLAIN (FORMAT json) SELECT id FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(a) >= 4; +---- +:"Filters".*vortex\.list\.length + +# array_length in WHERE on the same column used in the SELECT projection. +query TT +EXPLAIN (FORMAT json) +SELECT array_length(a) FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(a) >= 4; +---- +:"Filters".*vortex\.list\.length.*"SELECT projections".*vortex\.list\.length + +query I +SELECT array_length(a) FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(a) >= 4 +ORDER BY id; +---- +4 +5 + +# array_length in WHERE filtering a different column than the SELECT projection. +query TT +EXPLAIN (FORMAT json) +SELECT id, array_length(a) FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(b) >= 3; +---- +:"Filters".*vortex\.list\.length.*"SELECT projections".*vortex\.list\.length + +query II +SELECT id, array_length(a) FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(b) >= 3 +ORDER BY id; +---- +3 1 +4 5 + +# array_length over different columns in both SELECT and WHERE is pushed on both sides: the +# projection on `a` appears under "SELECT projections" and the filter on `b` under "Filters", +# both as list_length expressions. +query TT +EXPLAIN (FORMAT json) +SELECT array_length(a) FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(b) >= 2; +---- +:"Filters".*vortex\.list\.length.*"SELECT projections".*vortex\.list\.length + +# array_length over different columns in both SELECT and WHERE. +query TT +EXPLAIN (FORMAT json) +SELECT array_length(a), array_length(b) FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(b) >= 2; +---- +:"Filters".*vortex\.list\.length.*"SELECT projections".*vortex\.list\.length + +query II +SELECT array_length(a), array_length(b) FROM '$__TEST_DIR__/list-length.vortex' +WHERE array_length(b) >= 2 +ORDER BY id; +---- +4 2 +1 3 +5 4 + +# Mixing the len/length/array_length aliases across SELECT and WHERE. +query TT +EXPLAIN (FORMAT json) +SELECT len(a) FROM '$__TEST_DIR__/list-length.vortex' +WHERE length(b) >= 2; +---- +:"Filters".*vortex\.list\.length.*"SELECT projections".*vortex\.list\.length + +query I +SELECT len(a) FROM '$__TEST_DIR__/list-length.vortex' +WHERE length(b) >= 2 +ORDER BY id; +---- +4 +1 +5