Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions vortex-cuda/src/arrow/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use cudarc::driver::result as cuda_driver;
use futures::future::BoxFuture;
use vortex::array::ArrayRef;
use vortex::array::Canonical;
use vortex::array::ExecutionCtx;
use vortex::array::IntoArray;
use vortex::array::arrays::DecimalArray;
use vortex::array::arrays::Dict;
use vortex::array::arrays::DictArray;
Expand All @@ -35,6 +37,7 @@ use vortex::array::arrays::extension::ExtensionArrayExt;
use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex::array::arrays::fixed_size_list::FixedSizeListDataParts;
use vortex::array::arrays::list::ListDataParts;
use vortex::array::arrays::listview::ListViewArrayExt;
use vortex::array::arrays::listview::list_from_list_view;
use vortex::array::arrays::primitive::PrimitiveDataParts;
use vortex::array::arrays::struct_::StructDataParts;
Expand Down Expand Up @@ -63,10 +66,12 @@ use crate::CudaExecutionCtx;
use crate::arrow::ARROW_DEVICE_CUDA;
use crate::arrow::ArrowArray;
use crate::arrow::ArrowDeviceArray;
use crate::arrow::ArrowDeviceArrayWithSchema;
use crate::arrow::ExportDeviceArray;
use crate::arrow::PrivateData;
use crate::arrow::SyncEvent;
use crate::arrow::arrow_device_export_dictionary_codes_dtype;
use crate::arrow::arrow_schema_for_array;
use crate::arrow::cuda_decimal_value_type;
use crate::arrow::list_view::export_device_list_view;
use crate::cub::exclusive_sum_i32;
Expand Down Expand Up @@ -95,6 +100,92 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
reserved: Default::default(),
})
}

async fn export_device_array_with_schema(
&self,
array: ArrayRef,
ctx: &mut CudaExecutionCtx,
) -> VortexResult<ArrowDeviceArrayWithSchema> {
let array = rebuild_array_for_export_schema(array, ctx.execution_ctx())?;
let schema = arrow_schema_for_array(&array, ctx)?;
let array = self.export_device_array(array, ctx).await?;
Ok(ArrowDeviceArrayWithSchema { schema, array })
}
}

/// Rebuild arrays whose exported layout differs from their original layout.
fn rebuild_array_for_export_schema(
array: ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let array = match array.try_downcast::<Dict>() {
Ok(dict) => {
let parts = dict.into_parts();
let values = rebuild_array_for_export_schema(parts.values, ctx)?;
return Ok(DictArray::try_new(parts.codes, values)?.into_array());
}
Err(array) => array,
};
let array = match array.try_downcast::<Struct>() {
Ok(struct_array) => {
let len = struct_array.len();
let StructDataParts {
struct_fields,
fields,
validity,
} = struct_array.into_data_parts();
let fields = fields
.iter()
.map(|field| rebuild_array_for_export_schema(field.clone(), ctx))
.collect::<VortexResult<Vec<_>>>()?;
return Ok(
StructArray::try_new(struct_fields.names().clone(), fields, len, validity)?
.into_array(),
);
}
Err(array) => array,
};
let array = match array.try_downcast::<List>() {
Ok(list) => {
let ListDataParts {
elements,
offsets,
validity,
..
} = list.into_data_parts();
let elements = rebuild_array_for_export_schema(elements, ctx)?;
return Ok(ListArray::try_new(elements, offsets, validity)?.into_array());
}
Err(array) => array,
};
let array = match array.try_downcast::<FixedSizeList>() {
Ok(fixed_size_list) => {
let len = fixed_size_list.len();
let list_size = fixed_size_list.list_size();
let FixedSizeListDataParts {
elements, validity, ..
} = fixed_size_list.into_data_parts();
let elements = rebuild_array_for_export_schema(elements, ctx)?;
return Ok(
FixedSizeListArray::try_new(elements, list_size, validity, len)?.into_array(),
);
}
Err(array) => array,
};
let array = match array.try_downcast::<ListView>() {
Ok(listview)
if listview.as_ref().is_host() && listview.elements().as_opt::<Dict>().is_some() =>
{
return rebuild_array_for_export_schema(
list_from_list_view(listview, ctx)?.into_array(),
ctx,
);
}
Ok(listview) => return Ok(listview.into_array()),
Err(array) => array,
};

Ok(array)
}

