Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions vortex-datafusion/src/convert/exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::sync::Arc;

use arrow_schema::DataType;
use arrow_schema::Field;
use arrow_schema::Schema;
use datafusion_common::Result as DFResult;
use datafusion_common::exec_datafusion_err;
Expand All @@ -20,9 +21,9 @@ use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
use datafusion_physical_plan::expressions as df_expr;
use itertools::Itertools;
use vortex::dtype::DType;
use vortex::VortexSessionDefault;
use vortex::array::arrow::ArrowSessionExt;
use vortex::dtype::Nullability;
use vortex::dtype::arrow::FromArrowType;
use vortex::expr::Expression;
use vortex::expr::and_collect;
use vortex::expr::byte_length;
Expand All @@ -42,6 +43,7 @@ use vortex::scalar_fn::fns::binary::Binary;
use vortex::scalar_fn::fns::like::Like;
use vortex::scalar_fn::fns::like::LikeOptions;
use vortex::scalar_fn::fns::operators::Operator;
use vortex::session::VortexSession;

use crate::convert::FromDataFusion;

Expand Down Expand Up @@ -109,10 +111,29 @@ pub trait ExpressionConvertor: Send + Sync {
}

/// The default [`ExpressionConvertor`] implementation.
#[derive(Default)]
pub struct DefaultExpressionConvertor {}
pub struct DefaultExpressionConvertor {
/// Session used to resolve Arrow → Vortex dtypes through the extension
/// plugin registry, so registered extension types (e.g. UUID ⇄
/// `FixedSizeBinary[16]`) convert correctly instead of hitting the static,
/// non-plugin-aware `DType::from_arrow`.
session: VortexSession,
}

impl Default for DefaultExpressionConvertor {
fn default() -> Self {
Self {
session: VortexSession::default(),
}
}
}

impl DefaultExpressionConvertor {
/// Create a convertor that resolves Arrow extension types using `session`'s
/// dtype registry.
pub fn new(session: VortexSession) -> Self {
Self { session }
}

/// Attempts to convert DataFusion's `octet_length` function to Vortex `byte_length`.
fn try_convert_octet_length(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
let [input] = scalar_fn.args() else {
Expand All @@ -122,8 +143,15 @@ impl DefaultExpressionConvertor {
};

let input = self.convert(input.as_ref())?;
let return_dtype =
DType::from_arrow((scalar_fn.return_type(), scalar_fn.nullable().into()));
let return_dtype = self
.session
.arrow()
.from_arrow_field(&Field::new(
"",
scalar_fn.return_type().clone(),
scalar_fn.nullable(),
))
.map_err(|e| exec_datafusion_err!("Failed to convert return type to dtype: {e}"))?;
Ok(cast(byte_length(input), return_dtype))
}

Expand Down Expand Up @@ -246,7 +274,11 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
}

if let Some(cast_expr) = df.downcast_ref::<df_expr::CastExpr>() {
let cast_dtype = DType::from_arrow(cast_expr.target_field().as_ref());
let cast_dtype = self
.session
.arrow()
.from_arrow_field(cast_expr.target_field().as_ref())
.map_err(|e| exec_datafusion_err!("Failed to convert cast target to dtype: {e}"))?;
let child = self.convert(cast_expr.expr().as_ref())?;
return Ok(cast(child, cast_dtype));
}
Expand Down Expand Up @@ -975,6 +1007,25 @@ mod tests {
Ok(())
}

/// A cast whose target is a UUID-tagged `FixedSizeBinary(16)` must resolve
/// through the dtype extension registry (UUID is registered on the default
/// session) instead of the static, non-plugin-aware `DType::from_arrow`,
/// which does not support `FixedSizeBinary` and previously panicked here.
#[test]
fn test_cast_to_uuid_resolves_via_registry() -> anyhow::Result<()> {
use arrow_schema::extension::Uuid;

let mut uuid_field = Field::new("id", DataType::FixedSizeBinary(16), true);
uuid_field.try_with_extension_type(Uuid)?;

let child = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
let cast = df_expr::CastExpr::new_with_target_field(child, Arc::new(uuid_field), None);

// Must convert without panicking — the static path would `unimplemented!()`.
DefaultExpressionConvertor::default().convert(&cast)?;
Ok(())
}

/// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion
/// matches the result of applying the converted Vortex expression.
#[test]
Expand Down
3 changes: 2 additions & 1 deletion vortex-datafusion/src/persistent/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ impl VortexSource {
let full_schema = table_schema.table_schema();
let indices = (0..full_schema.fields().len()).collect::<Vec<_>>();
let projection = ProjectionExprs::from_indices(&indices, full_schema);
let expression_convertor = Arc::new(DefaultExpressionConvertor::new(session.clone()));

Self {
session,
Expand All @@ -231,7 +232,7 @@ impl VortexSource {
_unused_df_metrics: Default::default(),
layout_readers: Arc::new(DashMap::default()),
natural_split_ranges: Arc::new(DashMap::default()),
expression_convertor: Arc::new(DefaultExpressionConvertor::default()),
expression_convertor,
vortex_reader_factory: None,
vx_metrics_registry: Arc::new(DefaultMetricsRegistry::default()),
file_metadata_cache: None,
Expand Down
Loading