From 55fac1431fb5189ebe8b728c3bc2218d9fb6c8a7 Mon Sep 17 00:00:00 2001 From: Mryange Date: Sat, 27 Jun 2026 21:58:40 +0800 Subject: [PATCH] upd --- .../exprs/aggregate/aggregate_function_avg.h | 24 +- .../exprs/aggregate/aggregate_function_sum.h | 5 +- .../array/function_array_aggregation.cpp | 745 +++++++++--------- .../function/array/function_array_join.h | 16 - .../function/array/function_array_mapped.h | 4 +- 5 files changed, 391 insertions(+), 403 deletions(-) diff --git a/be/src/exprs/aggregate/aggregate_function_avg.h b/be/src/exprs/aggregate/aggregate_function_avg.h index 8cd542401a26a2..3062a2c6564cb8 100644 --- a/be/src/exprs/aggregate/aggregate_function_avg.h +++ b/be/src/exprs/aggregate/aggregate_function_avg.h @@ -21,8 +21,8 @@ #pragma once #include -#include +#include #include #include #include @@ -60,6 +60,28 @@ struct AggregateFunctionAvgData { AggregateFunctionAvgData& operator=(const AggregateFunctionAvgData& src) = default; + void reset() { + sum = {}; + count = 0; + } + + template + NO_SANITIZE_UNDEFINED void add(typename PrimitiveTypeTraits::CppType value) { +#ifdef __clang__ +#pragma clang fp reassociate(on) +#endif + if constexpr (InputType == TYPE_DECIMALV2) { + sum += value; + } else if constexpr (is_decimal(InputType)) { + sum += value.value; + } else { + sum += static_cast(value); + } + ++count; + } + + bool has_value() const { return count > 0; } + template ResultT result(ResultType multiplier) const { if (!count) { diff --git a/be/src/exprs/aggregate/aggregate_function_sum.h b/be/src/exprs/aggregate/aggregate_function_sum.h index c42c77f7d13d90..d6099356e63ee8 100644 --- a/be/src/exprs/aggregate/aggregate_function_sum.h +++ b/be/src/exprs/aggregate/aggregate_function_sum.h @@ -20,8 +20,7 @@ #pragma once -#include - +#include #include #include @@ -50,6 +49,8 @@ template struct AggregateFunctionSumData { typename PrimitiveTypeTraits::CppType sum {}; + void reset() { sum = {}; } + NO_SANITIZE_UNDEFINED void add(typename PrimitiveTypeTraits::CppType value) { #ifdef __clang__ #pragma clang fp reassociate(on) diff --git a/be/src/exprs/function/array/function_array_aggregation.cpp b/be/src/exprs/function/array/function_array_aggregation.cpp index e9a20652ed21d4..c90515ec64de69 100644 --- a/be/src/exprs/function/array/function_array_aggregation.cpp +++ b/be/src/exprs/function/array/function_array_aggregation.cpp @@ -18,20 +18,22 @@ // https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/array/arrayAggregation.cpp // and modified by Doris -#include -#include - +#include +#include +#include #include #include #include -#include "common/exception.h" +#include "common/check.h" +#include "common/compare.h" #include "common/status.h" -#include "core/arena.h" #include "core/block/block.h" #include "core/block/column_numbers.h" +#include "core/call_on_type_index.h" #include "core/column/column.h" #include "core/column/column_array.h" +#include "core/column/column_array_view.h" #include "core/column/column_decimal.h" #include "core/column/column_nullable.h" #include "core/data_type/data_type.h" @@ -40,18 +42,14 @@ #include "core/data_type/define_primitive_type.h" #include "core/data_type/primitive_type.h" #include "core/types.h" -#include "exprs/aggregate/aggregate_function.h" #include "exprs/aggregate/aggregate_function_avg.h" -#include "exprs/aggregate/aggregate_function_min_max.h" #include "exprs/aggregate/aggregate_function_product.h" -#include "exprs/aggregate/aggregate_function_simple_factory.h" #include "exprs/aggregate/aggregate_function_sum.h" -#include "exprs/aggregate/helpers.h" #include "exprs/function/array/function_array_join.h" #include "exprs/function/array/function_array_mapped.h" #include "exprs/function/function.h" #include "exprs/function/simple_function_factory.h" -#include "storage/predicate/column_predicate.h" +#include "util/simd/bits.h" namespace doris { @@ -75,11 +73,18 @@ template <> struct AggregateFunctionTraits { template struct TypeTraits { - static constexpr PrimitiveType ResultType = - Element == TYPE_DECIMALV2 ? TYPE_DECIMALV2 - : is_float_or_double(Element) - ? TYPE_DOUBLE - : (Element == TYPE_LARGEINT ? TYPE_LARGEINT : TYPE_BIGINT); + static consteval PrimitiveType get_result_type() { + if constexpr (Element == TYPE_DECIMALV2) { + return TYPE_DECIMALV2; + } else if constexpr (is_float_or_double(Element)) { + return TYPE_DOUBLE; + } else if constexpr (Element == TYPE_LARGEINT) { + return TYPE_LARGEINT; + } else { + return TYPE_BIGINT; + } + } + static constexpr PrimitiveType ResultType = get_result_type(); using AggregateDataType = AggregateFunctionSumData; using Function = AggregateFunctionSum; }; @@ -109,28 +114,232 @@ struct AggregateFunctionTraits { }; }; -template -struct ArrayAggregateFunctionCreator { - template - using Function = typename Derived::template TypeTraits::Function; - - static auto create(const DataTypePtr& data_type_ptr, const AggregateFunctionAttr& attr) - -> AggregateFunctionPtr { - if constexpr (AggOp == AggregateOperation::MIN || AggOp == AggregateOperation::MAX) { - return creator_with_type_list< - TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, TYPE_LARGEINT, TYPE_FLOAT, - TYPE_DOUBLE, TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I, - TYPE_DECIMAL256>::create(DataTypes {make_nullable(data_type_ptr)}, - true, attr); +template +inline constexpr bool is_array_agg_decimal = is_decimal(T) || T == TYPE_DECIMALV2; + +template +typename PrimitiveTypeTraits::CppType decimal_scale_multiplier(UInt32 scale) { + typename PrimitiveTypeTraits::CppType multiplier {}; + multiplier.value = PrimitiveTypeTraits::DataType::get_scale_multiplier(scale); + return multiplier; +} + +template +struct ArrayAggState; + +template +struct ArrayAggState { + using InputCppType = typename PrimitiveTypeTraits::CppType; + using ResultCppType = typename PrimitiveTypeTraits::CppType; + using ResultColumnType = typename PrimitiveTypeTraits::ColumnType; + + AggregateFunctionSumData data; + bool has = false; + + ArrayAggState(UInt32 /*input_scale*/, UInt32 /*result_scale*/) {} + + void reset() { + data.reset(); + has = false; + } + + void update(InputCppType value) { + data.add(ResultCppType(value)); + has = true; + } + + bool has_value() const { return has; } + + void insert_result_into(IColumn& to) const { + assert_cast(to).get_data().push_back(data.get()); + } +}; + +template +struct ArrayAggState { + using InputCppType = typename PrimitiveTypeTraits::CppType; + using ResultCppType = typename PrimitiveTypeTraits::CppType; + using ResultColumnType = typename PrimitiveTypeTraits::ColumnType; + + AggregateFunctionAvgData data; + ResultCppType multiplier {}; + + explicit ArrayAggState(UInt32 input_scale, UInt32 result_scale) { + if constexpr (is_array_agg_decimal) { + multiplier = decimal_scale_multiplier(result_scale - input_scale); + } + } + + void reset() { data.reset(); } + + void update(InputCppType value) { data.template add(value); } + + bool has_value() const { return data.has_value(); } + + void insert_result_into(IColumn& to) const { + auto& column = assert_cast(to); + if constexpr (is_array_agg_decimal) { + column.get_data().push_back(data.template result(multiplier)); + } else { + column.get_data().push_back(data.template result()); + } + } +}; + +template +struct ArrayAggState { + using InputCppType = typename PrimitiveTypeTraits::CppType; + using ResultCppType = typename PrimitiveTypeTraits::CppType; + using ResultColumnType = typename PrimitiveTypeTraits::ColumnType; + + AggregateFunctionProductData data; + ResultCppType multiplier {}; + bool has = false; + + explicit ArrayAggState(UInt32 input_scale, UInt32 /*result_scale*/) { + if constexpr (is_array_agg_decimal) { + multiplier = decimal_scale_multiplier(input_scale); + } + } + + void reset() { + if constexpr (is_array_agg_decimal) { + data.reset(multiplier); } else { - return creator_with_type_list< - TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, TYPE_LARGEINT, TYPE_FLOAT, - TYPE_DOUBLE>::create(DataTypes {make_nullable(data_type_ptr)}, true, - attr); + data.reset(ResultCppType(1)); } + has = false; + } + + void update(InputCppType value) { + data.add(ResultCppType(value), multiplier); + has = true; + } + + bool has_value() const { return has; } + + void insert_result_into(IColumn& to) const { + assert_cast(to).get_data().push_back(data.get()); } }; +template +struct ArrayMinMaxState { + using CppType = typename ColumnElementView::ElementType; + using ResultColumnType = typename PrimitiveTypeTraits::ColumnType; + + bool has = false; + CppType value {}; + + ArrayMinMaxState(UInt32 /*input_scale*/, UInt32 /*result_scale*/) {} + + void reset() { + has = false; + value = {}; + } + + void update(CppType input) { + if (!has) { + value = input; + has = true; + return; + } + if constexpr (is_min) { + if (Compare::less(input, value)) { + value = input; + } + } else { + if (Compare::greater(input, value)) { + value = input; + } + } + } + + bool has_value() const { return has; } + + void insert_result_into(IColumn& to) const { + if constexpr (is_string_type(InputType)) { + assert_cast(to).insert_data(value.data, value.size); + } else { + assert_cast(to).get_data().push_back(value); + } + } +}; + +template +struct ArrayAggState + : public ArrayMinMaxState { + using ArrayMinMaxState::ArrayMinMaxState; +}; + +template +struct ArrayAggState + : public ArrayMinMaxState { + using ArrayMinMaxState::ArrayMinMaxState; +}; + +template +bool execute_array_agg_fast_path(ColumnPtr& res_ptr, const ColumnPtr& array_column, + MutableColumnPtr result_column, UInt32 input_scale, + UInt32 result_scale) { + using State = ArrayAggState; + + auto array_view = ColumnArrayView::create(array_column); + auto res_column = ColumnNullable::create(std::move(result_column), ColumnUInt8::create()); + auto& nullable_result = *res_column; + auto& nested_result = nullable_result.get_nested_column(); + auto& null_map = nullable_result.get_null_map_data(); + nested_result.reserve(array_view.size()); + null_map.reserve(array_view.size()); + const bool has_nested_null = + array_view.size() > 0 && simd::contain_one(array_view.get_null_map_data(), + array_view.row_end(array_view.size() - 1)); + + State state(input_scale, result_scale); + for (size_t row = 0; row < array_view.size(); ++row) { + if (array_view.is_null_at(row)) { + nullable_result.insert_default(); + continue; + } + + state.reset(); + if constexpr (is_string_type(Element)) { + auto array_data = array_view[row]; + for (size_t i = 0; i < array_data.size(); ++i) { + if (!array_data.is_null_at(i)) { + state.update(array_data.value_at(i)); + } + } + } else { + if (has_nested_null) { + auto array_data = array_view[row]; + for (size_t i = 0; i < array_data.size(); ++i) { + if (!array_data.is_null_at(i)) { + state.update(array_data.value_at(i)); + } + } + } else { + const auto* data = array_view.get_data(); + const auto begin = array_view.row_begin(row); + const auto end = array_view.row_end(row); + for (size_t i = begin; i < end; ++i) { + state.update(data[i]); + } + } + } + + if (state.has_value()) { + state.insert_result_into(nested_result); + null_map.push_back(0); + } else { + nullable_result.insert_default(); + } + } + + res_ptr = std::move(res_column); + return true; +} + template struct ArrayAggregateImpl { using column_type = ColumnArray; @@ -140,179 +349,136 @@ struct ArrayAggregateImpl { static size_t _get_number_of_arguments() { return 1; } - static bool skip_return_type_check() { return false; } - - static DataTypePtr get_return_type(const DataTypes& arguments) { - using Function = - ArrayAggregateFunctionCreator>; - const DataTypeArray* data_type_array = - static_cast(remove_nullable(arguments[0]).get()); - auto function = Function::create(data_type_array->get_nested_type(), - {.is_window_function = false, .column_names = {}}); - if (function) { - return function->get_return_type(); - } else { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Unexpected type {} for aggregation {}", - data_type_array->get_nested_type()->get_name(), operation); - } + static Status execute(Block& block, const ColumnNumbers& arguments, uint32_t result, + const DataTypeArray* data_type_array, const ColumnArray& array) { + return execute(block, arguments, result, nullptr, data_type_array, array); } static Status execute(Block& block, const ColumnNumbers& arguments, uint32_t result, - const DataTypeArray* data_type_array, const ColumnArray& array) { + const DataTypePtr& result_type, const DataTypeArray* data_type_array, + const ColumnArray& array) { ColumnPtr res; DataTypePtr type = data_type_array->get_nested_type(); - const IColumn* data = array.get_data_ptr().get(); - - const auto& offsets = array.get_offsets(); - if constexpr (operation == AggregateOperation::MAX || - operation == AggregateOperation::MIN) { - // min/max can only be applied on ip type - if (execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets)) { - block.replace_by_position(result, std::move(res)); - return Status::OK(); + ColumnPtr array_column = array.get_ptr(); + + const auto nested_type = remove_nullable(type)->get_primitive_type(); + const bool matched = dispatch_switch_all(nested_type, [&](auto type_tag) { + constexpr PrimitiveType element_type = decltype(type_tag)::PType; + if constexpr (element_type == TYPE_DATE || element_type == TYPE_DATETIME || + element_type == TYPE_TIMEV2 || element_type == TYPE_DECIMALV2) { + return false; + } else if constexpr (element_type == TYPE_IPV4 || element_type == TYPE_IPV6) { + if constexpr (operation != AggregateOperation::MAX && + operation != AggregateOperation::MIN) { + return false; + } else { + return execute_type(res, type, array_column, result_type); + } + } else { + return execute_type(res, type, array_column, result_type); } - } + }); - if (execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets) || - execute_type(res, type, data, offsets)) { + if (matched) { block.replace_by_position(result, std::move(res)); return Status::OK(); } else { - return Status::RuntimeError("Unexpected column for aggregation: {}", data->get_name()); - } - } - - template - static bool execute_type_impl(ColumnPtr& res_ptr, const DataTypePtr& type, const IColumn* data, - const ColumnArray::Offsets64& offsets, - CreateColumnFunc create_column_func) { - using Function = - ArrayAggregateFunctionCreator>; - - const ColumnType* column = - is_column_nullable(*data) - ? check_and_get_column( - static_cast(data)->get_nested_column()) - : check_and_get_column(&*data); - if (!column) { - return false; - } - - ColumnPtr res_column = create_column_func(column); - res_column = make_nullable(res_column); - assert_cast(res_column->assert_mutable_ref()).reserve(offsets.size()); - - auto function = Function::create(type, {.is_window_function = false, .column_names = {}}); - auto guard = AggregateFunctionGuard(function.get()); - Arena arena; - auto nullable_column = make_nullable(data->get_ptr()); - const IColumn* columns[] = {nullable_column.get()}; - for (int64_t i = 0; i < offsets.size(); ++i) { - auto start = offsets[i - 1]; // -1 is ok. - auto end = offsets[i]; - bool is_empty = (start == end); - if (is_empty) { - res_column->assert_mutable()->insert_default(); - continue; - } - function->reset(guard.data()); - function->add_batch_range(start, end - 1, guard.data(), columns, arena, - is_column_nullable(*data)); - function->insert_result_into(guard.data(), res_column->assert_mutable_ref()); + return Status::RuntimeError("Unexpected column for aggregation: {}", + array.get_data().get_name()); } - res_ptr = std::move(res_column); - return true; } template - static bool execute_type(ColumnPtr& res_ptr, const DataTypePtr& type, const IColumn* data, - const ColumnArray::Offsets64& offsets) { + static bool execute_type(ColumnPtr& res_ptr, const DataTypePtr& type, + const ColumnPtr& array_column, const DataTypePtr& result_type) { if constexpr (is_string_type(Element)) { - if (operation == AggregateOperation::SUM || operation == AggregateOperation::PRODUCT || - operation == AggregateOperation::AVERAGE) { + if constexpr (operation == AggregateOperation::SUM || + operation == AggregateOperation::PRODUCT || + operation == AggregateOperation::AVERAGE) { return false; + } else { + return execute_array_agg_fast_path( + res_ptr, array_column, ColumnString::create(), 0, 0); } - - auto create_column = [](const ColumnString*) -> ColumnPtr { - return ColumnString::create(); - }; - - return execute_type_impl(res_ptr, type, data, - offsets, create_column); } else { - if constexpr ((operation == AggregateOperation::SUM || - operation == AggregateOperation::PRODUCT || - operation == AggregateOperation::AVERAGE) && - (is_date_type(Element) || is_timestamptz_type(Element) || - is_decimalv3(Element))) { + if constexpr (operation == AggregateOperation::SUM && Element == TYPE_BOOLEAN) { + return false; + } else if constexpr ((operation == AggregateOperation::SUM || + operation == AggregateOperation::PRODUCT || + operation == AggregateOperation::AVERAGE) && + (is_date_type(Element) || is_timestamptz_type(Element))) { return false; + } else if constexpr ((operation == AggregateOperation::SUM || + operation == AggregateOperation::PRODUCT || + operation == AggregateOperation::AVERAGE) && + is_decimalv3(Element)) { + return execute_decimalv3_type(res_ptr, type, array_column, result_type); } else { - using ColVecType = typename PrimitiveTypeTraits::ColumnType; static constexpr PrimitiveType ResultType = AggregateFunctionTraits< operation>::template TypeTraits::ResultType; using ColVecResultType = typename PrimitiveTypeTraits::ColumnType; - auto create_column = [](const ColVecType* column) -> ColumnPtr { + auto create_column = [](UInt32 scale) -> MutableColumnPtr { if constexpr (is_decimal(Element)) { - return ColVecResultType::create(0, column->get_scale()); + return ColVecResultType::create(0, scale); } else { return ColVecResultType::create(); } }; - return execute_type_impl( - res_ptr, type, data, offsets, create_column); + UInt32 input_scale = 0; + UInt32 result_scale = 0; + if constexpr (is_decimal(Element)) { + input_scale = get_decimal_scale(*remove_nullable(type)); + if constexpr (operation == AggregateOperation::AVERAGE) { + using AvgFunction = typename AggregateFunctionTraits< + operation>::template TypeTraits::Function; + result_scale = std::max(AvgFunction::DEFAULT_MIN_AVG_DECIMAL128_SCALE, + input_scale); + } else { + result_scale = input_scale; + } + } + return execute_array_agg_fast_path( + res_ptr, array_column, create_column(result_scale), input_scale, + result_scale); } } } + + template + static bool execute_decimalv3_type(ColumnPtr& res_ptr, const DataTypePtr& type, + const ColumnPtr& array_column, + const DataTypePtr& result_type) { + DORIS_CHECK(result_type != nullptr); + const auto result_primitive_type = remove_nullable(result_type)->get_primitive_type(); + return dispatch_switch_decimalv3(result_primitive_type, [&](auto result_type_tag) { + constexpr PrimitiveType result_type_value = decltype(result_type_tag)::PType; + if constexpr (result_type_value == TYPE_DECIMAL128I || + result_type_value == TYPE_DECIMAL256) { + using ColVecResultType = + typename PrimitiveTypeTraits::ColumnType; + + const auto input_scale = get_decimal_scale(*remove_nullable(type)); + const auto result_scale = get_decimal_scale(*remove_nullable(result_type)); + return execute_array_agg_fast_path( + res_ptr, array_column, ColVecResultType::create(0, result_scale), + input_scale, result_scale); + } else { + return false; + } + }); + } }; struct NameArrayMin { static constexpr auto name = "array_min"; }; -template <> -struct ArrayAggregateFunctionCreator> { - static auto create(const DataTypePtr& data_type_ptr, const AggregateFunctionAttr& attr) - -> AggregateFunctionPtr { - return create_aggregate_function_single_value( - NameArrayMin::name, {make_nullable(data_type_ptr)}, make_nullable(data_type_ptr), - true, attr); - } -}; - struct NameArrayMax { static constexpr auto name = "array_max"; }; -template <> -struct ArrayAggregateFunctionCreator> { - static auto create(const DataTypePtr& data_type_ptr, const AggregateFunctionAttr& attr) - -> AggregateFunctionPtr { - return create_aggregate_function_single_value( - NameArrayMax::name, {make_nullable(data_type_ptr)}, make_nullable(data_type_ptr), - true, attr); - } -}; - struct NameArraySum { static constexpr auto name = "array_sum"; }; @@ -337,154 +503,11 @@ using FunctionArrayProduct = using FunctionArrayJoin = FunctionArrayMapped; -template -struct AggregateFunctionTraitsWithResultType; - -template <> -struct AggregateFunctionTraitsWithResultType { - template - struct TypeTraits { - using AggregateDataType = AggregateFunctionSumData; - using Function = AggregateFunctionSum; - }; -}; -template <> -struct AggregateFunctionTraitsWithResultType { - template - struct TypeTraits { - using AggregateDataType = AggregateFunctionAvgData; - using Function = AggregateFunctionAvg; - }; -}; -template <> -struct AggregateFunctionTraitsWithResultType { - template - struct TypeTraits { - using AggregateDataType = AggregateFunctionProductData; - using Function = AggregateFunctionProduct; - }; -}; -template -struct ArrayAggregateFunctionCreatorWithResultType { - template - using Function = typename Derived::template TypeTraits::Function; - - static auto create(const DataTypePtr& data_type_ptr, const DataTypePtr& result_type_ptr, - const AggregateFunctionAttr& attr) -> AggregateFunctionPtr { - return creator_with_type_list< - TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I, - TYPE_DECIMAL256>::creator_with_result_type("", - DataTypes {make_nullable( - data_type_ptr)}, - result_type_ptr, true, attr); - } -}; -template -struct ArrayAggregateImplDecimalV3; -template - requires(operation == AggregateOperation::SUM || operation == AggregateOperation::PRODUCT || - operation == AggregateOperation::AVERAGE) -struct ArrayAggregateImplDecimalV3 { - using column_type = ColumnArray; - using data_type = DataTypeArray; - - static bool _is_variadic() { return false; } - - static size_t _get_number_of_arguments() { return 1; } - - static bool skip_return_type_check() { return true; } - - static DataTypePtr get_return_type(const DataTypes& arguments) { - throw doris::Exception( - ErrorCode::NOT_IMPLEMENTED_ERROR, - "get_return_type is not implemented for ArrayAggregateImplDecimalV3"); - __builtin_unreachable(); - } - - static Status execute(Block& block, const ColumnNumbers& arguments, uint32_t result, - const DataTypePtr& result_type, const DataTypeArray* data_type_array, - const ColumnArray& array) { - ColumnPtr res; - DataTypePtr type = data_type_array->get_nested_type(); - const IColumn* data = array.get_data_ptr().get(); - - const auto& offsets = array.get_offsets(); - - if (execute_type(res, result_type, type, data, offsets) || - execute_type(res, result_type, type, data, offsets) || - execute_type(res, result_type, type, data, offsets) || - execute_type(res, result_type, type, data, offsets)) { - block.replace_by_position(result, std::move(res)); - return Status::OK(); - } else { - return Status::RuntimeError("Unexpected column for aggregation: {}", data->get_name()); - } - } - - template - static bool execute_type_impl(ColumnPtr& res_ptr, const DataTypePtr& result_type, - const DataTypePtr& type, const IColumn* data, - const ColumnArray::Offsets64& offsets, - CreateColumnFunc create_column_func) { - using Function = ArrayAggregateFunctionCreatorWithResultType< - AggregateFunctionTraitsWithResultType>; - - const ColumnType* column = - is_column_nullable(*data) - ? check_and_get_column( - static_cast(data)->get_nested_column()) - : check_and_get_column(&*data); - if (!column) { - return false; - } - - ColumnPtr res_column = create_column_func(column); - res_column = make_nullable(res_column); - assert_cast(res_column->assert_mutable_ref()).reserve(offsets.size()); - - auto function = Function::create(type, result_type, - {.is_window_function = false, .column_names = {}}); - auto guard = AggregateFunctionGuard(function.get()); - Arena arena; - auto nullable_column = make_nullable(data->get_ptr()); - const IColumn* columns[] = {nullable_column.get()}; - for (int64_t i = 0; i < offsets.size(); ++i) { - auto start = offsets[i - 1]; // -1 is ok. - auto end = offsets[i]; - bool is_empty = (start == end); - if (is_empty) { - res_column->assert_mutable()->insert_default(); - continue; - } - function->reset(guard.data()); - function->add_batch_range(start, end - 1, guard.data(), columns, arena, - is_column_nullable(*data)); - function->insert_result_into(guard.data(), res_column->assert_mutable_ref()); - } - res_ptr = std::move(res_column); - return true; - } - - template - static bool execute_type(ColumnPtr& res_ptr, const DataTypePtr& result_type, - const DataTypePtr& type, const IColumn* data, - const ColumnArray::Offsets64& offsets) { - using ColVecType = typename PrimitiveTypeTraits::ColumnType; - using ColVecResultType = typename PrimitiveTypeTraits::ColumnType; - - auto create_column = [](const ColVecType* column) -> ColumnPtr { - return ColVecResultType::create(0, column->get_scale()); - }; - - return execute_type_impl(res_ptr, result_type, type, - data, offsets, create_column); - } -}; - -template +template class FunctionArrayAggDecimalV3 : public IFunction { public: static constexpr auto name = Name::name; + static_assert(is_decimalv3(ResultType)); explicit FunctionArrayAggDecimalV3(DataTypePtr result_type) : _result_type(std::move(result_type)) {} @@ -493,101 +516,61 @@ class FunctionArrayAggDecimalV3 : public IFunction { uint32_t result, size_t input_rows_count) const override { const auto& typed_column = block.get_by_position(arguments[0]); auto ptr = typed_column.column->convert_to_full_column_if_const(); - const typename Impl::column_type* column_array; + const ColumnArray* column_array; if (is_column_nullable(*ptr)) { - column_array = assert_cast( + column_array = assert_cast( assert_cast(ptr.get())->get_nested_column_ptr().get()); } else { - column_array = assert_cast(ptr.get()); + column_array = assert_cast(ptr.get()); } const auto* data_type_array = assert_cast(remove_nullable(typed_column.type).get()); - return Impl::execute(block, arguments, result, _result_type, data_type_array, - *column_array); + return ArrayAggregateImpl::execute(block, arguments, result, _result_type, + data_type_array, *column_array); } - bool is_variadic() const override { return Impl::_is_variadic(); } + bool is_variadic() const override { return ArrayAggregateImpl::_is_variadic(); } - size_t get_number_of_arguments() const override { return Impl::_get_number_of_arguments(); } + size_t get_number_of_arguments() const override { + return ArrayAggregateImpl::_get_number_of_arguments(); + } - bool skip_return_type_check() const override { return Impl::skip_return_type_check(); } + bool skip_return_type_check() const override { return true; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - return Impl::get_return_type(arguments); + return _result_type; } private: DataTypePtr _result_type; }; -template -struct ArraySumDecimalV3Attributes { - static_assert(is_decimalv3(ResultType)); - using AggregateDataType = AggregateFunctionSumData; - using Function = FunctionArrayAggDecimalV3< - ArrayAggregateImplDecimalV3, NameArraySum>; -}; -template -using ArraySumDecimalV3 = typename ArraySumDecimalV3Attributes::Function; -template -struct ArrayAvgDecimalV3Attributes { - static_assert(is_decimalv3(ResultType)); - using AggregateDataType = AggregateFunctionAvgData; - using Function = FunctionArrayAggDecimalV3< - ArrayAggregateImplDecimalV3, NameArrayAverage>; +template +struct ArrayAggDecimalV3Function { + template + using Type = FunctionArrayAggDecimalV3; }; -template -using ArrayAvgDecimalV3 = typename ArrayAvgDecimalV3Attributes::Function; -template -struct ArrayProductDecimalV3Attributes { - static_assert(is_decimalv3(ResultType)); - using AggregateDataType = AggregateFunctionProductData; - using Function = FunctionArrayAggDecimalV3< - ArrayAggregateImplDecimalV3, NameArrayProduct>; -}; -template -using ArrayProductDecimalV3 = typename ArrayProductDecimalV3Attributes::Function; +template +void register_array_reduce_agg_function(SimpleFunctionFactory& factory) { + ArrayAggFunctionCreator creator = [](const DataTypePtr& result_type) -> FunctionBuilderPtr { + if (is_decimalv3(result_type->get_primitive_type())) { + return DefaultFunctionBuilder::create_array_agg_function_decimalv3< + ArrayAggDecimalV3Function::template Type>(result_type); + } else { + return std::make_shared(Function::create()); + } + }; + factory.register_array_agg_function(Name::name, creator); +} + void register_array_reduce_agg_functions(SimpleFunctionFactory& factory) { - { - ArrayAggFunctionCreator creator = [&](const DataTypePtr& result_type) { - if (is_decimalv3(result_type->get_primitive_type())) { - return DefaultFunctionBuilder::create_array_agg_function_decimalv3< - ArraySumDecimalV3>(result_type); - } else { - FunctionBuilderPtr func = - std::make_shared(FunctionArraySum::create()); - return func; - } - }; - factory.register_array_agg_function(NameArraySum::name, creator); - } - { - ArrayAggFunctionCreator creator = [&](const DataTypePtr& result_type) { - if (is_decimalv3(result_type->get_primitive_type())) { - return DefaultFunctionBuilder::create_array_agg_function_decimalv3< - ArrayAvgDecimalV3>(result_type); - } else { - FunctionBuilderPtr func = - std::make_shared(FunctionArrayAverage::create()); - return func; - } - }; - factory.register_array_agg_function(NameArrayAverage::name, creator); - } - { - ArrayAggFunctionCreator creator = [&](const DataTypePtr& result_type) { - if (is_decimalv3(result_type->get_primitive_type())) { - return DefaultFunctionBuilder::create_array_agg_function_decimalv3< - ArrayProductDecimalV3>(result_type); - } else { - FunctionBuilderPtr func = - std::make_shared(FunctionArrayProduct::create()); - return func; - } - }; - factory.register_array_agg_function(NameArrayProduct::name, creator); - } + register_array_reduce_agg_function( + factory); + register_array_reduce_agg_function(factory); + register_array_reduce_agg_function(factory); } void register_function_array_aggregation(SimpleFunctionFactory& factory) { diff --git a/be/src/exprs/function/array/function_array_join.h b/be/src/exprs/function/array/function_array_join.h index d674851060c5c3..9b537094fdc1ca 100644 --- a/be/src/exprs/function/array/function_array_join.h +++ b/be/src/exprs/function/array/function_array_join.h @@ -40,22 +40,6 @@ struct ArrayJoinImpl { static size_t _get_number_of_arguments() { return 0; } - static DataTypePtr get_return_type(const DataTypes& arguments) { - DCHECK(arguments[0]->get_primitive_type() == TYPE_ARRAY) - << "first argument for function: array_join should be DataTypeArray" - << " and arguments[0] is " << arguments[0]->get_name(); - DCHECK(is_string_type(arguments[1]->get_primitive_type())) - << "second argument for function: array_join should be DataTypeString" - << ", and arguments[1] is " << arguments[1]->get_name(); - if (arguments.size() > 2) { - DCHECK(is_string_type(arguments[2]->get_primitive_type())) - << "third argument for function: array_join should be DataTypeString" - << ", and arguments[2] is " << arguments[2]->get_name(); - } - - return std::make_shared(); - } - static Status execute(Block& block, const ColumnNumbers& arguments, uint32_t result, const DataTypeArray* data_type_array, const ColumnArray& array) { ColumnPtr src_column = block.get_by_position(arguments[0]).column; diff --git a/be/src/exprs/function/array/function_array_mapped.h b/be/src/exprs/function/array/function_array_mapped.h index 83296f2eba2e50..ae6983bef2108e 100644 --- a/be/src/exprs/function/array/function_array_mapped.h +++ b/be/src/exprs/function/array/function_array_mapped.h @@ -65,9 +65,7 @@ class FunctionArrayMapped : public IFunction { size_t get_number_of_arguments() const override { return Impl::_get_number_of_arguments(); } - DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - return Impl::get_return_type(arguments); - } + bool skip_return_type_check() const override { return true; } }; } // namespace doris