/// Export arrays whose Arrow layout depends on their concrete children before CUDA
Expand Down Expand Up @@ -2139,7 +2230,7 @@ mod tests {
}

#[crate::test]
async fn test_export_host_non_contiguous_dictionary_list_view_preserves_dictionary_child()
async fn test_export_host_non_contiguous_dictionary_list_view_schema_matches_rebuilt_child()
-> VortexResult<()> {
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
.vortex_expect("failed to create execution context");
Expand All @@ -2165,7 +2256,13 @@ mod tests {
"",
Field::new(
Field::LIST_FIELD_DEFAULT_NAME,
DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Int32)),
DataType::Dictionary(
Box::new(DataType::Int64),
Box::new(DataType::Dictionary(
Box::new(DataType::Int16),
Box::new(DataType::Int32),
)),
),
true,
),
false,
Expand All @@ -2180,6 +2277,57 @@ mod tests {
assert!(!dict_child.dictionary.is_null());
assert_eq!(dict_child.length, 5);
assert_eq!(dict_child.n_buffers, 2);
let nested_dict = unsafe { &*dict_child.dictionary };
assert!(!nested_dict.dictionary.is_null());

unsafe { release_exported_array(&raw mut exported.array.array) };
Ok(())
}

// Regression test: with an average list size >= 128 the host list-view rebuild picks its
// list-by-list strategy, which may canonicalize Dict elements. The schema must describe the
// rebuilt child layout.
#[crate::test]
async fn test_export_host_large_lists_dictionary_list_view_schema_matches_rebuilt_child()
-> VortexResult<()> {
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
.vortex_expect("failed to create execution context");

let elements = DictArray::try_new(
PrimitiveArray::from_option_iter(
(0..256u32).map(|i| (i % 5 != 0).then_some((i % 3) as u8)),
)
.into_array(),
PrimitiveArray::from_iter([10i32, 20, 30]).into_array(),
)?
.into_array();
let array = ListViewArray::new(
elements,
PrimitiveArray::from_iter([128i32, 0]).into_array(),
PrimitiveArray::from_iter([128i32, 128]).into_array(),
Validity::NonNullable,
)
.into_array();
let mut exported = array.export_device_array_with_schema(&mut ctx).await?;

let field = Field::try_from(&exported.schema)?;
assert_eq!(
field,
Field::new_list(
"",
Field::new(Field::LIST_FIELD_DEFAULT_NAME, DataType::Int32, true),
false,
)
);
assert_eq!(
private_data_buffer_i32_values(&exported.array.array, 1)?,
[0, 128, 256]
);
let list_children = unsafe { std::slice::from_raw_parts(exported.array.array.children, 1) };
let child = unsafe { &*list_children[0] };
assert!(child.dictionary.is_null());
assert_eq!(child.length, 256);
assert_eq!(child.n_buffers, 2);

unsafe { release_exported_array(&raw mut exported.array.array) };
Ok(())
Expand Down
18 changes: 14 additions & 4 deletions vortex-cuda/src/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,13 @@ impl DeviceArrayExt for ArrayRef {
self,
ctx: &mut CudaExecutionCtx,
) -> VortexResult<ArrowDeviceArrayWithSchema> {
let schema = arrow_schema_for_array(&self, ctx)?;
let array = self.export_device_array(ctx).await?;
Ok(ArrowDeviceArrayWithSchema { schema, array })
let exporter = Arc::clone(ctx.exporter());
exporter.export_device_array_with_schema(self, ctx).await
}
}

