From f1c082fa6adfff37434a23c32ad4bc0c4eea3837 Mon Sep 17 00:00:00 2001 From: morrySnow Date: Fri, 26 Jun 2026 17:38:01 +0800 Subject: [PATCH] [fix](be) Fix nth_value for upper bounded windows (#64864) ### What problem does this PR solve? Related PR: #50559 Problem Summary: nth_value over an upper-bounded/lower-unbounded window frame such as ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING is normalized by reversing the order and frame, then evaluating nth_value with a negative offset over a cumulative frame. The BE nth_value window state replaced the tracked frame row count on each range update, so later rows in the cumulative frame could address the wrong row or return NULL. The fix keeps the cumulative frame row count across range updates, and FE preserves literal offset typing when negating bigint nth_value arguments. This adds BE unit coverage for the reversed cumulative execution path and a regression case comparing nth_value to lead for the original SQL frame. ### Release note Fix nth_value results for ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING window frames. --- .../aggregate/aggregate_function_window.h | 9 ++- .../aggregate/agg_window_nth_value_test.cpp | 81 +++++++++++++++++++ .../rules/analysis/WindowFunctionChecker.java | 18 ++++- .../test_nthvalue_function.out | 5 ++ .../test_nthvalue_function.groovy | 30 ++++++- 5 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 be/test/exprs/aggregate/agg_window_nth_value_test.cpp diff --git a/be/src/exprs/aggregate/aggregate_function_window.h b/be/src/exprs/aggregate/aggregate_function_window.h index 9b654c77ff5eb1..10de5cb51ef854 100644 --- a/be/src/exprs/aggregate/aggregate_function_window.h +++ b/be/src/exprs/aggregate/aggregate_function_window.h @@ -624,14 +624,15 @@ struct WindowFunctionNthValueImpl : Data { this->_frame_total_rows ? this->_frame_start_pose : real_frame_start; this->_frame_total_rows += real_frame_end - real_frame_start; int64_t offset = assert_cast(*columns[1]) - .get_data()[0] - - 1; - if (offset >= this->_frame_total_rows) { + .get_data()[0]; + DCHECK_NE(offset, 0); + int64_t row_position = offset > 0 ? offset - 1 : this->_frame_total_rows + offset; + if (row_position < 0 || row_position >= this->_frame_total_rows) { // offset is beyond the frame, so set null this->set_is_null(); return; } - this->set_value(columns, offset + this->_frame_start_pose); + this->set_value(columns, row_position + this->_frame_start_pose); } static const char* name() { return "nth_value"; } diff --git a/be/test/exprs/aggregate/agg_window_nth_value_test.cpp b/be/test/exprs/aggregate/agg_window_nth_value_test.cpp new file mode 100644 index 00000000000000..61ac798ba56662 --- /dev/null +++ b/be/test/exprs/aggregate/agg_window_nth_value_test.cpp @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "core/column/column_nullable.h" +#include "core/column/column_string.h" +#include "core/column/column_vector.h" +#include "core/data_type/data_type_number.h" +#include "core/data_type/data_type_string.h" +#include "exprs/aggregate/aggregate_function.h" +#include "exprs/aggregate/aggregate_function_simple_factory.h" + +namespace doris { + +void register_aggregate_function_window_lead_lag_first_last( + AggregateFunctionSimpleFactory& factory); + +TEST(AggregateWindowNthValueTest, UpperBoundedLowerUnboundedFrame) { + AggregateFunctionSimpleFactory factory; + register_aggregate_function_window_lead_lag_first_last(factory); + + DataTypes argument_types = {std::make_shared(), + std::make_shared()}; + auto function = factory.get("nth_value", argument_types, nullptr, true, -1, + {.is_window_function = true, .column_names = {}}); + ASSERT_NE(function, nullptr); + + auto value_column = ColumnString::create(); + value_column->insert_data("C", 1); + value_column->insert_data("B", 1); + value_column->insert_data("A", 1); + + auto offset_column = ColumnInt64::create(); + offset_column->insert_value(-2); + + const IColumn* columns[] = {value_column.get(), offset_column.get()}; + + Arena arena; + auto* place = reinterpret_cast(arena.alloc(function->size_of_data())); + function->create(place); + + auto result_column = ColumnNullable::create(ColumnString::create(), ColumnUInt8::create()); + UInt8 use_null_result = false; + UInt8 could_use_previous_result = false; + + function->add_range_single_place(0, 3, 0, 1, place, columns, arena, &use_null_result, + &could_use_previous_result); + function->insert_result_into(place, *result_column); + + function->add_range_single_place(0, 3, 1, 2, place, columns, arena, &use_null_result, + &could_use_previous_result); + function->insert_result_into(place, *result_column); + + function->add_range_single_place(0, 3, 2, 3, place, columns, arena, &use_null_result, + &could_use_previous_result); + function->insert_result_into(place, *result_column); + + ASSERT_EQ(result_column->size(), 3); + EXPECT_TRUE(result_column->is_null_at(0)); + EXPECT_EQ(result_column->get_data_at(1).to_string(), "C"); + EXPECT_EQ(result_column->get_data_at(2).to_string(), "B"); + + function->destroy(place); +} + +} // namespace doris diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java index cbc5061eaccf6a..9142cce8e0f82c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.WindowFrame; import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameBoundType; @@ -38,7 +39,9 @@ import org.apache.doris.nereids.trees.expressions.functions.window.PercentRank; import org.apache.doris.nereids.trees.expressions.functions.window.Rank; import org.apache.doris.nereids.trees.expressions.functions.window.RowNumber; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -476,12 +479,21 @@ private void checkWindowFrameAfterFunc(WindowFrame wf) { // e.g. (3 preceding, unbounded following) -> (unbounded preceding, 3 following) windowExpression = windowExpression.withWindowFrame(wf.reverseWindow()); - // reverse WindowFunction, which is used only for first_value() and last_value() + // adjust window functions whose result depends on the order within the frame. Expression windowFunction = windowExpression.getFunction(); if (windowFunction instanceof FirstOrLastValue) { - // windowExpression = windowExpression.withChildren( - // ImmutableList.of(((FirstOrLastValue) windowFunction).reverse())); windowExpression = windowExpression.withFunction(((FirstOrLastValue) windowFunction).reverse()); + } else if (windowFunction instanceof NthValue) { + NthValue nthValue = (NthValue) windowFunction; + Expression reversedOffset; + Expression offset = nthValue.getArgument(1); + if (offset instanceof BigIntLiteral) { + reversedOffset = new BigIntLiteral(-((BigIntLiteral) offset).getValue()); + } else { + reversedOffset = new Subtract(new IntegerLiteral(0), nthValue.child(1)); + } + windowExpression = windowExpression.withFunction( + nthValue.withChildren(ImmutableList.of(nthValue.child(0), reversedOffset))); } } } diff --git a/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out b/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out index 58fdaad0ee7ceb..94e99985fe130e 100644 --- a/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out +++ b/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out @@ -262,3 +262,8 @@ 11 true 11 15 true 15 +-- !select_upper_bounded -- +1 B B +2 C C +3 \N \N + diff --git a/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy b/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy index 71f96b1dadce70..b9172a1d4816f7 100644 --- a/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy @@ -113,8 +113,36 @@ suite("test_nthvalue_function") { qt_select_14 " SELECT k1,k6, nth_value(k1, 1) OVER(PARTITION BY k6 ORDER BY k1 RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM baseall order by k6,k1; " qt_select_15 "SELECT k1, k6, nth_value(k1, 5) OVER(PARTITION BY k6 ORDER BY k1 ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING) FROM baseall order by k6,k1; " qt_select_16 "SELECT k1, k6, nth_value(k1, 4) OVER(PARTITION BY k6 ORDER BY k1 ROWS BETWEEN 3 PRECEDING AND CURRENT ROW) FROM baseall order by k6,k1; " -} + sql "DROP TABLE IF EXISTS test_nthvalue_upper_bounded" + sql """ + CREATE TABLE test_nthvalue_upper_bounded ( + seq int, + v varchar(10) + ) DUPLICATE KEY(seq) + DISTRIBUTED BY HASH(seq) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ + INSERT INTO test_nthvalue_upper_bounded VALUES + (1, 'A'), + (2, 'B'), + (3, 'C') + """ + qt_select_upper_bounded """ + SELECT + seq, + nth_value(v, 2) OVER ( + ORDER BY seq + ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + ) AS actual, + lead(v, 1, NULL) OVER (ORDER BY seq) AS expected + FROM test_nthvalue_upper_bounded + ORDER BY seq + """ +}