diff --git a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/AccumulatorFactoryProvider.java b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/AccumulatorFactoryProvider.java new file mode 100644 index 0000000000000..c1bce8b31947f --- /dev/null +++ b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/AccumulatorFactoryProvider.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.internal.processors.query.calcite.exec.exp.agg; + +import java.util.function.Supplier; +import org.apache.calcite.plan.Context; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.tools.Frameworks; +import org.apache.ignite.internal.processors.query.calcite.exec.ExecutionContext; +import org.apache.ignite.plugin.PluginProvider; +import org.jetbrains.annotations.Nullable; + +/** + * Factory that selects and creates an accumulator supplier for an aggregate call. Allows overriding standard aggregate + * functions. + * + *

It can be set via {@link PluginProvider} when creating a configuration using + * {@link PluginProvider#createComponent} via {@link Frameworks.ConfigBuilder#context(Context)}.

+ */ +@FunctionalInterface +public interface AccumulatorFactoryProvider { + /** @return Accumulator supplier, {@code null} if no accumulator is required for this aggregate call. */ + @Nullable Supplier> factory(AggregateCall call, ExecutionContext ctx); +} diff --git a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java index 9c15e4a7a2375..085d212be25d2 100644 --- a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java +++ b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java @@ -71,6 +71,15 @@ private static Supplier> accumulatorFunctionFactory( ) { RowHandler hnd = ctx.rowHandler(); + AccumulatorFactoryProvider prov = ctx.unwrap(AccumulatorFactoryProvider.class); + + if (prov != null) { + Supplier> fac = prov.factory(call, ctx); + + if (fac != null) + return fac; + } + switch (call.getAggregation().getName()) { case "COUNT": return () -> new LongCount<>(call, hnd); @@ -280,7 +289,7 @@ private static Supplier> maxFactory(AggregateCall call, R } /** */ - private abstract static class AbstractAccumulator implements Accumulator { + public abstract static class AbstractAccumulator implements Accumulator { /** */ private final RowHandler hnd; @@ -288,13 +297,13 @@ private abstract static class AbstractAccumulator implements Accumulator hnd) { + protected AbstractAccumulator(AggregateCall aggCall, RowHandler hnd) { this.aggCall = aggCall; this.hnd = hnd; } /** */ - T get(int idx, Row row) { + protected T get(int idx, Row row) { assert idx < arguments().size() : "idx=" + idx + "; arguments=" + arguments(); return (T)hnd.get(arguments().get(idx), row); @@ -311,7 +320,7 @@ protected List arguments() { } /** */ - int columnCount(Row row) { + protected int columnCount(Row row) { return hnd.columnCount(row); } } @@ -1344,8 +1353,9 @@ public ListAggAccumulator(AggregateCall aggCall, RowHandler hnd) { if (builder == null) builder = new StringBuilder(); - if (builder.length() != 0) + if (!builder.isEmpty()) builder.append(extractSeparator(row)); + builder.append(val); } diff --git a/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/OperatorsExtensionIntegrationTest.java b/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/OperatorsExtensionIntegrationTest.java index 53b6d23bcf14f..d08259d0877d6 100644 --- a/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/OperatorsExtensionIntegrationTest.java +++ b/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/OperatorsExtensionIntegrationTest.java @@ -18,12 +18,18 @@ import java.math.BigDecimal; import java.sql.Timestamp; +import java.util.List; +import java.util.function.Supplier; import com.google.common.collect.ImmutableList; import org.apache.calcite.adapter.enumerable.NullPolicy; import org.apache.calcite.avatica.util.TimeUnitRange; import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.plan.Contexts; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; @@ -43,12 +49,19 @@ import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.util.BuiltInMethod; +import org.apache.calcite.util.Optionality; import org.apache.ignite.configuration.IgniteConfiguration; import org.apache.ignite.internal.processors.query.calcite.CalciteQueryProcessor; +import org.apache.ignite.internal.processors.query.calcite.exec.ExecutionContext; +import org.apache.ignite.internal.processors.query.calcite.exec.RowHandler; import org.apache.ignite.internal.processors.query.calcite.exec.exp.RexImpTable; +import org.apache.ignite.internal.processors.query.calcite.exec.exp.agg.Accumulator; +import org.apache.ignite.internal.processors.query.calcite.exec.exp.agg.AccumulatorFactoryProvider; +import org.apache.ignite.internal.processors.query.calcite.exec.exp.agg.Accumulators; import org.apache.ignite.internal.processors.query.calcite.prepare.IgniteConvertletTable; import org.apache.ignite.internal.processors.query.calcite.prepare.IgniteSqlNodeRewriter; import org.apache.ignite.internal.processors.query.calcite.prepare.IgniteSqlValidator; +import org.apache.ignite.internal.processors.query.calcite.type.IgniteTypeFactory; import org.apache.ignite.plugin.AbstractTestPluginProvider; import org.apache.ignite.plugin.PluginContext; import org.jetbrains.annotations.Nullable; @@ -75,6 +88,9 @@ public class OperatorsExtensionIntegrationTest extends AbstractBasicIntegrationT .sqlValidatorConfig( ((IgniteSqlValidator.Config)CalciteQueryProcessor.FRAMEWORK_CONFIG.getSqlValidatorConfig()) .withSqlNodeRewriter(new SqlRewriter())) + .context(Contexts.chain( + CalciteQueryProcessor.FRAMEWORK_CONFIG.getContext(), + Contexts.of(new AccumulatorFactoryProviderImpl()))) .build(); return (T)cfg; @@ -134,6 +150,14 @@ public void testOperatorsCallsInViews() { assertQuery("SELECT val_str from my_view").returns(new BigDecimal("0")).check(); } + /** */ + @Test + public void testCustomAggregateFunction() { + assertQuery("SELECT TEST_SUM(x) FROM (VALUES (1), (2), (3)) t(x)") + .returns(6L) + .check(); + } + /** Rewrites LTRIM with 2 parameters. */ public static SqlCall rewriteLtrim(SqlValidator validator, SqlCall call) { if (call.operandCount() != 2) @@ -193,6 +217,9 @@ public static class OperatorTable extends ReflectiveSqlOperatorTable { OperandTypes.STRING_STRING, SqlFunctionCategory.STRING ); + + /** */ + public static final SqlAggFunction TEST_SUM = new SqlTestSumAggFunction(); } /** Extended convertlet table. */ @@ -229,4 +256,73 @@ private static class SqlRewriter implements IgniteSqlNodeRewriter { return node; } } + + /** */ + private static class AccumulatorFactoryProviderImpl implements AccumulatorFactoryProvider { + /** {@inheritDoc} */ + @Override public @Nullable Supplier> factory(AggregateCall call, ExecutionContext ctx) { + if (call.getAggregation().getName().equals(OperatorTable.TEST_SUM.getName())) + return () -> new TestSum<>(call, ctx.rowHandler()); + + return null; + } + } + + /** */ + public static class SqlTestSumAggFunction extends SqlAggFunction { + /** */ + public SqlTestSumAggFunction() { + super( + "TEST_SUM", + null, + SqlKind.SUM, + ReturnTypes.AGG_SUM, + null, + OperandTypes.NUMERIC, + SqlFunctionCategory.NUMERIC, + false, + false, + Optionality.FORBIDDEN + ); + } + } + + /** */ + private static class TestSum extends Accumulators.AbstractAccumulator { + /** */ + private long sum; + + /** */ + protected TestSum(AggregateCall aggCall, RowHandler hnd) { + super(aggCall, hnd); + } + + /** {@inheritDoc} */ + @Override public void add(Row row) { + Number val = get(0, row); + + if (val != null) + sum += val.longValue(); + } + + /** {@inheritDoc} */ + @Override public void apply(Accumulator other) { + sum += ((TestSum)other).sum; + } + + /** {@inheritDoc} */ + @Override public Object end() { + return sum; + } + + /** {@inheritDoc} */ + @Override public List argumentTypes(IgniteTypeFactory typeFactory) { + return List.of(typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true)); + } + + /** {@inheritDoc} */ + @Override public RelDataType returnType(IgniteTypeFactory typeFactory) { + return typeFactory.createSqlType(org.apache.calcite.sql.type.SqlTypeName.BIGINT); + } + } } diff --git a/modules/core/src/main/java/org/apache/ignite/internal/processors/query/QueryContext.java b/modules/core/src/main/java/org/apache/ignite/internal/processors/query/QueryContext.java index a9dbe02255ece..ac7dfc2eb90e0 100644 --- a/modules/core/src/main/java/org/apache/ignite/internal/processors/query/QueryContext.java +++ b/modules/core/src/main/java/org/apache/ignite/internal/processors/query/QueryContext.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; import org.apache.ignite.internal.util.typedef.F; +import org.jetbrains.annotations.Nullable; /** */ public final class QueryContext { @@ -36,10 +37,10 @@ private QueryContext(Object[] params) { } /** - * Finds an instance of an interface implemented by this object, - * or returns null if this object does not support that interface. + * Finds an instance of an interface implemented by this object + * or returns {@code null} if this object does not support that interface. */ - public C unwrap(Class aClass) { + public @Nullable C unwrap(Class aClass) { if (Object[].class == aClass) return aClass.cast(params); @@ -50,12 +51,12 @@ public C unwrap(Class aClass) { * @param params Context parameters. * @return Query context. */ - public static QueryContext of(Object... params) { + public static QueryContext of(@Nullable Object... params) { return !F.isEmpty(params) ? new QueryContext(build(null, params).toArray()) : new QueryContext(EMPTY); } /** */ - private static List build(List dst, Object[] src) { + private static List build(List dst, @Nullable Object[] src) { if (dst == null) dst = new ArrayList<>();