/// Build the Arrow C schema that describes the exported device array.
fn arrow_schema_for_array(
pub(crate) fn arrow_schema_for_array(
array: &ArrayRef,
ctx: &mut CudaExecutionCtx,
) -> VortexResult<FFI_ArrowSchema> {
Expand Down Expand Up @@ -479,4 +478,15 @@ pub trait ExportDeviceArray: Debug + Send + Sync + 'static {
array: ArrayRef,
ctx: &mut CudaExecutionCtx,
) -> VortexResult<ArrowDeviceArray>;

/// Export a Vortex array as an [`ArrowDeviceArray`] with a matching Arrow C schema.
async fn export_device_array_with_schema(
&self,
array: ArrayRef,
ctx: &mut CudaExecutionCtx,
) -> VortexResult<ArrowDeviceArrayWithSchema> {
let schema = arrow_schema_for_array(&array, ctx)?;
let array = self.export_device_array(array, ctx).await?;
Ok(ArrowDeviceArrayWithSchema { schema, array })
}
}
62 changes: 54 additions & 8 deletions vortex-ffi/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,30 @@ pub(crate) fn vx_error_new(message: &str) -> *mut vx_error {
}

/// Write an error message to `error` which has not been populated before.
/// A null `error` pointer discards the message.
pub(crate) fn write_error(error: *mut *mut vx_error, message: &str) {
assert!(!error.is_null());
if error.is_null() {
return;
}
unsafe { error.write(vx_error_new(message)) };
}

/// Clear `*error_out` to null unless `error_out` itself is null.
fn clear_error(error_out: *mut *mut vx_error) {
if error_out.is_null() {
return;
}
unsafe { error_out.write(ptr::null_mut()) };
}

#[inline]
pub fn try_or_default<T: Default>(
error_out: *mut *mut vx_error,
function: impl FnOnce() -> VortexResult<T>,
) -> T {
match function() {
Ok(value) => {
unsafe { error_out.write(ptr::null_mut()) };
clear_error(error_out);
value
}
Err(err) => {
Expand All @@ -51,19 +62,16 @@ pub fn try_or_default<T: Default>(

/// Run `function`, returning its value on success and `error_value` on failure.
///
/// On success `*error_out` is cleared to null; on failure the error is written to `*error_out`
/// when it is non-null.
// Writes through `error_out` but stays safe like the other error-out helpers here; the raw-pointer
// contract is documented at the C boundary.
#[allow(clippy::not_unsafe_ptr_arg_deref)]
/// `error_out` may be null, in which case error details are discarded. When it is non-null,
/// `*error_out` is cleared to null on success and set to an owned `vx_error` on failure.
pub fn try_or<T>(
error_out: *mut *mut vx_error,
error_value: T,
function: impl FnOnce() -> VortexResult<T>,
) -> T {
match function() {
Ok(value) => {
unsafe { error_out.write(ptr::null_mut()) };
clear_error(error_out);
value
}
Err(err) => {
Expand All @@ -81,3 +89,41 @@ pub fn try_or<T>(
pub unsafe extern "C-unwind" fn vx_error_get_message(error: *const vx_error) -> *const vx_string {
vx_string::new_ref(&vx_error::as_ref(error).message)
}

#[cfg(test)]
mod tests {
use std::ptr;

use vortex::error::vortex_err;

use super::*;
use crate::error::vx_error_free;

#[test]
fn test_try_or_null_error_out() {
// A null error_out must be tolerated on both the success and failure paths.
assert_eq!(try_or(ptr::null_mut(), -1, || Ok(42)), 42);
assert_eq!(try_or(ptr::null_mut(), -1, || Err(vortex_err!("boom"))), -1);
}

#[test]
fn test_try_or_default_null_error_out() {
assert_eq!(try_or_default(ptr::null_mut(), || Ok(42)), 42);
assert_eq!(
try_or_default::<i32>(ptr::null_mut(), || Err(vortex_err!("boom"))),
0
);
}

#[test]
fn test_try_or_writes_and_clears_error_out() {
let mut error: *mut vx_error = ptr::null_mut();

assert_eq!(try_or(&raw mut error, -1, || Err(vortex_err!("boom"))), -1);
assert!(!error.is_null());
unsafe { vx_error_free(error) };

assert_eq!(try_or(&raw mut error, -1, || Ok(42)), 42);
assert!(error.is_null());
}
}
Loading