Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions arrow-select/src/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! [`take`]: crate::take::take
use crate::filter::filter_record_batch;
use arrow_array::types::{BinaryViewType, StringViewType};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
use arrow_array::{downcast_primitive, Array, ArrayRef, BooleanArray, RecordBatch};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use std::collections::VecDeque;
use std::sync::Arc;
Expand All @@ -31,9 +31,11 @@ use std::sync::Arc;

mod byte_view;
mod generic;
mod primitive;

use byte_view::InProgressByteViewArray;
use generic::GenericInProgressArray;
use primitive::InProgressPrimitiveArray;

/// Concatenate multiple [`RecordBatch`]es
///
Expand Down Expand Up @@ -322,7 +324,15 @@ impl BatchCoalescer {

/// Return a new `InProgressArray` for the given data type
fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> Box<dyn InProgressArray> {
match data_type {
macro_rules! instantiate_primitive {
($t:ty) => {
Box::new(InProgressPrimitiveArray::<$t>::new(batch_size))
};
}

downcast_primitive! {
// Instantiate InProgressPrimitiveArray for each primitive type
data_type => (instantiate_primitive),
DataType::Utf8View => Box::new(InProgressByteViewArray::<StringViewType>::new(batch_size)),
DataType::BinaryView => {
Box::new(InProgressByteViewArray::<BinaryViewType>::new(batch_size))
Expand Down Expand Up @@ -364,7 +374,9 @@ mod tests {
use crate::concat::concat_batches;
use arrow_array::builder::StringViewBuilder;
use arrow_array::cast::AsArray;
use arrow_array::{BinaryViewArray, RecordBatchOptions, StringViewArray, UInt32Array};
use arrow_array::{
BinaryViewArray, RecordBatchOptions, StringArray, StringViewArray, UInt32Array,
};
use arrow_schema::{DataType, Field, Schema};
use std::ops::Range;

Expand Down Expand Up @@ -456,6 +468,27 @@ mod tests {
.run();
}

#[test]
fn test_coalesce_non_null() {
Test::new()
// 4040 rows of unit32
.with_batch(uint32_batch_non_null(0..3000))
.with_batch(uint32_batch_non_null(0..1040))
.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
.run();
}
#[test]
fn test_utf8_split() {
Test::new()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this case is needed to cover GenericInProgressArray

// 4040 rows of utf8 strings in total, split into batches of 1024
.with_batch(utf8_batch(0..3000))
.with_batch(utf8_batch(0..1040))
.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
.run();
}

#[test]
fn test_string_view_no_views() {
let output_batches = Test::new()
Expand Down Expand Up @@ -941,15 +974,37 @@ mod tests {
}
}

/// Return a RecordBatch with a UInt32Array with the specified range
/// Return a RecordBatch with a UInt32Array with the specified range and
/// every third value is null.
fn uint32_batch(range: Range<u32>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)]));

let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) }));
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch with a UInt32Array with no nulls specified range
fn uint32_batch_non_null(range: Range<u32>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));

RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(UInt32Array::from_iter_values(range))],
)
.unwrap()
let array = UInt32Array::from_iter_values(range);
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch with a StringArrary with values `value0`, `value1`, ...
/// and every third value is `None`.
fn utf8_batch(range: Range<u32>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::Utf8, true)]));

let array = StringArray::from_iter(range.map(|i| {
if i % 3 == 0 {
None
} else {
Some(format!("value{}", i))
}
}));

RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch with a StringViewArray with (only) the specified values
Expand All @@ -960,14 +1015,11 @@ mod tests {
false,
)]));

RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(StringViewArray::from_iter(values))],
)
.unwrap()
let array = StringViewArray::from_iter(values);
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch with a StringViewArray with num_rows by repating
/// Return a RecordBatch with a StringViewArray with num_rows by repeating
/// values over and over.
fn stringview_batch_repeated<'a>(
num_rows: usize,
Expand Down
101 changes: 101 additions & 0 deletions arrow-select/src/coalesce/primitive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::coalesce::InProgressArray;
use arrow_array::cast::AsArray;
use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
use arrow_schema::ArrowError;
use std::fmt::Debug;
use std::sync::Arc;

/// InProgressArray for [`PrimitiveArray`]
#[derive(Debug)]
pub(crate) struct InProgressPrimitiveArray<T: ArrowPrimitiveType> {
/// The current source, if any
source: Option<ArrayRef>,
/// the target batch size (and thus size for views allocation)
batch_size: usize,
/// In progress nulls
nulls: NullBufferBuilder,
/// The currently in progress array
current: Vec<T::Native>,
}

impl<T: ArrowPrimitiveType> InProgressPrimitiveArray<T> {
/// Create a new `InProgressPrimitiveArray`
pub(crate) fn new(batch_size: usize) -> Self {
Self {
batch_size,
source: None,
nulls: NullBufferBuilder::new(batch_size),
current: vec![],
}
}

/// Allocate space for output values if necessary.
///
/// This is done on write (when we know it is necessary) rather than
/// eagerly to avoid allocations that are not used.
fn ensure_capacity(&mut self) {
self.current.reserve(self.batch_size);
}
}

impl<T: ArrowPrimitiveType + Debug> InProgressArray for InProgressPrimitiveArray<T> {
fn set_source(&mut self, source: Option<ArrayRef>) {
self.source = source;
}

fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError> {
self.ensure_capacity();

let s = self
.source
.as_ref()
.ok_or_else(|| {
ArrowError::InvalidArgumentError(
"Internal Error: InProgressPrimitiveArray: source not set".to_string(),
)
})?
.as_primitive::<T>();

// add nulls if necessary
if let Some(nulls) = s.nulls().as_ref() {
let nulls = nulls.slice(offset, len);
self.nulls.append_buffer(&nulls);
} else {
self.nulls.append_n_non_nulls(len);
};

// Copy the values
self.current
.extend_from_slice(&s.values()[offset..offset + len]);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the meat of this PR: copy the values into the final destination incrementally


Ok(())
}

fn finish(&mut self) -> Result<ArrayRef, ArrowError> {
// take and reset the current values and nulls
let values = std::mem::take(&mut self.current);
let nulls = self.nulls.finish();
self.nulls = NullBufferBuilder::new(self.batch_size);

let array = PrimitiveArray::<T>::try_new(ScalarBuffer::from(values), nulls)?;
Ok(Arc::new(array))
}
}
Loading