Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1746,7 +1746,7 @@ impl ScalarValue {
}

/// Converts `Vec<ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
/// `data_type`, to a single element [`ListArray`].
///
/// Example
/// ```
Expand Down
32 changes: 25 additions & 7 deletions datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::binary_map::OutputType;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};

/// Expression for a COUNT(DISTINCT) aggregation.
/// Expression for a `COUNT(DISTINCT)` aggregation.
#[derive(Debug)]
pub struct DistinctCount {
/// Column name
Expand Down Expand Up @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount {
use TimeUnit::*;

Ok(match &self.state_data_type {
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
Expand Down Expand Up @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount {
OutputType::Binary,
)),

// Use the generic accumulator based on `ScalarValue` for all other types
_ => Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
Expand All @@ -183,7 +185,11 @@ impl PartialEq<dyn Any> for DistinctCount {
}

/// General purpose distinct accumulator that works for any DataType by using
/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
/// [`ScalarValue`].
///
/// It stores intermediate results as a `ListArray`
///
/// Note that many types have specialized accumulators that are (much)
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
/// [`BytesDistinctCountAccumulator`]
#[derive(Debug)]
Expand All @@ -193,8 +199,9 @@ struct DistinctCountAccumulator {
}

impl DistinctCountAccumulator {
// calculating the size for fixed length values, taking first batch size * number of batches
// This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
// calculating the size for fixed length values, taking first batch size *
// number of batches This method is faster than .full_size(), however it is
// not suitable for variable length values like strings or complex types
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
Expand All @@ -207,7 +214,8 @@ impl DistinctCountAccumulator {
+ std::mem::size_of::<DataType>()
}

// calculates the size as accurate as possible, call to this method is expensive
// calculates the size as accurately as possible. Note that calling this
// method is expensive
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
Expand All @@ -221,6 +229,7 @@ impl DistinctCountAccumulator {
}

impl Accumulator for DistinctCountAccumulator {
/// Returns the distinct values seen so far as (one element) ListArray.
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let scalars = self.values.iter().cloned().collect::<Vec<_>>();
let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type);
Expand All @@ -246,15 +255,24 @@ impl Accumulator for DistinctCountAccumulator {
})
}

/// Merges multiple sets of distinct values into the current set.
///
/// The input to this function is a `ListArray` with **multiple** rows,
/// where each row contains the values from a partial aggregate's phase (e.g.
/// the result of calling `Self::state` on multiple accumulators).
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let array = &states[0];
let list_array = array.as_list::<i32>();
let inner_array = list_array.value(0);
self.update_batch(&[inner_array])
for inner_array in list_array.iter() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual bug fix -- to use all rows not just the first. The rest of this PR is tests / comment improvements

let inner_array = inner_array
.expect("counts are always non null, so are intermediate results");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this when updating my local repo .... is expect something that should be used here ... my understanding that it panics on None. Given the method returns Result I would expect err to be returned instead - am I missing something in my understanding of Rust here?

Copy link
Contributor

@jayzhan211 jayzhan211 Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It panics, but it is fine if it is ensured to be non-null. I am looking into how the array was built in the above comment but failed. 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point and I think it would be a better UX to avoid panic'ing even if something "impossible" happens. I made the change in #9712

self.update_batch(&[inner_array])?;
}
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Expand Down
67 changes: 67 additions & 0 deletions datafusion/sqllogictest/test_files/dictionary.slt
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,70 @@ ORDER BY
2023-12-20T01:20:00 1000 f2 foo
2023-12-20T01:30:00 1000 f1 32.0
2023-12-20T01:30:00 1000 f2 foo

# Cleanup
statement ok
drop view m1;

statement ok
drop view m2;

######
# Create a table using UNION ALL to get 2 partitions (very important)
######
statement ok
create table m3_source as
select * from (values('foo', 'bar', 1))
UNION ALL
select * from (values('foo', 'baz', 1));

######
# Now, create a table with the same data, but column2 has type `Dictionary(Int32)` to trigger the fallback code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the cast to the dictionary trigger the fallback code? Does it refer to merge_batch?

Copy link
Contributor

@jayzhan211 jayzhan211 Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specifically, why the key of dict is the sub group index after casting? 🤔

group 1: "a", "b",
group 2: "c"

we get
(0, a), (1, "b"), and (0, "c")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the cast to the dictionary trigger the fallback code?

The reason the dictionary triggers merge is that when grouping on strings or primitive values, the DistinctCountAccumulator code path is not used. Instead one of the specialized implementations (like BytesDistinctCountAccumulator) is used instead, which use the GroupsAccumulator interface.

Dictionary encoded columns run this path DistinctCountAccumulator
https://github.com/apache/arrow-datafusion/blob/b0b329ba39403b9e87156d6f9b8c5464dc6d2480/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs#L160-L163

specifically, why the key of dict is the sub group index after casting? 🤔

What is happening is that we are doing a two phase groupby (illustated here)

https://github.com/apache/arrow-datafusion/blob/b0b329ba39403b9e87156d6f9b8c5464dc6d2480/datafusion/expr/src/accumulator.rs#L99-L131

And so there are two different Partial group bys happening. Each PartialGroupBy produces a a set of distinct values. Using your example, I think it would be more like the following (where we have the same group in multiple partial results):

group 1 (partial): "a", "b",
group 1 (partial): "c"

The merge is called to combine the results together with a two element array

("a, "b")
("c")

But I may be misunderstanding your question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also filed #9695 to add some more coverage of array operations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about how and where the DictionarayArray has been built. It is quite hard to trace the previous caller of GroupedHashAggregateStream::poll_next with RUST_BACKTRACE.

https://github.com/apache/arrow-datafusion/blob/b0b329ba39403b9e87156d6f9b8c5464dc6d2480/datafusion/physical-plan/src/aggregates/row_hash.rs#L434

batch: RecordBatch { schema: Schema { fields: [Field { name: "column3", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "COUNT(DISTINCT m3.column1)[count distinct]", data_type: List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "COUNT(DISTINCT m3.column2)[count distinct]", data_type: List(Field { name: "item", data_type: Dictionary(Int32, Utf8), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }, columns: [PrimitiveArray<Int64>
[
  1,
  1,
], ListArray
[
  StringArray
[
  "foo",
],
  StringArray
[
  "foo",
],
], ListArray
[
  DictionaryArray {keys: PrimitiveArray<Int32>
[
  0,
] values: StringArray
[
  "bar",
  "baz",
]}
,
  DictionaryArray {keys: PrimitiveArray<Int32>
[
  1,
] values: StringArray
[
  "bar",
  "baz",
]}
,
]], row_count: 2 }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about how and where the DictionarayArray has been built.

I think it comes from emitting ScalarValue::Dictionary that are combined into an array here

https://github.com/apache/arrow-datafusion/blob/b87dd6143c2dc089b07f74780bd525c4369e68a3/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs#L304-L309

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems converted at this point already

######
statement ok
create table m3 as
select
column1,
arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2",
column3
from m3_source;

# there are two values in column2
query T?I rowsort
SELECT *
FROM m3;
----
foo bar 1
foo baz 1

# There is 1 distinct value in column1
query I
SELECT count(distinct column1)
FROM m3
GROUP BY column3;
----
1

# There are 2 distinct values in column2
query I
SELECT count(distinct column2)
FROM m3
GROUP BY column3;
----
2

# Should still get the same results when querying in the same query
query II
SELECT count(distinct column1), count(distinct column2)
FROM m3
GROUP BY column3;
----
1 2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this query returns 1 1 without the code change in this PR



# Cleanup
statement ok
drop table m3;

statement ok
drop table m3_source;