diff --git a/be/src/exprs/aggregate/aggregate_function_window.h b/be/src/exprs/aggregate/aggregate_function_window.h index d491a1db99e81f..243f268df31461 100644 --- a/be/src/exprs/aggregate/aggregate_function_window.h +++ b/be/src/exprs/aggregate/aggregate_function_window.h @@ -623,14 +623,15 @@ struct WindowFunctionNthValueImpl : Data { this->_frame_total_rows ? this->_frame_start_pose : real_frame_start; this->_frame_total_rows += real_frame_end - real_frame_start; int64_t offset = assert_cast(*columns[1]) - .get_data()[0] - - 1; - if (offset >= this->_frame_total_rows) { + .get_data()[0]; + DCHECK_NE(offset, 0); + int64_t row_position = offset > 0 ? offset - 1 : this->_frame_total_rows + offset; + if (row_position < 0 || row_position >= this->_frame_total_rows) { // offset is beyond the frame, so set null this->set_is_null(); return; } - this->set_value(columns, offset + this->_frame_start_pose); + this->set_value(columns, row_position + this->_frame_start_pose); } static const char* name() { return "nth_value"; } diff --git a/be/test/exprs/aggregate/agg_window_nth_value_test.cpp b/be/test/exprs/aggregate/agg_window_nth_value_test.cpp new file mode 100644 index 00000000000000..61ac798ba56662 --- /dev/null +++ b/be/test/exprs/aggregate/agg_window_nth_value_test.cpp @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "core/column/column_nullable.h" +#include "core/column/column_string.h" +#include "core/column/column_vector.h" +#include "core/data_type/data_type_number.h" +#include "core/data_type/data_type_string.h" +#include "exprs/aggregate/aggregate_function.h" +#include "exprs/aggregate/aggregate_function_simple_factory.h" + +namespace doris { + +void register_aggregate_function_window_lead_lag_first_last( + AggregateFunctionSimpleFactory& factory); + +TEST(AggregateWindowNthValueTest, UpperBoundedLowerUnboundedFrame) { + AggregateFunctionSimpleFactory factory; + register_aggregate_function_window_lead_lag_first_last(factory); + + DataTypes argument_types = {std::make_shared(), + std::make_shared()}; + auto function = factory.get("nth_value", argument_types, nullptr, true, -1, + {.is_window_function = true, .column_names = {}}); + ASSERT_NE(function, nullptr); + + auto value_column = ColumnString::create(); + value_column->insert_data("C", 1); + value_column->insert_data("B", 1); + value_column->insert_data("A", 1); + + auto offset_column = ColumnInt64::create(); + offset_column->insert_value(-2); + + const IColumn* columns[] = {value_column.get(), offset_column.get()}; + + Arena arena; + auto* place = reinterpret_cast(arena.alloc(function->size_of_data())); + function->create(place); + + auto result_column = ColumnNullable::create(ColumnString::create(), ColumnUInt8::create()); + UInt8 use_null_result = false; + UInt8 could_use_previous_result = false; + + function->add_range_single_place(0, 3, 0, 1, place, columns, arena, &use_null_result, + &could_use_previous_result); + function->insert_result_into(place, *result_column); + + function->add_range_single_place(0, 3, 1, 2, place, columns, arena, &use_null_result, + &could_use_previous_result); + function->insert_result_into(place, *result_column); + + function->add_range_single_place(0, 3, 2, 3, place, columns, arena, &use_null_result, + &could_use_previous_result); + function->insert_result_into(place, *result_column); + + ASSERT_EQ(result_column->size(), 3); + EXPECT_TRUE(result_column->is_null_at(0)); + EXPECT_EQ(result_column->get_data_at(1).to_string(), "C"); + EXPECT_EQ(result_column->get_data_at(2).to_string(), "B"); + + function->destroy(place); +} + +} // namespace doris diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java index 1a8cc394adba03..20d59e9bc4d74a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.WindowFrame; import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameBoundType; @@ -38,11 +39,14 @@ import org.apache.doris.nereids.trees.expressions.functions.window.PercentRank; import org.apache.doris.nereids.trees.expressions.functions.window.Rank; import org.apache.doris.nereids.trees.expressions.functions.window.RowNumber; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Optional; @@ -456,12 +460,21 @@ private void checkWindowFrameAfterFunc(WindowFrame wf) { // e.g. (3 preceding, unbounded following) -> (unbounded preceding, 3 following) windowExpression = windowExpression.withWindowFrame(wf.reverseWindow()); - // reverse WindowFunction, which is used only for first_value() and last_value() + // adjust window functions whose result depends on the order within the frame. Expression windowFunction = windowExpression.getFunction(); if (windowFunction instanceof FirstOrLastValue) { - // windowExpression = windowExpression.withChildren( - // ImmutableList.of(((FirstOrLastValue) windowFunction).reverse())); windowExpression = windowExpression.withFunction(((FirstOrLastValue) windowFunction).reverse()); + } else if (windowFunction instanceof NthValue) { + NthValue nthValue = (NthValue) windowFunction; + Expression reversedOffset; + Expression offset = nthValue.getArgument(1); + if (offset instanceof BigIntLiteral) { + reversedOffset = new BigIntLiteral(-((BigIntLiteral) offset).getValue()); + } else { + reversedOffset = new Subtract(new IntegerLiteral(0), nthValue.child(1)); + } + windowExpression = windowExpression.withFunction( + nthValue.withChildren(ImmutableList.of(nthValue.child(0), reversedOffset))); } } } diff --git a/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out b/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out index 58fdaad0ee7ceb..94e99985fe130e 100644 --- a/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out +++ b/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out @@ -262,3 +262,8 @@ 11 true 11 15 true 15 +-- !select_upper_bounded -- +1 B B +2 C C +3 \N \N + diff --git a/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy b/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy index 71f96b1dadce70..b9172a1d4816f7 100644 --- a/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy @@ -113,8 +113,36 @@ suite("test_nthvalue_function") { qt_select_14 " SELECT k1,k6, nth_value(k1, 1) OVER(PARTITION BY k6 ORDER BY k1 RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM baseall order by k6,k1; " qt_select_15 "SELECT k1, k6, nth_value(k1, 5) OVER(PARTITION BY k6 ORDER BY k1 ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING) FROM baseall order by k6,k1; " qt_select_16 "SELECT k1, k6, nth_value(k1, 4) OVER(PARTITION BY k6 ORDER BY k1 ROWS BETWEEN 3 PRECEDING AND CURRENT ROW) FROM baseall order by k6,k1; " -} + sql "DROP TABLE IF EXISTS test_nthvalue_upper_bounded" + sql """ + CREATE TABLE test_nthvalue_upper_bounded ( + seq int, + v varchar(10) + ) DUPLICATE KEY(seq) + DISTRIBUTED BY HASH(seq) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ + INSERT INTO test_nthvalue_upper_bounded VALUES + (1, 'A'), + (2, 'B'), + (3, 'C') + """ + qt_select_upper_bounded """ + SELECT + seq, + nth_value(v, 2) OVER ( + ORDER BY seq + ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + ) AS actual, + lead(v, 1, NULL) OVER (ORDER BY seq) AS expected + FROM test_nthvalue_upper_bounded + ORDER BY seq + """ +}