From 2472cf446d14578bf01934bb5e3c829d24ca344c Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 29 Apr 2026 23:26:46 -0700 Subject: [PATCH 1/7] [SPARK-56395][SQL] Add NEAREST BY top-K ranking join (catalyst-side) This PR introduces the SQL grammar, logical plan, analyzer checks, and optimizer rewrite for the new `NEAREST BY` clause. The DataFrame / PySpark / Spark Connect API surface is split into a follow-up PR. --- .../resources/error/error-conditions.json | 48 ++++ docs/sql-ref-ansi-compliance.md | 5 + docs/sql-ref-syntax-qry-select-join.md | 26 +- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 5 + .../sql/catalyst/parser/SqlBaseParser.g4 | 16 +- .../spark/sql/errors/QueryParsingErrors.scala | 25 ++ .../sql/catalyst/analysis/CheckAnalysis.scala | 34 +++ .../analysis/DeduplicateRelations.scala | 7 +- .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../optimizer/RewriteNearestByJoin.scala | 125 +++++++++ .../sql/catalyst/parser/AstBuilder.scala | 86 ++++--- .../spark/sql/catalyst/plans/joinTypes.scala | 62 +++++ .../plans/logical/basicLogicalOperators.scala | 59 +++++ .../sql/catalyst/trees/TreePatterns.scala | 1 + .../optimizer/RewriteNearestByJoinSuite.scala | 115 +++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 139 ++++++++++ .../SparkConnectDatabaseMetaDataSuite.scala | 2 +- .../analyzer-results/join-nearest-by.sql.out | 238 ++++++++++++++++++ .../sql-tests/inputs/join-nearest-by.sql | 57 +++++ .../sql-tests/results/join-nearest-by.sql.out | 220 ++++++++++++++++ .../results/keywords-enforced.sql.out | 5 + .../sql-tests/results/keywords.sql.out | 5 + .../results/nonansi/keywords.sql.out | 5 + .../ThriftServerWithSparkContextSuite.scala | 2 +- 24 files changed, 1258 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 7822dc05502c0..850fae02f719a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5313,6 +5313,49 @@ ], "sqlState" : "0A000" }, + "NEAREST_BY_JOIN" : { + "message" : [ + "Invalid nearest-by join." + ], + "subClass" : { + "EXACT_WITH_NONDETERMINISTIC_EXPRESSION" : { + "message" : [ + "EXACT nearest-by join is incompatible with the nondeterministic ranking expression . Use APPROX, or replace the expression with a deterministic one." + ] + }, + "NON_ORDERABLE_RANKING_EXPRESSION" : { + "message" : [ + "The ranking expression of type is not orderable." + ] + }, + "NUM_RESULTS_OUT_OF_RANGE" : { + "message" : [ + "The number of results must be between and ." + ] + }, + "STREAMING_NOT_SUPPORTED" : { + "message" : [ + "NEAREST BY join is not supported with streaming DataFrames/Datasets." + ] + }, + "UNSUPPORTED_DIRECTION" : { + "message" : [ + "Unsupported nearest-by join direction ''. Supported nearest-by join directions include: ." + ] + }, + "UNSUPPORTED_JOIN_TYPE" : { + "message" : [ + "Unsupported nearest-by join type . Supported types: ." + ] + }, + "UNSUPPORTED_MODE" : { + "message" : [ + "Unsupported nearest-by join mode ''. Supported modes include: ." + ] + } + }, + "sqlState" : "42604" + }, "NEGATIVE_SCALE_DISALLOWED" : { "message" : [ "Negative scale is not allowed: ''. Set the config to \"true\" to allow it." @@ -7837,6 +7880,11 @@ "Referencing a lateral column alias in window expression ." ] }, + "LATERAL_JOIN_NEAREST_BY" : { + "message" : [ + "LATERAL correlation with NEAREST BY clause." + ] + }, "LATERAL_JOIN_USING" : { "message" : [ "JOIN USING with LATERAL correlation." diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 8621eca79a6c8..8542cd3d89865 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -420,6 +420,7 @@ Below is a list of all the keywords in Spark SQL. |ANTI|non-reserved|strict-non-reserved|non-reserved| |ANY|reserved|non-reserved|reserved| |ANY_VALUE|non-reserved|non-reserved|non-reserved| +|APPROX|non-reserved|non-reserved|non-reserved| |ARCHIVE|non-reserved|non-reserved|non-reserved| |ARRAY|non-reserved|non-reserved|reserved| |AS|reserved|non-reserved|reserved| @@ -515,6 +516,7 @@ Below is a list of all the keywords in Spark SQL. |DFS|non-reserved|non-reserved|non-reserved| |DIRECTORIES|non-reserved|non-reserved|non-reserved| |DIRECTORY|non-reserved|non-reserved|non-reserved| +|DISTANCE|non-reserved|non-reserved|non-reserved| |DISTINCT|reserved|non-reserved|reserved| |DISTRIBUTE|non-reserved|non-reserved|non-reserved| |DIV|non-reserved|non-reserved|not a keyword| @@ -528,6 +530,7 @@ Below is a list of all the keywords in Spark SQL. |ESCAPE|reserved|non-reserved|reserved| |ESCAPED|non-reserved|non-reserved|non-reserved| |EVOLUTION|non-reserved|non-reserved|non-reserved| +|EXACT|non-reserved|non-reserved|non-reserved| |EXCEPT|reserved|strict-non-reserved|reserved| |EXCHANGE|non-reserved|non-reserved|non-reserved| |EXCLUDE|non-reserved|non-reserved|non-reserved| @@ -648,6 +651,7 @@ Below is a list of all the keywords in Spark SQL. |NANOSECOND|non-reserved|non-reserved|non-reserved| |NANOSECONDS|non-reserved|non-reserved|non-reserved| |NATURAL|reserved|strict-non-reserved|reserved| +|NEAREST|non-reserved|non-reserved|non-reserved| |NEXT|non-reserved|non-reserved|non-reserved| |NO|non-reserved|non-reserved|reserved| |NONE|non-reserved|non-reserved|reserved| @@ -738,6 +742,7 @@ Below is a list of all the keywords in Spark SQL. |SETS|non-reserved|non-reserved|non-reserved| |SHORT|non-reserved|non-reserved|non-reserved| |SHOW|non-reserved|non-reserved|non-reserved| +|SIMILARITY|non-reserved|non-reserved|non-reserved| |SINGLE|non-reserved|non-reserved|non-reserved| |SKEWED|non-reserved|non-reserved|non-reserved| |SMALLINT|non-reserved|non-reserved|reserved| diff --git a/docs/sql-ref-syntax-qry-select-join.md b/docs/sql-ref-syntax-qry-select-join.md index 698884dc28b57..68fb6eda9353e 100644 --- a/docs/sql-ref-syntax-qry-select-join.md +++ b/docs/sql-ref-syntax-qry-select-join.md @@ -26,7 +26,7 @@ A SQL join is used to combine rows from two relations based on join criteria. Th ### Syntax ```sql -relation { [ join_type ] JOIN [ LATERAL ] relation [ join_criteria ] | NATURAL join_type JOIN [ LATERAL ] relation } +relation { [ join_type ] JOIN [ LATERAL ] relation [ join_criteria | nearest_by_clause ] | NATURAL join_type JOIN [ LATERAL ] relation } ``` ### Parameters @@ -53,6 +53,30 @@ relation { [ join_type ] JOIN [ LATERAL ] relation [ join_criteria ] | NATURAL j Specifies an expression with a return type of boolean. +* **nearest_by_clause** + + Specifies a nearest-by top-K ranking join. For each row on the left (query side), returns up to `num_results` rows from the right (base side), ranked by `ranking_expression`. Only `INNER` (the default) and `LEFT OUTER` join types are supported with this clause. + + **Syntax:** `{ APPROX | EXACT } NEAREST [ num_results ] BY { DISTANCE | SIMILARITY } ranking_expression` + + `APPROX | EXACT` + + Controls the search algorithm contract. `APPROX` allows the optimizer to use faster approximate strategies (such as indexed nearest-neighbor search when available). `EXACT` forces brute-force evaluation and requires `ranking_expression` to be deterministic. + + `num_results` + + A positive integer literal between 1 and 100000 that limits the number of matches per left row. Defaults to 1 when omitted. + + `DISTANCE | SIMILARITY` + + `DISTANCE` ranks rows by smallest value of `ranking_expression` first. `SIMILARITY` ranks rows by largest value first. + + `ranking_expression` + + A scalar expression that returns an orderable type. + + **Performance note.** The current implementation evaluates the full cross-product of the left and right sides and bounds memory per left row by `num_results`. Per-query work is `O(|left| × |right| × log num_results)`. Index-backed approximate strategies (transparent to `APPROX` queries) are planned in a future release; until then, pre-filter the right side (e.g. via a subquery) when it is large. + ### Join Types #### **Inner Join** diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 59a0034f922e4..f4834b4ecf623 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -139,6 +139,7 @@ AND: 'AND'; ANTI: 'ANTI'; ANY: 'ANY'; ANY_VALUE: 'ANY_VALUE'; +APPROX: 'APPROX'; ARCHIVE: 'ARCHIVE'; ARRAY: 'ARRAY' {incComplexTypeLevelCounter();}; AS: 'AS'; @@ -234,6 +235,7 @@ DETERMINISTIC: 'DETERMINISTIC'; DFS: 'DFS'; DIRECTORIES: 'DIRECTORIES'; DIRECTORY: 'DIRECTORY'; +DISTANCE: 'DISTANCE'; DISTINCT: 'DISTINCT'; DISTRIBUTE: 'DISTRIBUTE'; DIV: 'DIV'; @@ -247,6 +249,7 @@ ENFORCED: 'ENFORCED'; ESCAPE: 'ESCAPE'; ESCAPED: 'ESCAPED'; EVOLUTION: 'EVOLUTION'; +EXACT: 'EXACT'; EXCEPT: 'EXCEPT'; EXCHANGE: 'EXCHANGE'; EXCLUDE: 'EXCLUDE'; @@ -366,6 +369,7 @@ NAMESPACES: 'NAMESPACES'; NANOSECOND: 'NANOSECOND'; NANOSECONDS: 'NANOSECONDS'; NATURAL: 'NATURAL'; +NEAREST: 'NEAREST'; NEXT: 'NEXT'; NO: 'NO'; NONE: 'NONE'; @@ -456,6 +460,7 @@ SETMINUS: 'MINUS'; SETS: 'SETS'; SHORT: 'SHORT'; SHOW: 'SHOW'; +SIMILARITY: 'SIMILARITY'; SINGLE: 'SINGLE'; SKEWED: 'SKEWED'; SMALLINT: 'SMALLINT'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 1a0382dbe10c4..735921681cdcd 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1049,7 +1049,7 @@ relationExtension ; joinRelation - : (joinType) JOIN LATERAL? right=relationPrimary joinCriteria? + : (joinType) JOIN LATERAL? right=relationPrimary (joinCriteria | nearestByClause)? | NATURAL joinType JOIN LATERAL? right=relationPrimary ; @@ -1068,6 +1068,10 @@ joinCriteria | USING identifierList ; +nearestByClause + : (APPROX | EXACT) NEAREST num=INTEGER_VALUE? BY (DISTANCE | SIMILARITY) expression + ; + sample : TABLESAMPLE LEFT_PAREN sampleMethod? RIGHT_PAREN (REPEATABLE LEFT_PAREN seed=integerValue RIGHT_PAREN)? ; @@ -1930,6 +1934,7 @@ ansiNonReserved | ANALYZE | ANTI | ANY_VALUE + | APPROX | ARCHIVE | ARRAY | ASC @@ -2006,6 +2011,7 @@ ansiNonReserved | DFS | DIRECTORIES | DIRECTORY + | DISTANCE | DISTRIBUTE | DIV | DO @@ -2015,6 +2021,7 @@ ansiNonReserved | ENFORCED | ESCAPED | EVOLUTION + | EXACT | EXCHANGE | EXCLUDE | EXCLUSIVE @@ -2112,6 +2119,7 @@ ansiNonReserved | NAMESPACES | NANOSECOND | NANOSECONDS + | NEAREST | NEXT | NO | NONE @@ -2187,6 +2195,7 @@ ansiNonReserved | SETS | SHORT | SHOW + | SIMILARITY | SINGLE | SKEWED | SMALLINT @@ -2303,6 +2312,7 @@ nonReserved | AND | ANY | ANY_VALUE + | APPROX | ARCHIVE | ARRAY | AS @@ -2398,6 +2408,7 @@ nonReserved | DFS | DIRECTORIES | DIRECTORY + | DISTANCE | DISTINCT | DISTRIBUTE | DIV @@ -2411,6 +2422,7 @@ nonReserved | ESCAPE | ESCAPED | EVOLUTION + | EXACT | EXCHANGE | EXCLUDE | EXCLUSIVE @@ -2523,6 +2535,7 @@ nonReserved | NAMESPACES | NANOSECOND | NANOSECONDS + | NEAREST | NEXT | NO | NONE @@ -2609,6 +2622,7 @@ nonReserved | SETS | SHORT | SHOW + | SIMILARITY | SINGLE | SKEWED | SMALLINT diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 69ff4c9cd108f..9a7e833342c20 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -203,6 +203,31 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx) } + def nearestByJoinWithLateralUnsupportedError(ctx: ParserRuleContext): Throwable = { + new ParseException( + errorClass = "UNSUPPORTED_FEATURE.LATERAL_JOIN_NEAREST_BY", + messageParameters = Map.empty, + ctx) + } + + def unsupportedNearestByJoinTypeError(ctx: ParserRuleContext, joinType: String): Throwable = { + new ParseException( + errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", + messageParameters = + Map("joinType" -> toSQLStmt(joinType), "supported" -> "'INNER', 'LEFT OUTER'"), + ctx) + } + + def nearestByJoinNumResultsOutOfRangeError( + ctx: ParserRuleContext, + numResults: String, + max: Int): Throwable = { + new ParseException( + errorClass = "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + messageParameters = Map("numResults" -> numResults, "min" -> "1", "max" -> max.toString), + ctx) + } + def repetitiveWindowDefinitionError(name: String, ctx: WindowClauseContext): Throwable = { new ParseException( errorClass = "INVALID_SQL_SYNTAX.REPETITIVE_WINDOW_DEFINITION", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b923d442e6d98..19d1375d962ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -657,6 +657,29 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString messageParameters = Map.empty) } + // Reject streaming inputs early. The optimizer rewrite introduces + // `MonotonicallyIncreasingID()`, which is per-batch only and would silently produce + // incorrect results across micro-batches; failing at analysis time is clearer than + // letting the streaming check fire on an incidental MID node. + case j: NearestByJoin if j.isStreaming => + j.failAnalysis( + errorClass = "NEAREST_BY_JOIN.STREAMING_NOT_SUPPORTED", + messageParameters = Map.empty) + + case j @ NearestByJoin(_, _, _, _, _, rankingExpression, _) + if !RowOrdering.isOrderable(rankingExpression.dataType) => + j.failAnalysis( + errorClass = "NEAREST_BY_JOIN.NON_ORDERABLE_RANKING_EXPRESSION", + messageParameters = Map( + "expression" -> toSQLExpr(rankingExpression), + "type" -> toSQLType(rankingExpression.dataType))) + + case j @ NearestByJoin(_, _, _, false, _, rankingExpression, _) + if !rankingExpression.deterministic => + j.failAnalysis( + errorClass = "NEAREST_BY_JOIN.EXACT_WITH_NONDETERMINISTIC_EXPRESSION", + messageParameters = Map("expression" -> toSQLExpr(rankingExpression))) + case a: Aggregate => a.groupingExpressions.foreach( expression => @@ -949,6 +972,17 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString context = j.origin.getQueryContext, summary = j.origin.context.summary) + case j: NearestByJoin if !j.duplicateResolved => + val conflictingAttributes = + j.left.outputSet.intersect(j.right.outputSet).map(toSQLExpr(_)).mkString(", ") + throw SparkException.internalError( + msg = s""" + |Failure when resolving conflicting references in ${j.nodeName}: + |${planToString(plan)} + |Conflicting attributes: $conflictingAttributes.""".stripMargin, + context = j.origin.getQueryContext, + summary = j.origin.context.summary) + // TODO: although map type is not orderable, technically map type should be able to be // used in equality comparison, remove this type check once we support it. case o if mapColumnInSetOperation(o).isDefined => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 2a2440117e401..ec2ba4f692216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -36,7 +36,8 @@ object DeduplicateRelations extends Rule[LogicalPlan] { def noMissingInput(p: LogicalPlan) = !p.exists(_.missingInput.nonEmpty) newPlan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(JOIN, LATERAL_JOIN, AS_OF_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), + _.containsAnyPattern( + JOIN, LATERAL_JOIN, AS_OF_JOIN, NEAREST_BY_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), ruleId) { case p: LogicalPlan if !p.childrenResolved => p // To resolve duplicate expression IDs for Join. @@ -50,6 +51,10 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case j @ AsOfJoin(left, right, _, _, _, _, _) if !j.duplicateResolved && noMissingInput(right) => j.copy(right = dedupRight(left, right)) + // Resolve duplicate output for NearestByJoin. + case j @ NearestByJoin(left, right, _, _, _, _, _) + if !j.duplicateResolved && noMissingInput(right) => + j.copy(right = dedupRight(left, right)) // intersect/except will be rewritten to join at the beginning of optimizer. Here we need to // deduplicate the right side plan, so that we won't produce an invalid self-join later. case i @ Intersect(left, right, _) if !i.duplicateResolved && noMissingInput(right) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 25e7479d8897a..618940ac10684 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -339,6 +339,7 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceCurrentLike(catalogManager), SpecialDatetimeValues, RewriteAsOfJoin, + RewriteNearestByJoin, EvalInlineTables, ReplaceTranspose, RewriteCollationJoin @@ -2545,8 +2546,11 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { if (conf.crossJoinEnabled) { plan } else plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, OUTER_JOIN)) { + // Skip joins that were synthesized by `RewriteNearestByJoin`: the cross product is an + // intentional, bounded part of that rewrite (see `NearestByJoin.SYNTHETIC_JOIN_TAG`). case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _) - if isCartesianProduct(j) => + if isCartesianProduct(j) && + j.getTagValue(NearestByJoin.SYNTHETIC_JOIN_TAG).isEmpty => throw QueryCompilationErrors.joinConditionMissingOrTrivialError(j, left, right) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala new file mode 100644 index 0000000000000..0d4d45ad44ec3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +/** + * Replaces a logical [[NearestByJoin]] operator with a `Generate(Inline(...))` over an + * `Aggregate` that tags each left row with a unique id, cross-joins with the right side, and + * groups by the unique id to compute the top-K matches via `MAX_BY`/`MIN_BY` (K-overload). + * + * Input Pseudo-Query: + * {{{ + * SELECT * FROM left [INNER | LEFT OUTER] JOIN right + * {APPROX | EXACT} NEAREST k BY {DISTANCE | SIMILARITY} expr + * }}} + * + * Rewritten Plan (SIMILARITY, INNER join type): + * {{{ + * Generate inline(_matches), [N], outer=false, [right.col1, right.col2, ...] + * +- Aggregate [__qid], + * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS left.colN-1, + * max_by(struct(right.*), expr, k) AS _matches] + * +- Join Inner + * :- Project [left.*, monotonically_increasing_id() AS __qid] + * : +- left + * +- right + * }}} + * + * For `DISTANCE`, `MIN_BY` is used instead of `MAX_BY`. For `LEFT OUTER`, the `Generate` is + * constructed with `outer = true` so left rows with no matches (empty/null `_matches`) are + * preserved with `NULL` right-side columns. + * + * In this initial implementation both `APPROX` and `EXACT` take the same brute-force rewrite + * path. `APPROX` establishes the contract for future indexed-ANN strategies. + */ +object RewriteNearestByJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput { + case j @ NearestByJoin(left, right, joinType, _, numResults, rankingExpression, direction) => + // 1. Tag each left row with a unique id so that rows from the same left row can later be + // grouped together after the cross-join with `right`. + val qidAlias = Alias(MonotonicallyIncreasingID(), "__qid")() + val taggedLeft = Project(left.output :+ qidAlias, left) + val qidAttr = qidAlias.toAttribute + + // 2. LEFT OUTER-join the tagged left with right (no join condition). LEFT OUTER + // (rather than INNER) preserves left rows even when `right` is empty, so that a + // `LEFT OUTER NEAREST BY` query still returns those rows with `NULL` right-side + // columns after the aggregate + inline below. When `right` is non-empty every left + // row already has right-row pairings, so LEFT OUTER and INNER are equivalent. + // + // Tag the join so `CheckCartesianProducts` skips it: the rewrite intentionally + // materializes a cross product bounded by the downstream `MaxMinByK` aggregate, so + // `spark.sql.crossJoin.enabled = false` should not reject user queries written as + // `NEAREST BY`. + val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE) + join.setTagValue(NearestByJoin.SYNTHETIC_JOIN_TAG, ()) + + // 3. Aggregate grouped by `__qid`: + // - first(col) for every left column so it flows to the output. + // - max_by/min_by(struct(right.*), ranking, k) as `_matches`. + // The ranking expression references left and right columns directly; no outer + // reference is needed because both sides are present in the joined input. + val rightStruct = CreateStruct(right.output) + // reverse = true -> MIN_BY (smallest ranking value first, for DISTANCE) + // reverse = false -> MAX_BY (largest ranking value first, for SIMILARITY) + val reverse = direction match { + case NearestByDistance => true + case NearestBySimilarity => false + } + val topK = MaxMinByK( + rightStruct, + rankingExpression, + Literal(numResults), + reverse = reverse).toAggregateExpression() + val matchesAlias = Alias(topK, "__nearest_matches__")() + + // Carry left columns through with `First`. Within a `__qid` group every row has the same + // left values (each group corresponds to one left row), so `First` is effectively a no-op. + // We use `First` rather than adding all left columns to the GROUP BY because grouping by + // `__qid` alone keeps the shuffle key small. + val firstLeftAggs = left.output.map { attr => + Alias( + First(attr, ignoreNulls = false).toAggregateExpression(), + attr.name)(exprId = attr.exprId, qualifier = attr.qualifier) + } + val aggregate = Aggregate(Seq(qidAttr), firstLeftAggs :+ matchesAlias, join) + + // 4. Generate inline(_matches) expands the K-element array into K rows, exposing each + // struct field as a top-level column. `outer = true` for LEFT OUTER preserves the + // left row with NULL right columns when there are no matches. + val generatorOutput = right.output.map { a => + AttributeReference(a.name, a.dataType, nullable = true)(qualifier = a.qualifier) + } + val generate = Generate( + Inline(matchesAlias.toAttribute), + unrequiredChildIndex = Seq(aggregate.output.indexOf(matchesAlias.toAttribute)), + outer = joinType == LeftOuter, + qualifier = None, + generatorOutput = generatorOutput, + child = aggregate) + + val attrMapping = j.output.zip(generate.output) + generate -> attrMapping + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index df83f558c892e..930fdc2d93176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2362,39 +2362,67 @@ class AstBuilder extends DataTypeAstBuilder } } - // Resolve the join type and join condition - val (joinType, condition) = Option(ctx.joinCriteria) match { - case Some(c) if c.USING != null => - if (ctx.LATERAL != null) { - throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx) + if (ctx.nearestByClause != null) { + if (ctx.LATERAL != null) { + throw QueryParsingErrors.nearestByJoinWithLateralUnsupportedError(ctx) + } + if (!Seq(Inner, LeftOuter).contains(baseJoinType)) { + throw QueryParsingErrors.unsupportedNearestByJoinTypeError(ctx, baseJoinType.sql) + } + val clause = ctx.nearestByClause + val approx = clause.APPROX != null + val numResults = Option(clause.num).map { n => + // Guard against literals that overflow Long. + val value = try n.getText.toLong catch { + case _: NumberFormatException => + throw QueryParsingErrors.nearestByJoinNumResultsOutOfRangeError( + ctx, n.getText, NearestByJoin.MaxNumResults) } - (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case Some(c) => - throw SparkException.internalError(s"Unimplemented joinCriteria: $c") - case None if ctx.NATURAL != null => - if (ctx.LATERAL != null) { - throw QueryParsingErrors.incompatibleJoinTypesError( - joinType1 = ctx.LATERAL.toString, joinType2 = ctx.NATURAL.toString, ctx = ctx - ) + if (value < 1 || value > NearestByJoin.MaxNumResults) { + throw QueryParsingErrors.nearestByJoinNumResultsOutOfRangeError( + ctx, value.toString, NearestByJoin.MaxNumResults) } - if (baseJoinType == Cross) { - throw QueryParsingErrors.incompatibleJoinTypesError( - joinType1 = ctx.NATURAL.toString, joinType2 = baseJoinType.toString, ctx = ctx - ) + value.toInt + }.getOrElse(1) + val direction = if (clause.DISTANCE != null) NearestByDistance else NearestBySimilarity + val rankingExpr = expression(clause.expression) + NearestByJoin( + base, plan(ctx.right), baseJoinType, approx, numResults, rankingExpr, direction) + } else { + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + if (ctx.LATERAL != null) { + throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx) + } + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw SparkException.internalError(s"Unimplemented joinCriteria: $c") + case None if ctx.NATURAL != null => + if (ctx.LATERAL != null) { + throw QueryParsingErrors.incompatibleJoinTypesError( + joinType1 = ctx.LATERAL.toString, joinType2 = ctx.NATURAL.toString, ctx = ctx + ) + } + if (baseJoinType == Cross) { + throw QueryParsingErrors.incompatibleJoinTypesError( + joinType1 = ctx.NATURAL.toString, joinType2 = baseJoinType.toString, ctx = ctx + ) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + if (ctx.LATERAL != null) { + if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { + throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql) } - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - if (ctx.LATERAL != null) { - if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { - throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql) + LateralJoin(base, LateralSubquery(plan(ctx.right)), joinType, condition) + } else { + Join(base, plan(ctx.right), joinType, condition, JoinHint.NONE) } - LateralJoin(base, LateralSubquery(plan(ctx.right)), joinType, condition) - } else { - Join(base, plan(ctx.right), joinType, condition, JoinHint.NONE) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 9f8c62fe58408..569cd05a46ba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -181,3 +181,65 @@ object LateralJoinType { ) } } + +object NearestByDirection { + + val supported = Seq("distance", "similarity") + + def apply(direction: String): NearestByDirection = { + direction.toLowerCase(Locale.ROOT) match { + case "distance" => NearestByDistance + case "similarity" => NearestBySimilarity + case _ => + throw new AnalysisException( + errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_DIRECTION", + messageParameters = Map( + "direction" -> direction, + "supported" -> supported.mkString("'", "', '", "'"))) + } + } +} + +sealed abstract class NearestByDirection + +case object NearestByDistance extends NearestByDirection +case object NearestBySimilarity extends NearestByDirection + +object NearestByJoinType { + + /** Strings accepted by the Dataset API. */ + val supported = Seq("inner", "leftouter", "left", "left_outer") + + /** Display string used in `NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE` error messages. Matches the + * parser-side wording so the same error class reports the same `supported` value across the + * SQL and DataFrame paths. */ + val supportedDisplay = "'INNER', 'LEFT OUTER'" + + def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match { + case "inner" => Inner + case "leftouter" | "left" => LeftOuter + case _ => + throw new AnalysisException( + errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", + messageParameters = Map( + "joinType" -> typ, + "supported" -> supportedDisplay)) + } +} + +object NearestByJoinMode { + + val supported = Seq("approx", "exact") + + /** Returns true for APPROX, false for EXACT. */ + def apply(mode: String): Boolean = mode.toLowerCase(Locale.ROOT) match { + case "approx" => true + case "exact" => false + case _ => + throw new AnalysisException( + errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_MODE", + messageParameters = Map( + "mode" -> mode, + "supported" -> supported.mkString("'", "', '", "'"))) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8e9f264698caf..d4b91f8e26f15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -2420,3 +2420,62 @@ object AsOfJoin { } } } + +object NearestByJoin { + /** Upper bound on `numResults`. Mirrors the K-overload limit of `MaxMinByK`. */ + val MaxNumResults: Int = 100000 + + /** + * Tag set by `RewriteNearestByJoin` on the synthetic `Join` it produces. The synthetic join + * has no condition by construction (it is the cross-join step in the rewrite, bounded by the + * subsequent `MaxMinByK` aggregate). `CheckCartesianProducts` skips any join carrying this + * tag so that user queries written as `NEAREST BY` are not rejected when + * `spark.sql.crossJoin.enabled` is set to `false`. + */ + val SYNTHETIC_JOIN_TAG: TreeNodeTag[Unit] = TreeNodeTag("nearestBySyntheticJoin") +} + +/** + * A logical plan for nearest-by top-K ranking join. For each row on the left side it returns up to + * `numResults` rows from the right side ordered by `rankingExpression`: + * - `NearestByDistance`: smallest values of `rankingExpression` first. + * - `NearestBySimilarity`: largest values of `rankingExpression` first. + * + * When `approx` is true, the optimizer is allowed to use approximate strategies such as indexed + * nearest-neighbor search. When `approx` is false (EXACT), brute-force evaluation is used and the + * ranking expression must be deterministic. + */ +case class NearestByJoin( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + approx: Boolean, + numResults: Int, + rankingExpression: Expression, + direction: NearestByDirection) extends BinaryNode { + + require(Seq(Inner, LeftOuter).contains(joinType), + s"Unsupported nearest-by join type $joinType") + + // Right-side attributes are always declared nullable because the rewrite materializes them + // through `Inline` over `MaxMinByK`'s `ArrayType(.., containsNull = true)`, which widens + // every struct field to nullable. Declaring them nullable here keeps the analyzed schema + // consistent with the optimized plan (and with what users see in cached or written outputs). + override def output: Seq[Attribute] = + left.output ++ right.output.map(_.withNullability(true)) + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + override lazy val resolved: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + duplicateResolved + } + + final override val nodePatterns: Seq[TreePattern] = Seq(NEAREST_BY_JOIN) + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): NearestByJoin = { + copy(left = newLeft, right = newRight) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 6af98240160bc..d94a506da82d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -154,6 +154,7 @@ object TreePattern extends Enumeration { val LOGICAL_QUERY_STAGE: Value = Value val METRIC_VIEW_PLACEHOLDER: Value = Value val NATURAL_LIKE_JOIN: Value = Value + val NEAREST_BY_JOIN: Value = Value val NO_GROUPING_AGGREGATE_REFERENCE: Value = Value val OFFSET: Value = Value val OUTER_JOIN: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala new file mode 100644 index 0000000000000..f81aa5174f113 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, MonotonicallyIncreasingID} +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, MaxMinByK} +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, JoinHint, LocalRelation, NearestByJoin, Project} + +class RewriteNearestByJoinSuite extends PlanTest { + + private def expectedRewrite( + left: LocalRelation, + right: LocalRelation, + numResults: Int, + ranking: org.apache.spark.sql.catalyst.expressions.Expression, + reverse: Boolean, + outer: Boolean) = { + val qidAlias = Alias(MonotonicallyIncreasingID(), "__qid")() + val taggedLeft = Project(left.output :+ qidAlias, left) + val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE) + + val rightStruct = CreateStruct(right.output) + val topKAgg = MaxMinByK( + rightStruct, ranking, Literal(numResults), reverse = reverse) + .toAggregateExpression() + val matchesAlias = Alias(topKAgg, "__nearest_matches__")() + val firstLeftAggs = left.output.map { attr => + Alias( + First(attr, ignoreNulls = false).toAggregateExpression(), + attr.name)(exprId = attr.exprId, qualifier = attr.qualifier) + } + val aggregate = Aggregate( + Seq(qidAlias.toAttribute), firstLeftAggs :+ matchesAlias, join) + + val generatorOutput = right.output.map { a => + AttributeReference(a.name, a.dataType, nullable = true)(qualifier = a.qualifier) + } + Generate( + Inline(matchesAlias.toAttribute), + unrequiredChildIndex = Seq(aggregate.output.indexOf(matchesAlias.toAttribute)), + outer = outer, + qualifier = None, + generatorOutput = generatorOutput, + child = aggregate) + } + + test("similarity, inner, k=5") { + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, Inner, approx = true, numResults = 5, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val expected = expectedRewrite( + left, right, 5, + ranking = left.output(0) + right.output(0), + reverse = false, outer = false) + + comparePlans(rewritten, expected, checkAnalysis = false) + } + + test("distance, inner, k=3") { + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, Inner, approx = true, numResults = 3, + rankingExpression = left.output(0) - right.output(0), + direction = NearestByDistance) + + val rewritten = RewriteNearestByJoin(query.analyze) + val expected = expectedRewrite( + left, right, 3, + ranking = left.output(0) - right.output(0), + reverse = true, outer = false) + + comparePlans(rewritten, expected, checkAnalysis = false) + } + + test("similarity, left outer, k=1") { + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, LeftOuter, approx = true, numResults = 1, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val expected = expectedRewrite( + left, right, 1, + ranking = left.output(0) + right.output(0), + reverse = false, outer = true) + + comparePlans(rewritten, expected, checkAnalysis = false) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index edaa7aee5cabb..6124c69fbedda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -826,6 +826,145 @@ class PlanParserSuite extends AnalysisTest { ) } + test("nearest-by join") { + assertEqual( + "select * from t join u approx nearest 5 by similarity t.a + u.a", + NearestByJoin( + table("t"), + table("u"), + Inner, + approx = true, + numResults = 5, + rankingExpression = $"t.a" + $"u.a", + direction = NearestBySimilarity).select(star())) + + assertEqual( + "select * from t inner join u exact nearest 3 by distance t.a - u.a", + NearestByJoin( + table("t"), + table("u"), + Inner, + approx = false, + numResults = 3, + rankingExpression = $"t.a" - $"u.a", + direction = NearestByDistance).select(star())) + + assertEqual( + "select * from t left outer join u approx nearest by similarity t.a + u.a", + NearestByJoin( + table("t"), + table("u"), + LeftOuter, + approx = true, + numResults = 1, + rankingExpression = $"t.a" + $"u.a", + direction = NearestBySimilarity).select(star())) + + // Unsupported join type. + val sqlRightOuter = + "select * from t right outer join u approx nearest 1 by similarity t.a" + checkError( + exception = parseException(sqlRightOuter), + condition = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", + parameters = Map( + "joinType" -> "RIGHT OUTER", + "supported" -> "'INNER', 'LEFT OUTER'"), + context = ExpectedContext( + fragment = "right outer join u approx nearest 1 by similarity t.a", + start = 16, + stop = 68)) + + val sqlFullOuter = + "select * from t full outer join u approx nearest 1 by similarity t.a" + checkError( + exception = parseException(sqlFullOuter), + condition = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", + parameters = Map( + "joinType" -> "FULL OUTER", + "supported" -> "'INNER', 'LEFT OUTER'"), + context = ExpectedContext( + fragment = "full outer join u approx nearest 1 by similarity t.a", + start = 16, + stop = 67)) + + val sqlCross = + "select * from t cross join u approx nearest 1 by similarity t.a" + checkError( + exception = parseException(sqlCross), + condition = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", + parameters = Map( + "joinType" -> "CROSS", + "supported" -> "'INNER', 'LEFT OUTER'"), + context = ExpectedContext( + fragment = "cross join u approx nearest 1 by similarity t.a", + start = 16, + stop = 62)) + + // LATERAL + NEAREST BY not allowed. + val sqlLateral = + "select * from t join lateral (select * from u) uu approx nearest 1 by similarity 1" + checkError( + exception = parseException(sqlLateral), + condition = "UNSUPPORTED_FEATURE.LATERAL_JOIN_NEAREST_BY", + parameters = Map.empty, + context = ExpectedContext( + fragment = "join lateral (select * from u) uu approx nearest 1 by similarity 1", + start = 16, + stop = 81)) + + // num_results out of range. + val sqlTooSmall = + "select * from t join u approx nearest 0 by similarity t.a" + checkError( + exception = parseException(sqlTooSmall), + condition = "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + parameters = Map("numResults" -> "0", "min" -> "1", "max" -> "100000"), + context = ExpectedContext( + fragment = "join u approx nearest 0 by similarity t.a", + start = 16, + stop = 56)) + + val sqlTooLarge = + "select * from t join u approx nearest 100001 by distance t.a" + checkError( + exception = parseException(sqlTooLarge), + condition = "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + parameters = Map("numResults" -> "100001", "min" -> "1", "max" -> "100000"), + context = ExpectedContext( + fragment = "join u approx nearest 100001 by distance t.a", + start = 16, + stop = 59)) + + // Literal that overflows Long (>19 digits) should surface as the standard out-of-range + // error, not an unwrapped NumberFormatException. + val sqlOverflow = + "select * from t join u approx nearest 99999999999999999999 by distance t.a" + checkError( + exception = parseException(sqlOverflow), + condition = "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + parameters = Map( + "numResults" -> "99999999999999999999", + "min" -> "1", + "max" -> "100000"), + context = ExpectedContext( + fragment = "join u approx nearest 99999999999999999999 by distance t.a", + start = 16, + stop = 73)) + } + + test("nearest-by keywords are non-reserved (usable as identifiers)") { + // The five new keywords (APPROX, DISTANCE, EXACT, NEAREST, SIMILARITY) must remain + // non-reserved so they can continue to be used as column or table identifiers. + Seq("approx", "distance", "exact", "nearest", "similarity").foreach { kw => + // As a column identifier in the SELECT list. + parsePlan(s"select $kw from t") + // As a table identifier in the FROM clause. + parsePlan(s"select * from $kw") + } + // All five together in a single SELECT list. + parsePlan("select approx, distance, exact, nearest, similarity from t") + } + test("sampled relations") { val sql = "select * from t" assertEqual(s"$sql tablesample(100 rows)", diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala index a0b4711c2747b..1cfe05a2b5c1f 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala @@ -210,7 +210,7 @@ class SparkConnectDatabaseMetaDataSuite extends ConnectFunSuite with RemoteSpark val metadata = conn.getMetaData // scalastyle:off line.size.limit // CURRENT_PATH is excluded: getSQLKeywords drops SQL:2003 reserved words (see companion). - assert(metadata.getSQLKeywords === "ADD,AFTER,AGGREGATE,ALWAYS,ANALYZE,ANTI,ANY_VALUE,ARCHIVE,ASC,BINDING,BUCKET,BUCKETS,BYTE,CACHE,CASCADE,CATALOG,CATALOGS,CHANGE,CHANGES,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATION,COLLATIONS,COLLECTION,COLUMNS,COMMENT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONTAINS,CONTINUE,COST,CURRENT_DATABASE,CURRENT_SCHEMA,DATA,DATABASE,DATABASES,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAYOFYEAR,DAYS,DBPROPERTIES,DEFAULT_PATH,DEFINED,DEFINER,DELAY,DELIMITED,DESC,DFS,DIRECTORIES,DIRECTORY,DISTRIBUTE,DIV,DO,ELSEIF,ENFORCED,ESCAPED,EVOLUTION,EXCHANGE,EXCLUDE,EXCLUSIVE,EXIT,EXPLAIN,EXPORT,EXTEND,EXTENDED,FIELDS,FILEFORMAT,FIRST,FLOW,FOLLOWING,FORMAT,FORMATTED,FOUND,FUNCTIONS,GENERATED,GEOGRAPHY,GEOMETRY,HANDLER,HOURS,IDENTIFIED,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,INCLUDE,INCLUSIVE,INCREMENT,INDEX,INDEXES,INPATH,INPUT,INPUTFORMAT,INVOKER,ITEMS,ITERATE,JSON,KEY,KEYS,LAST,LAZY,LEAVE,LEVEL,LIMIT,LINES,LIST,LOAD,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MATERIALIZED,MEASURE,METRICS,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTES,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NORELY,NULLS,OFFSET,OPTION,OPTIONS,OUTPUTFORMAT,OVERWRITE,PARTITIONED,PARTITIONS,PATH,PERCENT,PIVOT,PLACING,PRECEDING,PRINCIPALS,PROCEDURES,PROPERTIES,PURGE,QUALIFY,QUARTER,QUERY,RECORDREADER,RECORDWRITER,RECOVER,RECURSION,REDUCE,REFRESH,RELY,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,ROLE,ROLES,SCHEMA,SCHEMAS,SECONDS,SECURITY,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SETS,SHORT,SHOW,SINGLE,SKEWED,SORT,SORTED,SOURCE,STATISTICS,STORED,STRATIFY,STREAM,STREAMING,STRING,STRUCT,SUBSTR,SYNC,SYSTEM_PATH,SYSTEM_TIME,SYSTEM_VERSION,TABLES,TARGET,TBLPROPERTIES,TERMINATED,TIMEDIFF,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TOUCH,TRANSACTION,TRANSACTIONS,TRANSFORM,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNLOCK,UNPIVOT,UNSET,UNTIL,USE,VAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WATERMARK,WEEK,WEEKS,WHILE,X,YEARS,ZONE") + assert(metadata.getSQLKeywords === "ADD,AFTER,AGGREGATE,ALWAYS,ANALYZE,ANTI,ANY_VALUE,APPROX,ARCHIVE,ASC,BINDING,BUCKET,BUCKETS,BYTE,CACHE,CASCADE,CATALOG,CATALOGS,CHANGE,CHANGES,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATION,COLLATIONS,COLLECTION,COLUMNS,COMMENT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONTAINS,CONTINUE,COST,CURRENT_DATABASE,CURRENT_SCHEMA,DATA,DATABASE,DATABASES,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAYOFYEAR,DAYS,DBPROPERTIES,DEFAULT_PATH,DEFINED,DEFINER,DELAY,DELIMITED,DESC,DFS,DIRECTORIES,DIRECTORY,DISTANCE,DISTRIBUTE,DIV,DO,ELSEIF,ENFORCED,ESCAPED,EVOLUTION,EXACT,EXCHANGE,EXCLUDE,EXCLUSIVE,EXIT,EXPLAIN,EXPORT,EXTEND,EXTENDED,FIELDS,FILEFORMAT,FIRST,FLOW,FOLLOWING,FORMAT,FORMATTED,FOUND,FUNCTIONS,GENERATED,GEOGRAPHY,GEOMETRY,HANDLER,HOURS,IDENTIFIED,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,INCLUDE,INCLUSIVE,INCREMENT,INDEX,INDEXES,INPATH,INPUT,INPUTFORMAT,INVOKER,ITEMS,ITERATE,JSON,KEY,KEYS,LAST,LAZY,LEAVE,LEVEL,LIMIT,LINES,LIST,LOAD,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MATERIALIZED,MEASURE,METRICS,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTES,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NEAREST,NORELY,NULLS,OFFSET,OPTION,OPTIONS,OUTPUTFORMAT,OVERWRITE,PARTITIONED,PARTITIONS,PATH,PERCENT,PIVOT,PLACING,PRECEDING,PRINCIPALS,PROCEDURES,PROPERTIES,PURGE,QUALIFY,QUARTER,QUERY,RECORDREADER,RECORDWRITER,RECOVER,RECURSION,REDUCE,REFRESH,RELY,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,ROLE,ROLES,SCHEMA,SCHEMAS,SECONDS,SECURITY,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SETS,SHORT,SHOW,SIMILARITY,SINGLE,SKEWED,SORT,SORTED,SOURCE,STATISTICS,STORED,STRATIFY,STREAM,STREAMING,STRING,STRUCT,SUBSTR,SYNC,SYSTEM_PATH,SYSTEM_TIME,SYSTEM_VERSION,TABLES,TARGET,TBLPROPERTIES,TERMINATED,TIMEDIFF,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TOUCH,TRANSACTION,TRANSACTIONS,TRANSFORM,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNLOCK,UNPIVOT,UNSET,UNTIL,USE,VAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WATERMARK,WEEK,WEEKS,WHILE,X,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out new file mode 100644 index 0000000000000..5abb671190cb6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out @@ -0,0 +1,238 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE VIEW users(user_id, score) AS VALUES (1, 10.0), (2, 20.0), (3, 30.0) +-- !query analysis +CreateViewCommand `spark_catalog`.`default`.`users`, [(user_id,None), (score,None)], VALUES (1, 10.0), (2, 20.0), (3, 30.0), false, false, PersistedView, COMPENSATION, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +CREATE VIEW products(product, pscore) AS VALUES ('A', 11.0), ('B', 22.0), ('C', 5.0) +-- !query analysis +CreateViewCommand `spark_catalog`.`default`.`products`, [(product,None), (pscore,None)], VALUES ('A', 11.0), ('B', 22.0), ('C', 5.0), false, false, PersistedView, COMPENSATION, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x] ++- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product, p.pscore +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x, pscore#x] ++- NearestByJoin Inner, true, 2, abs((score#x - pscore#x)), NearestByDistance + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u INNER JOIN products p + EXACT NEAREST BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x] ++- NearestByJoin Inner, false, 1, -abs((score#x - pscore#x)), NearestBySimilarity + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u LEFT OUTER JOIN (SELECT * FROM products WHERE false) p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x] ++- NearestByJoin LeftOuter, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- Project [product#x, pscore#x] + +- Filter false + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u INNER JOIN products p + APPROX NEAREST 1 BY DISTANCE abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x] ++- NearestByJoin Inner, true, 1, abs((score#x - pscore#x)), NearestByDistance + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u RIGHT OUTER JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", + "sqlState" : "42604", + "messageParameters" : { + "joinType" : "RIGHT OUTER", + "supported" : "'INNER', 'LEFT OUTER'" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 126, + "fragment" : "RIGHT OUTER JOIN products p\n APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 0 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + "sqlState" : "42604", + "messageParameters" : { + "max" : "100000", + "min" : "1", + "numResults" : "0" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 114, + "fragment" : "JOIN products p\n APPROX NEAREST 0 BY SIMILARITY -abs(u.score - p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 100001 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + "sqlState" : "42604", + "messageParameters" : { + "max" : "100000", + "min" : "1", + "numResults" : "100001" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 119, + "fragment" : "JOIN products p\n APPROX NEAREST 100001 BY SIMILARITY -abs(u.score - p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY map(u.score, p.pscore) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NEAREST_BY_JOIN.NON_ORDERABLE_RANKING_EXPRESSION", + "sqlState" : "42604", + "messageParameters" : { + "expression" : "\"map(score, pscore)\"", + "type" : "\"MAP\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 112, + "fragment" : "JOIN products p\n APPROX NEAREST 1 BY SIMILARITY map(u.score, p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + EXACT NEAREST 1 BY SIMILARITY rand() + p.pscore +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NEAREST_BY_JOIN.EXACT_WITH_NONDETERMINISTIC_EXPRESSION", + "sqlState" : "42604", + "messageParameters" : { + "expression" : "\"(rand() + pscore)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 106, + "fragment" : "JOIN products p\n EXACT NEAREST 1 BY SIMILARITY rand() + p.pscore" + } ] +} + + +-- !query +DROP VIEW users +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`users`, false, true, false + + +-- !query +DROP VIEW products +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`products`, false, true, false diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql new file mode 100644 index 0000000000000..8dca4f895483a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql @@ -0,0 +1,57 @@ +-- Test cases for NEAREST BY top-K ranking join. + +CREATE VIEW users(user_id, score) AS VALUES (1, 10.0), (2, 20.0), (3, 30.0); +CREATE VIEW products(product, pscore) AS VALUES ('A', 11.0), ('B', 22.0), ('C', 5.0); + +-- Basic APPROX NEAREST BY SIMILARITY with k = 1 +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + +-- APPROX NEAREST BY DISTANCE with k = 2 +SELECT u.user_id, p.product, p.pscore +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore); + +-- EXACT NEAREST BY SIMILARITY with default k = 1 +SELECT u.user_id, p.product +FROM users u INNER JOIN products p + EXACT NEAREST BY SIMILARITY -abs(u.score - p.pscore); + +-- LEFT OUTER JOIN with NEAREST BY, empty right side +SELECT u.user_id, p.product +FROM users u LEFT OUTER JOIN (SELECT * FROM products WHERE false) p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + +-- Explicit INNER keyword +SELECT u.user_id, p.product +FROM users u INNER JOIN products p + APPROX NEAREST 1 BY DISTANCE abs(u.score - p.pscore); + +-- Error: unsupported join type (RIGHT OUTER) +SELECT u.user_id, p.product +FROM users u RIGHT OUTER JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + +-- Error: num_results out of range (0) +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 0 BY SIMILARITY -abs(u.score - p.pscore); + +-- Error: num_results out of range (100001) +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 100001 BY SIMILARITY -abs(u.score - p.pscore); + +-- Error: non-orderable ranking expression +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY map(u.score, p.pscore); + +-- Error: EXACT mode with nondeterministic ranking expression +SELECT u.user_id, p.product +FROM users u JOIN products p + EXACT NEAREST 1 BY SIMILARITY rand() + p.pscore; + +DROP VIEW users; +DROP VIEW products; diff --git a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out new file mode 100644 index 0000000000000..afaf4fbdb9e24 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out @@ -0,0 +1,220 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE VIEW users(user_id, score) AS VALUES (1, 10.0), (2, 20.0), (3, 30.0) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE VIEW products(product, pscore) AS VALUES ('A', 11.0), ('B', 22.0), ('C', 5.0) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 A +2 B +3 B + + +-- !query +SELECT u.user_id, p.product, p.pscore +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 A 11.0 +1 C 5.0 +2 A 11.0 +2 B 22.0 +3 A 11.0 +3 B 22.0 + + +-- !query +SELECT u.user_id, p.product +FROM users u INNER JOIN products p + EXACT NEAREST BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 A +2 B +3 B + + +-- !query +SELECT u.user_id, p.product +FROM users u LEFT OUTER JOIN (SELECT * FROM products WHERE false) p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 NULL +2 NULL +3 NULL + + +-- !query +SELECT u.user_id, p.product +FROM users u INNER JOIN products p + APPROX NEAREST 1 BY DISTANCE abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 A +2 B +3 B + + +-- !query +SELECT u.user_id, p.product +FROM users u RIGHT OUTER JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", + "sqlState" : "42604", + "messageParameters" : { + "joinType" : "RIGHT OUTER", + "supported" : "'INNER', 'LEFT OUTER'" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 126, + "fragment" : "RIGHT OUTER JOIN products p\n APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 0 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + "sqlState" : "42604", + "messageParameters" : { + "max" : "100000", + "min" : "1", + "numResults" : "0" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 114, + "fragment" : "JOIN products p\n APPROX NEAREST 0 BY SIMILARITY -abs(u.score - p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 100001 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE", + "sqlState" : "42604", + "messageParameters" : { + "max" : "100000", + "min" : "1", + "numResults" : "100001" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 119, + "fragment" : "JOIN products p\n APPROX NEAREST 100001 BY SIMILARITY -abs(u.score - p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY map(u.score, p.pscore) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NEAREST_BY_JOIN.NON_ORDERABLE_RANKING_EXPRESSION", + "sqlState" : "42604", + "messageParameters" : { + "expression" : "\"map(score, pscore)\"", + "type" : "\"MAP\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 112, + "fragment" : "JOIN products p\n APPROX NEAREST 1 BY SIMILARITY map(u.score, p.pscore)" + } ] +} + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + EXACT NEAREST 1 BY SIMILARITY rand() + p.pscore +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NEAREST_BY_JOIN.EXACT_WITH_NONDETERMINISTIC_EXPRESSION", + "sqlState" : "42604", + "messageParameters" : { + "expression" : "\"(rand() + pscore)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 42, + "stopIndex" : 106, + "fragment" : "JOIN products p\n EXACT NEAREST 1 BY SIMILARITY rand() + p.pscore" + } ] +} + + +-- !query +DROP VIEW users +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW products +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out index 11a103e6cc0e6..6f9e8fde5d9f1 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out @@ -15,6 +15,7 @@ AND true ANTI false ANY true ANY_VALUE false +APPROX false ARCHIVE false ARRAY false AS true @@ -110,6 +111,7 @@ DETERMINISTIC false DFS false DIRECTORIES false DIRECTORY false +DISTANCE false DISTINCT true DISTRIBUTE false DIV false @@ -123,6 +125,7 @@ ENFORCED false ESCAPE true ESCAPED false EVOLUTION false +EXACT false EXCEPT true EXCHANGE false EXCLUDE false @@ -243,6 +246,7 @@ NAMESPACES false NANOSECOND false NANOSECONDS false NATURAL true +NEAREST false NEXT false NO false NONE false @@ -331,6 +335,7 @@ SET false SETS false SHORT false SHOW false +SIMILARITY false SINGLE false SKEWED false SMALLINT false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 1a7db9df073f4..1fdb51507bc1b 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -15,6 +15,7 @@ AND false ANTI false ANY false ANY_VALUE false +APPROX false ARCHIVE false ARRAY false AS false @@ -110,6 +111,7 @@ DETERMINISTIC false DFS false DIRECTORIES false DIRECTORY false +DISTANCE false DISTINCT false DISTRIBUTE false DIV false @@ -123,6 +125,7 @@ ENFORCED false ESCAPE false ESCAPED false EVOLUTION false +EXACT false EXCEPT false EXCHANGE false EXCLUDE false @@ -243,6 +246,7 @@ NAMESPACES false NANOSECOND false NANOSECONDS false NATURAL false +NEAREST false NEXT false NO false NONE false @@ -331,6 +335,7 @@ SET false SETS false SHORT false SHOW false +SIMILARITY false SINGLE false SKEWED false SMALLINT false diff --git a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out index 1a7db9df073f4..1fdb51507bc1b 100644 --- a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out @@ -15,6 +15,7 @@ AND false ANTI false ANY false ANY_VALUE false +APPROX false ARCHIVE false ARRAY false AS false @@ -110,6 +111,7 @@ DETERMINISTIC false DFS false DIRECTORIES false DIRECTORY false +DISTANCE false DISTINCT false DISTRIBUTE false DIV false @@ -123,6 +125,7 @@ ENFORCED false ESCAPE false ESCAPED false EVOLUTION false +EXACT false EXCEPT false EXCHANGE false EXCLUDE false @@ -243,6 +246,7 @@ NAMESPACES false NANOSECOND false NANOSECONDS false NATURAL false +NEAREST false NEXT false NO false NONE false @@ -331,6 +335,7 @@ SET false SETS false SHORT false SHOW false +SIMILARITY false SINGLE false SKEWED false SMALLINT false diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index c9696a1b2fe68..5067f7dfbcc54 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,ASENSITIVE,AT,ATOMIC,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHANGES,CHAR,CHARACTER,CHECK,CLEAR,CLOSE,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLATIONS,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONDITION,CONSTRAINT,CONTAINS,CONTINUE,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATABASE,CURRENT_DATE,CURRENT_PATH,CURRENT_SCHEMA,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,CURSOR,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFAULT_PATH,DEFINED,DEFINER,DELAY,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,ELSEIF,END,ENFORCED,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXCLUSIVE,EXECUTE,EXISTS,EXIT,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FLOW,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FOUND,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GEOGRAPHY,GEOMETRY,GLOBAL,GRANT,GROUP,GROUPING,HANDLER,HAVING,HOUR,HOURS,IDENTIFIED,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCLUSIVE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSENSITIVE,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEY,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LEVEL,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MATERIALIZED,MAX,MEASURE,MERGE,METRICS,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NEXT,NO,NONE,NORELY,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPEN,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PATH,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROCEDURE,PROCEDURES,PROPERTIES,PURGE,QUALIFY,QUARTER,QUERY,RANGE,READ,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSION,RECURSIVE,REDUCE,REFERENCES,REFRESH,RELY,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,SQLEXCEPTION,SQLSTATE,START,STATISTICS,STORED,STRATIFY,STREAM,STREAMING,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_PATH,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUE,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WATERMARK,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,WITHOUT,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,APPROX,ARCHIVE,ARRAY,AS,ASC,ASENSITIVE,AT,ATOMIC,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHANGES,CHAR,CHARACTER,CHECK,CLEAR,CLOSE,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLATIONS,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONDITION,CONSTRAINT,CONTAINS,CONTINUE,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATABASE,CURRENT_DATE,CURRENT_PATH,CURRENT_SCHEMA,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,CURSOR,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFAULT_PATH,DEFINED,DEFINER,DELAY,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTANCE,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,ELSEIF,END,ENFORCED,ESCAPE,ESCAPED,EVOLUTION,EXACT,EXCEPT,EXCHANGE,EXCLUDE,EXCLUSIVE,EXECUTE,EXISTS,EXIT,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FLOW,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FOUND,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GEOGRAPHY,GEOMETRY,GLOBAL,GRANT,GROUP,GROUPING,HANDLER,HAVING,HOUR,HOURS,IDENTIFIED,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCLUSIVE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSENSITIVE,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEY,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LEVEL,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MATERIALIZED,MAX,MEASURE,MERGE,METRICS,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NEAREST,NEXT,NO,NONE,NORELY,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPEN,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PATH,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROCEDURE,PROCEDURES,PROPERTIES,PURGE,QUALIFY,QUARTER,QUERY,RANGE,READ,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSION,RECURSIVE,REDUCE,REFERENCES,REFRESH,RELY,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SIMILARITY,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,SQLEXCEPTION,SQLSTATE,START,STATISTICS,STORED,STRATIFY,STREAM,STREAMING,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_PATH,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUE,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WATERMARK,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,WITHOUT,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From bd534c4f768677717e124b801f6726025ff4ffcf Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 1 May 2026 16:47:20 -0700 Subject: [PATCH 2/7] Code review. --- .../resources/error/error-conditions.json | 2 +- .../spark/sql/errors/QueryParsingErrors.scala | 8 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 13 +- .../sql/catalyst/optimizer/Optimizer.scala | 35 ++- .../optimizer/RewriteNearestByJoin.scala | 40 +++- .../sql/catalyst/parser/AstBuilder.scala | 3 +- .../plans/logical/basicLogicalOperators.scala | 26 +- .../analysis/AnalysisErrorSuite.scala | 16 +- .../optimizer/RewriteNearestByJoinSuite.scala | 145 +++++++++++- .../analyzer-results/join-nearest-by.sql.out | 201 ++++++++++++++++ .../sql-tests/inputs/join-nearest-by.sql | 98 ++++++++ .../sql-tests/results/join-nearest-by.sql.out | 222 ++++++++++++++++++ 12 files changed, 762 insertions(+), 47 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 850fae02f719a..9eba289660505 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5335,7 +5335,7 @@ }, "STREAMING_NOT_SUPPORTED" : { "message" : [ - "NEAREST BY join is not supported with streaming DataFrames/Datasets." + "nearest-by join is not supported with streaming DataFrames/Datasets." ] }, "UNSUPPORTED_DIRECTION" : { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 9a7e833342c20..33d7aaef17b81 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -210,11 +210,13 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx) } - def unsupportedNearestByJoinTypeError(ctx: ParserRuleContext, joinType: String): Throwable = { + def unsupportedNearestByJoinTypeError( + ctx: ParserRuleContext, + joinType: String, + supported: String): Throwable = { new ParseException( errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE", - messageParameters = - Map("joinType" -> toSQLStmt(joinType), "supported" -> "'INNER', 'LEFT OUTER'"), + messageParameters = Map("joinType" -> toSQLStmt(joinType), "supported" -> supported), ctx) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 19d1375d962ab..e231fe20d6186 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -657,10 +657,15 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString messageParameters = Map.empty) } - // Reject streaming inputs early. The optimizer rewrite introduces - // `MonotonicallyIncreasingID()`, which is per-batch only and would silently produce - // incorrect results across micro-batches; failing at analysis time is clearer than - // letting the streaming check fire on an incidental MID node. + // Reject streaming inputs early. The optimizer rewrite groups by a `__qid` derived + // from `MonotonicallyIncreasingID()` and feeds it to a global `Aggregate`, which + // Spark turns into a stateful streaming aggregation. Because MID restarts per + // micro-batch, `__qid` values collide across batches, and the stateful aggregate + // silently merges state from old batches into new rows that share the same key -- + // producing wrong top-K results. Failing at analysis time is clearer than letting + // this slip through. Streaming support is tracked as a follow-up; resolving it does + // not require streaming-aware MID and is likely to come from a different grouping + // strategy or a dedicated physical operator. case j: NearestByJoin if j.isStreaming => j.failAnalysis( errorClass = "NEAREST_BY_JOIN.STREAMING_NOT_SUPPORTED", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 618940ac10684..97bfd799a4c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2542,17 +2542,36 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { } } - def apply(plan: LogicalPlan): LogicalPlan = + def apply(plan: LogicalPlan): LogicalPlan = { if (conf.crossJoinEnabled) { - plan - } else plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, OUTER_JOIN)) { - // Skip joins that were synthesized by `RewriteNearestByJoin`: the cross product is an - // intentional, bounded part of that rewrite (see `NearestByJoin.SYNTHETIC_JOIN_TAG`). + return plan + } + + // Joins synthesized by `RewriteNearestByJoin` are an intentional, bounded cross-product + // wrapped by a `MaxMinByK` aggregate. Identify them by their unambiguous post-rewrite + // signature -- `Aggregate(_, exprs, Join(_, _, LeftOuter, None, _))` where `exprs` + // contains a `MaxMinByK` -- and skip them so user queries written as `NEAREST BY` are not + // rejected when `spark.sql.crossJoin.enabled = false`. We use structural detection rather + // than a `TreeNodeTag` because a tag set on the `Join` would be silently dropped by any + // intervening optimizer rule that constructs a fresh `Join` via the case-class + // constructor without calling `copyTagsFrom`. + val nearestByJoins: java.util.IdentityHashMap[Join, Unit] = { + val acc = new java.util.IdentityHashMap[Join, Unit]() + plan.foreach { + case Aggregate(_, exprs, j @ Join(_, _, LeftOuter, None, _), _) + if exprs.exists(_.exists(_.isInstanceOf[MaxMinByK])) => + acc.put(j, ()) + case _ => + } + acc + } + + plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, OUTER_JOIN)) { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _) - if isCartesianProduct(j) && - j.getTagValue(NearestByJoin.SYNTHETIC_JOIN_TAG).isEmpty => - throw QueryCompilationErrors.joinConditionMissingOrTrivialError(j, left, right) + if isCartesianProduct(j) && !nearestByJoins.containsKey(j) => + throw QueryCompilationErrors.joinConditionMissingOrTrivialError(j, left, right) } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala index 0d4d45ad44ec3..1ccf19ebcfb0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules._ * +- Aggregate [__qid], * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS left.colN-1, * max_by(struct(right.*), expr, k) AS _matches] - * +- Join Inner + * +- Join LeftOuter * :- Project [left.*, monotonically_increasing_id() AS __qid] * : +- left * +- right @@ -50,8 +50,17 @@ import org.apache.spark.sql.catalyst.rules._ * constructed with `outer = true` so left rows with no matches (empty/null `_matches`) are * preserved with `NULL` right-side columns. * - * In this initial implementation both `APPROX` and `EXACT` take the same brute-force rewrite - * path. `APPROX` establishes the contract for future indexed-ANN strategies. + * If `rankingExpression` is nondeterministic (legal only under `APPROX`), an extra + * `Project` is inserted above the `Join` to materialize the value as `__ranking__`. The + * standard projection machinery runs `Nondeterministic.initialize(partitionIndex)` on every + * nondeterministic descendant before any value is evaluated, so `MaxMinByK` only ever sees a + * plain `AttributeReference` and never evaluates a nondeterministic expression directly. + * + * Unlike [[RewriteAsOfJoin]], which uses a correlated scalar subquery, this rule materializes + * the cross product directly. A scalar subquery returns a single value per left row, so it + * cannot carry K matches without an array-valued subquery + `Generate(Inline(...))` -- which + * collapses back to the same cross product after decorrelation. The aggregate-then-inline form + * makes the intended shape explicit and avoids round-tripping through subquery decorrelation. */ object RewriteNearestByJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput { @@ -68,14 +77,20 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { // columns after the aggregate + inline below. When `right` is non-empty every left // row already has right-row pairings, so LEFT OUTER and INNER are equivalent. // - // Tag the join so `CheckCartesianProducts` skips it: the rewrite intentionally - // materializes a cross product bounded by the downstream `MaxMinByK` aggregate, so - // `spark.sql.crossJoin.enabled = false` should not reject user queries written as - // `NEAREST BY`. + // `CheckCartesianProducts` recognizes this synthetic join structurally (by its + // parent `Aggregate` containing a `MaxMinByK`) and skips it, so user queries + // written as `NEAREST BY` are not rejected when `spark.sql.crossJoin.enabled` is + // false. val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE) - join.setTagValue(NearestByJoin.SYNTHETIC_JOIN_TAG, ()) - // 3. Aggregate grouped by `__qid`: + val (aggInput, rankingForAgg) = if (!rankingExpression.deterministic) { + val rankingAlias = Alias(rankingExpression, "__ranking__")() + Project(join.output :+ rankingAlias, join) -> rankingAlias.toAttribute + } else { + join -> rankingExpression + } + + // 4. Aggregate grouped by `__qid`: // - first(col) for every left column so it flows to the output. // - max_by/min_by(struct(right.*), ranking, k) as `_matches`. // The ranking expression references left and right columns directly; no outer @@ -89,7 +104,7 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { } val topK = MaxMinByK( rightStruct, - rankingExpression, + rankingForAgg, Literal(numResults), reverse = reverse).toAggregateExpression() val matchesAlias = Alias(topK, "__nearest_matches__")() @@ -103,13 +118,14 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { First(attr, ignoreNulls = false).toAggregateExpression(), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier) } - val aggregate = Aggregate(Seq(qidAttr), firstLeftAggs :+ matchesAlias, join) + val aggregate = Aggregate(Seq(qidAttr), firstLeftAggs :+ matchesAlias, aggInput) // 4. Generate inline(_matches) expands the K-element array into K rows, exposing each // struct field as a top-level column. `outer = true` for LEFT OUTER preserves the // left row with NULL right columns when there are no matches. val generatorOutput = right.output.map { a => - AttributeReference(a.name, a.dataType, nullable = true)(qualifier = a.qualifier) + AttributeReference(a.name, a.dataType, nullable = true, a.metadata)( + qualifier = a.qualifier) } val generate = Generate( Inline(matchesAlias.toAttribute), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 930fdc2d93176..c233f70cc7b33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2367,7 +2367,8 @@ class AstBuilder extends DataTypeAstBuilder throw QueryParsingErrors.nearestByJoinWithLateralUnsupportedError(ctx) } if (!Seq(Inner, LeftOuter).contains(baseJoinType)) { - throw QueryParsingErrors.unsupportedNearestByJoinTypeError(ctx, baseJoinType.sql) + throw QueryParsingErrors.unsupportedNearestByJoinTypeError( + ctx, baseJoinType.sql, NearestByJoinType.supportedDisplay) } val clause = ctx.nearestByClause val approx = clause.APPROX != null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d4b91f8e26f15..7fb9f6b13e445 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -2424,26 +2424,18 @@ object AsOfJoin { object NearestByJoin { /** Upper bound on `numResults`. Mirrors the K-overload limit of `MaxMinByK`. */ val MaxNumResults: Int = 100000 - - /** - * Tag set by `RewriteNearestByJoin` on the synthetic `Join` it produces. The synthetic join - * has no condition by construction (it is the cross-join step in the rewrite, bounded by the - * subsequent `MaxMinByK` aggregate). `CheckCartesianProducts` skips any join carrying this - * tag so that user queries written as `NEAREST BY` are not rejected when - * `spark.sql.crossJoin.enabled` is set to `false`. - */ - val SYNTHETIC_JOIN_TAG: TreeNodeTag[Unit] = TreeNodeTag("nearestBySyntheticJoin") } /** - * A logical plan for nearest-by top-K ranking join. For each row on the left side it returns up to - * `numResults` rows from the right side ordered by `rankingExpression`: + * A logical plan for a nearest-by top-K ranking join. For each row on the left side it returns + * up to `numResults` rows from the right side ordered by `rankingExpression`: * - `NearestByDistance`: smallest values of `rankingExpression` first. * - `NearestBySimilarity`: largest values of `rankingExpression` first. * - * When `approx` is true, the optimizer is allowed to use approximate strategies such as indexed - * nearest-neighbor search. When `approx` is false (EXACT), brute-force evaluation is used and the - * ranking expression must be deterministic. + * The `approx` field records the user's APPROX/EXACT choice from the SPIP. Today both modes + * use the same brute-force rewrite. The flag is preserved on the logical plan so future + * indexed approximate-nearest-neighbor strategies can fire only when `approx = true`, + * leaving EXACT queries unaffected. See the SPIP linked from SPARK-56395. */ case class NearestByJoin( left: LogicalPlan, @@ -2452,11 +2444,15 @@ case class NearestByJoin( approx: Boolean, numResults: Int, rankingExpression: Expression, - direction: NearestByDirection) extends BinaryNode { + direction: NearestByDirection) + extends BinaryNode with SupportsNonDeterministicExpression { require(Seq(Inner, LeftOuter).contains(joinType), s"Unsupported nearest-by join type $joinType") + // APPROX permits a nondeterministic ranking expression (per the SPIP); the rewrite + override def allowNonDeterministicExpression: Boolean = approx + // Right-side attributes are always declared nullable because the rewrite materializes them // through `Inline` over `MaxMinByK`'s `ArrayType(.., containsNull = true)`, which widens // every struct field to nullable. Declaring them nullable here keeps the analyzed schema diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ee644fc62a1ab..0f245d1c22001 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Max} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Cross, Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Cross, Inner, LeftOuter, NearestBySimilarity, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.internal.SQLConf @@ -924,6 +924,20 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { |Conflicting attributes: "a".""".stripMargin)) } + test("NearestByJoin with a streaming input fails analysis") { + val streamingLeft = LocalRelation( + Seq(AttributeReference("a", IntegerType)()), Nil, isStreaming = true) + val batchRight = LocalRelation(AttributeReference("b", IntegerType)()) + val nearestBy = NearestByJoin( + streamingLeft, batchRight, Inner, approx = true, numResults = 1, + rankingExpression = streamingLeft.output.head + batchRight.output.head, + direction = NearestBySimilarity) + assertAnalysisErrorCondition( + nearestBy, + expectedErrorCondition = "NEAREST_BY_JOIN.STREAMING_NOT_SUPPORTED", + expectedMessageParameters = Map.empty) + } + test("check grouping expression data types") { def checkDataType(dataType: DataType): Unit = { val plan = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala index f81aa5174f113..a9132f5f9d4a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, MonotonicallyIncreasingID} -import org.apache.spark.sql.catalyst.expressions.aggregate.{First, MaxMinByK} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, MonotonicallyIncreasingID, Rand} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, First, MaxMinByK} import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, JoinHint, LocalRelation, NearestByJoin, Project} @@ -112,4 +112,145 @@ class RewriteNearestByJoinSuite extends PlanTest { comparePlans(rewritten, expected, checkAnalysis = false) } + + test("distance, left outer, k=2") { + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, LeftOuter, approx = true, numResults = 2, + rankingExpression = left.output(0) - right.output(0), + direction = NearestByDistance) + + val rewritten = RewriteNearestByJoin(query.analyze) + val expected = expectedRewrite( + left, right, 2, + ranking = left.output(0) - right.output(0), + reverse = true, outer = true) + + comparePlans(rewritten, expected, checkAnalysis = false) + } + + test("EXACT (approx = false) produces the same rewrite as APPROX") { + // Locks in the current invariant that APPROX and EXACT lower through the same + // brute-force rewrite. If a future change diverges them (e.g. an APPROX-only + // indexed-ANN strategy lands), this test fails and forces an intentional update. + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, Inner, approx = false, numResults = 5, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val expected = expectedRewrite( + left, right, 5, + ranking = left.output(0) + right.output(0), + reverse = false, outer = false) + + comparePlans(rewritten, expected, checkAnalysis = false) + } + + test("k = 1 (lower boundary)") { + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, Inner, approx = true, numResults = 1, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val expected = expectedRewrite( + left, right, 1, + ranking = left.output(0) + right.output(0), + reverse = false, outer = false) + + comparePlans(rewritten, expected, checkAnalysis = false) + } + + test("k = NearestByJoin.MaxNumResults (upper boundary)") { + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, Inner, approx = true, numResults = NearestByJoin.MaxNumResults, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val expected = expectedRewrite( + left, right, NearestByJoin.MaxNumResults, + ranking = left.output(0) + right.output(0), + reverse = false, outer = false) + + comparePlans(rewritten, expected, checkAnalysis = false) + } + + test("self-join: rewrite resolves duplicate ExprIds via DeduplicateRelations") { + // Exercises the NearestByJoin arm in DeduplicateRelations. Without it, `.analyze` on + // a self-join would leave the right side sharing ExprIds with the left and the + // CheckAnalysis arm would throw an internal error. + val t = LocalRelation($"a".int, $"b".int) + val query = NearestByJoin( + t, t, Inner, approx = true, numResults = 1, + rankingExpression = t.output(0) + t.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val tDup = LocalRelation($"a".int, $"b".int) + val expected = expectedRewrite( + t, tDup, 1, + ranking = t.output(0) + tDup.output(0), + reverse = false, outer = false) + + comparePlans(rewritten, expected, checkAnalysis = false) + } + + test("APPROX with nondeterministic ranking pre-materializes via Project") { + // Locks in the Project-injection shape: when the ranking expression is nondeterministic + // (legal only under APPROX), the rewrite inserts a Project above the Join that aliases + // the ranking value as `__ranking__`. MaxMinByK then sees a plain AttributeReference as + // its ordering input. This relies on Projection's standard partition-aware initialization + // to call `Rand.initialize` once per partition before any value is evaluated; otherwise + // MaxMinByK would call `eval` on an uninitialized Rand and throw at runtime. If a future + // optimizer change folds this Project away, this test fails and forces an intentional + // update. + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val ranking = Rand(Literal(0L)) + right.output(0) + val query = NearestByJoin( + left, right, Inner, approx = true, numResults = 1, + rankingExpression = ranking, + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + + val agg = rewritten.collect { case a: Aggregate => a }.head + assert(agg.child.isInstanceOf[Project], + s"expected materializing Project above the Join when ranking is nondeterministic, " + + s"got ${agg.child.getClass.getSimpleName}") + val maxMinByK = agg.aggregateExpressions.collectFirst { + case Alias(AggregateExpression(m: MaxMinByK, _, _, _, _), "__nearest_matches__") => m + }.getOrElse(fail("expected MaxMinByK aggregate in the rewritten plan")) + assert(maxMinByK.orderingExpr.isInstanceOf[AttributeReference], + "ranking expression should be materialized as an attribute, not evaluated inside MaxMinByK") + assert(maxMinByK.orderingExpr.asInstanceOf[AttributeReference].name == "__ranking__") + assert(rewritten.exists(_.expressions.exists(_.exists(_.isInstanceOf[Rand]))), + "Rand should still appear in the plan -- inside the materializing Project, not lost") + } + + test("APPROX with deterministic ranking does NOT inject the materializing Project") { + // Counterpart to the test above: confirms the Project-injection is gated on + // `!rankingExpression.deterministic` so the deterministic path's plan shape is unchanged. + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + val query = NearestByJoin( + left, right, Inner, approx = true, numResults = 1, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val agg = rewritten.collect { case a: Aggregate => a }.head + assert(agg.child.isInstanceOf[Join], + s"expected Aggregate's child to be the Join directly when ranking is deterministic, " + + s"got ${agg.child.getClass.getSimpleName}") + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out index 5abb671190cb6..79cbee6001619 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out @@ -110,6 +110,28 @@ Project [user_id#x, product#x] +- LocalRelation [col1#x, col2#x] +-- !query +SELECT a.user_id AS a_id, b.user_id AS b_id +FROM users a JOIN users b + APPROX NEAREST 1 BY DISTANCE abs(a.score - b.score) +ORDER BY a.user_id, b.user_id +-- !query analysis +Project [a_id#x, b_id#x] ++- Sort [user_id#x ASC NULLS FIRST, user_id#x ASC NULLS FIRST], true + +- Project [user_id#x AS a_id#x, user_id#x AS b_id#x, user_id#x, user_id#x] + +- NearestByJoin Inner, true, 1, abs((score#x - score#x)), NearestByDistance + :- SubqueryAlias a + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias b + +- SubqueryAlias spark_catalog.default.users + +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + +- LocalRelation [col1#x, col2#x] + + -- !query SELECT u.user_id, p.product FROM users u RIGHT OUTER JOIN products p @@ -226,6 +248,185 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +SELECT COUNT(*) AS num_rows +FROM ( + SELECT u.user_id, p.product + FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY rand() + p.pscore +) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT COUNT(*) AS num_rows +FROM ( + SELECT u.user_id, p.product + FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE rand() + p.pscore +) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY rand(0) + p.pscore +-- !query analysis +ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode + + +-- !query +SET spark.sql.crossJoin.enabled = false +-- !query analysis +SetCommand (spark.sql.crossJoin.enabled,Some(false)) + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x] ++- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE p.product != 'C' +-- !query analysis +Project [user_id#x, product#x] ++- Filter NOT (product#x = C) + +- NearestByJoin Inner, true, 2, abs((score#x - pscore#x)), NearestByDistance + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE u.user_id > 1 +-- !query analysis +Project [user_id#x, product#x] ++- Filter (user_id#x > 1) + +- NearestByJoin Inner, true, 2, abs((score#x - pscore#x)), NearestByDistance + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, p.product +FROM users u LEFT OUTER JOIN products p + EXACT NEAREST 1 BY DISTANCE abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x] ++- NearestByJoin LeftOuter, false, 1, abs((score#x - pscore#x)), NearestByDistance + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE u.user_id > 1 +-- !query analysis +ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode + + +-- !query +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE p.product != 'C' +-- !query analysis +ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode + + +-- !query +SET spark.sql.crossJoin.enabled = true +-- !query analysis +SetCommand (spark.sql.crossJoin.enabled,Some(true)) + + +-- !query +CREATE OR REPLACE TEMP VIEW tied_products(product, pscore) + AS VALUES ('A', 10.0), ('B', 10.0), ('C', 10.0) +-- !query analysis +CreateViewCommand `tied_products`, [(product,None), (pscore,None)], VALUES ('A', 10.0), ('B', 10.0), ('C', 10.0), false, true, LocalTempView, UNSUPPORTED, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT u.user_id, COUNT(*) AS num_matches +FROM users u JOIN tied_products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +GROUP BY u.user_id +ORDER BY u.user_id +-- !query analysis +Sort [user_id#x ASC NULLS FIRST], true ++- Aggregate [user_id#x], [user_id#x, count(1) AS num_matches#xL] + +- NearestByJoin Inner, true, 2, abs((score#x - pscore#x)), NearestByDistance + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias tied_products + +- View (`tied_products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +DROP VIEW tied_products +-- !query analysis +DropTempViewCommand tied_products, false + + -- !query DROP VIEW users -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql index 8dca4f895483a..ad6506cd50191 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql @@ -28,6 +28,14 @@ SELECT u.user_id, p.product FROM users u INNER JOIN products p APPROX NEAREST 1 BY DISTANCE abs(u.score - p.pscore); +-- Self-join: same relation on both sides. Exercises DeduplicateRelations' NearestByJoin +-- arm, which rewrites the right side with fresh ExprIds so the join resolves. Each row's +-- nearest match by `abs(score - score)` is itself, so the output is deterministic. +SELECT a.user_id AS a_id, b.user_id AS b_id +FROM users a JOIN users b + APPROX NEAREST 1 BY DISTANCE abs(a.score - b.score) +ORDER BY a.user_id, b.user_id; + -- Error: unsupported join type (RIGHT OUTER) SELECT u.user_id, p.product FROM users u RIGHT OUTER JOIN products p @@ -53,5 +61,95 @@ SELECT u.user_id, p.product FROM users u JOIN products p EXACT NEAREST 1 BY SIMILARITY rand() + p.pscore; +-- APPROX permits a nondeterministic ranking expression (per the SPIP). Rows differ run to +-- run, so we only assert the row count: one match per left row when k = 1. +SELECT COUNT(*) AS num_rows +FROM ( + SELECT u.user_id, p.product + FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY rand() + p.pscore +); + +-- Same with k = 2 to exercise the multi-match path with rand(). +SELECT COUNT(*) AS num_rows +FROM ( + SELECT u.user_id, p.product + FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE rand() + p.pscore +); + +-- EXPLAIN of APPROX + nondeterministic ranking. Locks in the plan shape: the rewrite +-- injects a Project above the Join that materializes `rand(0) + p.pscore` as `__ranking__`, +-- An explicit seed is used so the EXPLAIN string is byte-stable across runs (without it, +-- `rand()` synthesizes a fresh random seed each time and the seed appears in the EXPLAIN). +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY rand(0) + p.pscore; + +-- spark.sql.crossJoin.enabled = false must NOT reject NEAREST BY queries. +-- The synthetic LEFT OUTER cross-join inside the rewrite is recognized structurally +-- by `CheckCartesianProducts` (its parent `Aggregate` contains `MaxMinByK`) and skipped. +SET spark.sql.crossJoin.enabled = false; + +-- Basic NEAREST BY with crossJoin disabled. +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + +-- NEAREST BY with a top-level filter on a right-side column. This exercises the path +-- where filter pushdown / column pruning may run between the rewrite (FinishAnalysis batch) +-- and `CheckCartesianProducts` (a much later batch). +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE p.product != 'C'; + +-- NEAREST BY with a top-level filter on a left-side column. +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE u.user_id > 1; + +-- LEFT OUTER NEAREST BY with crossJoin disabled. +SELECT u.user_id, p.product +FROM users u LEFT OUTER JOIN products p + EXACT NEAREST 1 BY DISTANCE abs(u.score - p.pscore); + +-- EXPLAIN of a query whose left-side predicate (user_id > 1) is pushed down to the left +-- input of the rewrite's synthetic join. Demonstrates that CheckCartesianProducts succeeds +-- AFTER pushdown rules run, and that the rewrite's Aggregate -> Join shape is preserved in +-- the optimized plan. +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE u.user_id > 1; + +-- EXPLAIN of a query whose right-side predicate (p.product != 'C') cannot push below the +-- rewrite's Generate(inline) and stays above it. Demonstrates that the optimizer pipeline +-- runs end-to-end without CheckCartesianProducts rejecting the synthetic join. +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE p.product != 'C'; + +SET spark.sql.crossJoin.enabled = true; + +-- Tie behavior: when multiple right rows have equal ranking values for a given left row, +-- MaxMinByK breaks ties arbitrarily (the SPIP marks tie-break as unspecified). We can't +-- pin specific rows, but the operator must still return exactly `numResults` matches per +-- left row when enough candidates exist. +CREATE OR REPLACE TEMP VIEW tied_products(product, pscore) + AS VALUES ('A', 10.0), ('B', 10.0), ('C', 10.0); + +SELECT u.user_id, COUNT(*) AS num_matches +FROM users u JOIN tied_products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +GROUP BY u.user_id +ORDER BY u.user_id; + +DROP VIEW tied_products; DROP VIEW users; DROP VIEW products; diff --git a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out index afaf4fbdb9e24..44fb44af8c4c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out @@ -78,6 +78,19 @@ struct 3 B +-- !query +SELECT a.user_id AS a_id, b.user_id AS b_id +FROM users a JOIN users b + APPROX NEAREST 1 BY DISTANCE abs(a.score - b.score) +ORDER BY a.user_id, b.user_id +-- !query schema +struct +-- !query output +1 1 +2 2 +3 3 + + -- !query SELECT u.user_id, p.product FROM users u RIGHT OUTER JOIN products p @@ -204,6 +217,215 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +SELECT COUNT(*) AS num_rows +FROM ( + SELECT u.user_id, p.product + FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY rand() + p.pscore +) +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT COUNT(*) AS num_rows +FROM ( + SELECT u.user_id, p.product + FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE rand() + p.pscore +) +-- !query schema +struct +-- !query output +6 + + +-- !query +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY rand(0) + p.pscore +-- !query schema +struct +-- !query output +== Physical Plan == +AdaptiveSparkPlan isFinalPlan=false ++- Project [user_id#x, product#x] + +- Generate inline(__nearest_matches__#x), [user_id#x], false, [product#x, pscore#x] + +- Filter ((size(__nearest_matches__#x, false) > 0) AND isnotnull(__nearest_matches__#x)) + +- SortAggregate(key=[__qid#xL], functions=[first(user_id#x, false), max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)]) + +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(__qid#xL, 4), ENSURE_REQUIREMENTS, [plan_id=x] + +- SortAggregate(key=[__qid#xL], functions=[partial_first(user_id#x, false), partial_max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)]) + +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- Project [user_id#x, __qid#xL, product#x, pscore#x, (rand(0) + cast(pscore#x as double)) AS __ranking__#x] + +- BroadcastNestedLoopJoin BuildRight, LeftOuter + :- Project [col1#x AS user_id#x, monotonically_increasing_id() AS __qid#xL] + : +- LocalTableScan [col1#x, col2#x] + +- BroadcastExchange IdentityBroadcastMode, [plan_id=x] + +- Project [col1#x AS product#x, col2#x AS pscore#x] + +- LocalTableScan [col1#x, col2#x] + + +-- !query +SET spark.sql.crossJoin.enabled = false +-- !query schema +struct +-- !query output +spark.sql.crossJoin.enabled false + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 A +2 B +3 B + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE p.product != 'C' +-- !query schema +struct +-- !query output +1 A +2 A +2 B +3 A +3 B + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE u.user_id > 1 +-- !query schema +struct +-- !query output +2 A +2 B +3 A +3 B + + +-- !query +SELECT u.user_id, p.product +FROM users u LEFT OUTER JOIN products p + EXACT NEAREST 1 BY DISTANCE abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 A +2 B +3 B + + +-- !query +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE u.user_id > 1 +-- !query schema +struct +-- !query output +== Physical Plan == +AdaptiveSparkPlan isFinalPlan=false ++- Project [user_id#x, product#x] + +- Generate inline(__nearest_matches__#x), [user_id#x], false, [product#x, pscore#x] + +- Filter ((size(__nearest_matches__#x, false) > 0) AND isnotnull(__nearest_matches__#x)) + +- SortAggregate(key=[__qid#xL], functions=[first(user_id#x, false), min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(__qid#xL, 4), ENSURE_REQUIREMENTS, [plan_id=x] + +- SortAggregate(key=[__qid#xL], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- BroadcastNestedLoopJoin BuildRight, LeftOuter + :- Filter (user_id#x > 1) + : +- Project [col1#x AS user_id#x, col2#x AS score#x, monotonically_increasing_id() AS __qid#xL] + : +- LocalTableScan [col1#x, col2#x] + +- BroadcastExchange IdentityBroadcastMode, [plan_id=x] + +- Project [col1#x AS product#x, col2#x AS pscore#x] + +- LocalTableScan [col1#x, col2#x] + + +-- !query +EXPLAIN +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE p.product != 'C' +-- !query schema +struct +-- !query output +== Physical Plan == +AdaptiveSparkPlan isFinalPlan=false ++- Project [user_id#x, product#x] + +- Filter (isnotnull(product#x) AND NOT (product#x = C)) + +- Generate inline(__nearest_matches__#x), [user_id#x], false, [product#x, pscore#x] + +- Filter ((size(__nearest_matches__#x, false) > 0) AND isnotnull(__nearest_matches__#x)) + +- SortAggregate(key=[__qid#xL], functions=[first(user_id#x, false), min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(__qid#xL, 4), ENSURE_REQUIREMENTS, [plan_id=x] + +- SortAggregate(key=[__qid#xL], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- BroadcastNestedLoopJoin BuildRight, LeftOuter + :- Project [col1#x AS user_id#x, col2#x AS score#x, monotonically_increasing_id() AS __qid#xL] + : +- LocalTableScan [col1#x, col2#x] + +- BroadcastExchange IdentityBroadcastMode, [plan_id=x] + +- Project [col1#x AS product#x, col2#x AS pscore#x] + +- LocalTableScan [col1#x, col2#x] + + +-- !query +SET spark.sql.crossJoin.enabled = true +-- !query schema +struct +-- !query output +spark.sql.crossJoin.enabled true + + +-- !query +CREATE OR REPLACE TEMP VIEW tied_products(product, pscore) + AS VALUES ('A', 10.0), ('B', 10.0), ('C', 10.0) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT u.user_id, COUNT(*) AS num_matches +FROM users u JOIN tied_products p + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +GROUP BY u.user_id +ORDER BY u.user_id +-- !query schema +struct +-- !query output +1 2 +2 2 +3 2 + + +-- !query +DROP VIEW tied_products +-- !query schema +struct<> +-- !query output + + + -- !query DROP VIEW users -- !query schema From b2e11eefb645e96a075c95d116e25f00d569ad0d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 4 May 2026 12:48:43 -0700 Subject: [PATCH 3/7] Code review 2. --- .../resources/error/error-conditions.json | 6 +- docs/sql-ref-syntax-qry-select-join.md | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 30 +---- .../optimizer/RewriteNearestByJoin.scala | 29 +++-- .../sql/catalyst/parser/AstBuilder.scala | 65 ++++++---- .../plans/logical/NearestByJoin.scala | 98 ++++++++++++++ .../plans/logical/basicLogicalOperators.scala | 54 -------- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../analyzer-results/join-nearest-by.sql.out | 121 +++++++----------- .../sql-tests/inputs/join-nearest-by.sql | 59 ++++----- .../sql-tests/results/join-nearest-by.sql.out | 111 ++++++++-------- 11 files changed, 290 insertions(+), 288 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 9eba289660505..36df08a5744cb 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5325,17 +5325,17 @@ }, "NON_ORDERABLE_RANKING_EXPRESSION" : { "message" : [ - "The ranking expression of type is not orderable." + "The ranking expression of type is not orderable. Provide an expression that returns an orderable type, such as a numeric distance like abs(a.col - b.col) or a numeric similarity score." ] }, "NUM_RESULTS_OUT_OF_RANGE" : { "message" : [ - "The number of results must be between and ." + "The number of results must be between and . Update the literal in `APPROX NEAREST BY ...` (or `EXACT NEAREST BY ...`) to fall within that range." ] }, "STREAMING_NOT_SUPPORTED" : { "message" : [ - "nearest-by join is not supported with streaming DataFrames/Datasets." + "Nearest-by join is not supported with streaming DataFrames/Datasets." ] }, "UNSUPPORTED_DIRECTION" : { diff --git a/docs/sql-ref-syntax-qry-select-join.md b/docs/sql-ref-syntax-qry-select-join.md index 68fb6eda9353e..a082a13707bdd 100644 --- a/docs/sql-ref-syntax-qry-select-join.md +++ b/docs/sql-ref-syntax-qry-select-join.md @@ -69,11 +69,11 @@ relation { [ join_type ] JOIN [ LATERAL ] relation [ join_criteria | nearest_by_ `DISTANCE | SIMILARITY` - `DISTANCE` ranks rows by smallest value of `ranking_expression` first. `SIMILARITY` ranks rows by largest value first. + `DISTANCE` ranks rows by smallest value of `ranking_expression` first. `SIMILARITY` ranks rows by largest value first. Matched right-side rows are emitted in best-first order: smallest ranking value first under `DISTANCE`, largest first under `SIMILARITY`. (Downstream operators may reorder; add an explicit `ORDER BY` if you need to lock in the ordering.) `ranking_expression` - A scalar expression that returns an orderable type. + A scalar expression that returns an orderable type. Must be deterministic with `EXACT`; may be nondeterministic with `APPROX` (e.g., `rand()` for randomized tie-breaking). The expression is evaluated once per (left, right) pair on the brute-force path, so avoid expensive or side-effecting UDFs in ranking expressions. **Performance note.** The current implementation evaluates the full cross-product of the left and right sides and bounds memory per left row by `num_results`. Per-query work is `O(|left| × |right| × log num_results)`. Index-backed approximate strategies (transparent to `APPROX` queries) are planned in a future release; until then, pre-filter the right side (e.g. via a subquery) when it is large. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 97bfd799a4c6d..e4d53b697af80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2542,36 +2542,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { } } - def apply(plan: LogicalPlan): LogicalPlan = { + def apply(plan: LogicalPlan): LogicalPlan = if (conf.crossJoinEnabled) { - return plan - } - - // Joins synthesized by `RewriteNearestByJoin` are an intentional, bounded cross-product - // wrapped by a `MaxMinByK` aggregate. Identify them by their unambiguous post-rewrite - // signature -- `Aggregate(_, exprs, Join(_, _, LeftOuter, None, _))` where `exprs` - // contains a `MaxMinByK` -- and skip them so user queries written as `NEAREST BY` are not - // rejected when `spark.sql.crossJoin.enabled = false`. We use structural detection rather - // than a `TreeNodeTag` because a tag set on the `Join` would be silently dropped by any - // intervening optimizer rule that constructs a fresh `Join` via the case-class - // constructor without calling `copyTagsFrom`. - val nearestByJoins: java.util.IdentityHashMap[Join, Unit] = { - val acc = new java.util.IdentityHashMap[Join, Unit]() - plan.foreach { - case Aggregate(_, exprs, j @ Join(_, _, LeftOuter, None, _), _) - if exprs.exists(_.exists(_.isInstanceOf[MaxMinByK])) => - acc.put(j, ()) - case _ => - } - acc - } - - plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, OUTER_JOIN)) { + plan + } else plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, OUTER_JOIN)) { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _) - if isCartesianProduct(j) && !nearestByJoins.containsKey(j) => + if isCartesianProduct(j) => throw QueryCompilationErrors.joinConditionMissingOrTrivialError(j, left, right) } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala index 1ccf19ebcfb0b..fabd3203f3c98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala @@ -50,6 +50,11 @@ import org.apache.spark.sql.catalyst.rules._ * constructed with `outer = true` so left rows with no matches (empty/null `_matches`) are * preserved with `NULL` right-side columns. * + * The matches in `_matches` are produced by `MaxMinByK` ordered by the ranking value: best + * match first (largest ranking value for `SIMILARITY`, smallest for `DISTANCE`). `Inline` + * preserves array order, so the K rows emitted per left row appear best-first in the output + * of this rule. (Downstream operators may reorder.) + * * If `rankingExpression` is nondeterministic (legal only under `APPROX`), an extra * `Project` is inserted above the `Join` to materialize the value as `__ranking__`. The * standard projection machinery runs `Nondeterministic.initialize(partitionIndex)` on every @@ -63,7 +68,7 @@ import org.apache.spark.sql.catalyst.rules._ * makes the intended shape explicit and avoids round-tripping through subquery decorrelation. */ object RewriteNearestByJoin extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case j @ NearestByJoin(left, right, joinType, _, numResults, rankingExpression, direction) => // 1. Tag each left row with a unique id so that rows from the same left row can later be // grouped together after the cross-join with `right`. @@ -77,10 +82,11 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { // columns after the aggregate + inline below. When `right` is non-empty every left // row already has right-row pairings, so LEFT OUTER and INNER are equivalent. // - // `CheckCartesianProducts` recognizes this synthetic join structurally (by its - // parent `Aggregate` containing a `MaxMinByK`) and skips it, so user queries - // written as `NEAREST BY` are not rejected when `spark.sql.crossJoin.enabled` is - // false. + // This synthetic join is an unconditioned cross-product, so `NEAREST BY` queries + // are subject to `CheckCartesianProducts` and will be rejected when the user has + // set `spark.sql.crossJoin.enabled = false`. That is intentional: if the user has + // opted out of cross-products, the NEAREST BY rewrite -- which is itself a bounded + // cross-product today -- should not silently bypass that choice. val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE) val (aggInput, rankingForAgg) = if (!rankingExpression.deterministic) { @@ -122,20 +128,21 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { // 4. Generate inline(_matches) expands the K-element array into K rows, exposing each // struct field as a top-level column. `outer = true` for LEFT OUTER preserves the - // left row with NULL right columns when there are no matches. + // left row with NULL right columns when there are no matches. Preserving the right + // side's `ExprId`s in `generatorOutput` (rather than allocating fresh ones) keeps + // `generate.output` byte-for-byte equivalent to `j.output` -- which already used + // those ExprIds with `nullable = true` -- so parent-operator references continue to + // resolve naturally and the rule can use plain `transformUp` without an attrMapping. val generatorOutput = right.output.map { a => AttributeReference(a.name, a.dataType, nullable = true, a.metadata)( - qualifier = a.qualifier) + exprId = a.exprId, qualifier = a.qualifier) } - val generate = Generate( + Generate( Inline(matchesAlias.toAttribute), unrequiredChildIndex = Seq(aggregate.output.indexOf(matchesAlias.toAttribute)), outer = joinType == LeftOuter, qualifier = None, generatorOutput = generatorOutput, child = aggregate) - - val attrMapping = j.output.zip(generate.output) - generate -> attrMapping } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c233f70cc7b33..929fb2b4ceb15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2363,32 +2363,7 @@ class AstBuilder extends DataTypeAstBuilder } if (ctx.nearestByClause != null) { - if (ctx.LATERAL != null) { - throw QueryParsingErrors.nearestByJoinWithLateralUnsupportedError(ctx) - } - if (!Seq(Inner, LeftOuter).contains(baseJoinType)) { - throw QueryParsingErrors.unsupportedNearestByJoinTypeError( - ctx, baseJoinType.sql, NearestByJoinType.supportedDisplay) - } - val clause = ctx.nearestByClause - val approx = clause.APPROX != null - val numResults = Option(clause.num).map { n => - // Guard against literals that overflow Long. - val value = try n.getText.toLong catch { - case _: NumberFormatException => - throw QueryParsingErrors.nearestByJoinNumResultsOutOfRangeError( - ctx, n.getText, NearestByJoin.MaxNumResults) - } - if (value < 1 || value > NearestByJoin.MaxNumResults) { - throw QueryParsingErrors.nearestByJoinNumResultsOutOfRangeError( - ctx, value.toString, NearestByJoin.MaxNumResults) - } - value.toInt - }.getOrElse(1) - val direction = if (clause.DISTANCE != null) NearestByDistance else NearestBySimilarity - val rankingExpr = expression(clause.expression) - NearestByJoin( - base, plan(ctx.right), baseJoinType, approx, numResults, rankingExpr, direction) + withNearestByJoin(ctx, base, baseJoinType) } else { // Resolve the join type and join condition val (joinType, condition) = Option(ctx.joinCriteria) match { @@ -2428,6 +2403,44 @@ class AstBuilder extends DataTypeAstBuilder } } + /** + * Build a [[NearestByJoin]] from the parsed `NEAREST BY` clause attached to a join relation. + * Validates that the clause is not combined with `LATERAL` and that the base join type is one + * of the supported types (`INNER` or `LEFT OUTER`), parses `num_results` (with bounds checks), + * the direction (`DISTANCE` / `SIMILARITY`), and the ranking expression. + */ + private def withNearestByJoin( + ctx: JoinRelationContext, + base: LogicalPlan, + baseJoinType: JoinType): NearestByJoin = { + if (ctx.LATERAL != null) { + throw QueryParsingErrors.nearestByJoinWithLateralUnsupportedError(ctx) + } + if (!Seq(Inner, LeftOuter).contains(baseJoinType)) { + throw QueryParsingErrors.unsupportedNearestByJoinTypeError( + ctx, baseJoinType.sql, NearestByJoinType.supportedDisplay) + } + val clause = ctx.nearestByClause + val approx = clause.APPROX != null + val numResults = Option(clause.num).map { n => + // Guard against literals that overflow Long. + val value = try n.getText.toLong catch { + case _: NumberFormatException => + throw QueryParsingErrors.nearestByJoinNumResultsOutOfRangeError( + ctx, n.getText, NearestByJoin.MaxNumResults) + } + if (value < 1 || value > NearestByJoin.MaxNumResults) { + throw QueryParsingErrors.nearestByJoinNumResultsOutOfRangeError( + ctx, value.toString, NearestByJoin.MaxNumResults) + } + value.toInt + }.getOrElse(1) + val direction = if (clause.DISTANCE != null) NearestByDistance else NearestBySimilarity + val rankingExpr = expression(clause.expression) + NearestByJoin( + base, plan(ctx.right), baseJoinType, approx, numResults, rankingExpr, direction) + } + /** * Add a [[Sample]] to a logical plan. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala new file mode 100644 index 0000000000000..27a92353267c7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, NearestByDirection} +import org.apache.spark.sql.catalyst.trees.TreePattern._ + +object NearestByJoin { + /** Upper bound on `numResults`. Mirrors the K-overload limit of `MaxMinByK`. */ + val MaxNumResults: Int = 100000 +} + +/** + * A logical plan for a nearest-by top-K ranking join. For each row on the left side it returns + * up to `numResults` rows from the right side ordered by `rankingExpression`: + * - `NearestByDistance`: smallest values of `rankingExpression` first. + * - `NearestBySimilarity`: largest values of `rankingExpression` first. + * + * The `approx` field records the user's APPROX/EXACT choice. Today both modes use the same + * brute-force rewrite. The flag is preserved on the logical plan so that future indexed + * approximate-nearest-neighbor strategies can fire only when `approx = true`, leaving EXACT + * queries unaffected. + * + * @param left The left (query) side of the join. + * @param right The right (base) side of the join, against which each left row finds matches. + * @param joinType Must be `Inner` or `LeftOuter`. `Inner` drops left rows with no matches; + * `LeftOuter` preserves them with `NULL` right-side columns. + * @param approx `true` for `APPROX` mode, `false` for `EXACT` mode. `APPROX` permits a + * nondeterministic `rankingExpression` and is the contract future indexed + * approximate-nearest-neighbor strategies key off; `EXACT` requires + * determinism (enforced by `CheckAnalysis`). + * @param numResults The K in top-K: the maximum number of right-side matches returned per + * left row. Bounded above by `NearestByJoin.MaxNumResults`. + * @param rankingExpression Scalar expression evaluated per (left, right) pair. Must return + * an orderable type. Rows are ranked by its value, with ordering + * determined by `direction`. + * @param direction `NearestByDistance` (smaller is better) or `NearestBySimilarity` (larger + * is better). Selects whether the rewrite uses `MIN_BY` or `MAX_BY`. + */ +case class NearestByJoin( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + approx: Boolean, + numResults: Int, + rankingExpression: Expression, + direction: NearestByDirection) + extends BinaryNode with SupportsNonDeterministicExpression { + + require(Seq(Inner, LeftOuter).contains(joinType), + s"Unsupported nearest-by join type $joinType") + + // `APPROX` mode permits a nondeterministic ranking expression (e.g. `rand()` for randomized + // tie-breaking). `EXACT` mode requires determinism, and that requirement is enforced + // separately by the `NEAREST_BY_JOIN.EXACT_WITH_NONDETERMINISTIC_EXPRESSION` arm in + // `CheckAnalysis`. Returning `approx` from this override is what lets APPROX queries pass + // the generic `INVALID_NON_DETERMINISTIC_EXPRESSIONS` check that fires on operators not on + // the analyzer's whitelist. + override def allowNonDeterministicExpression: Boolean = approx + + // Right-side attributes are always declared nullable because the rewrite materializes them + // through `Inline` over `MaxMinByK`'s `ArrayType(.., containsNull = true)`, which widens + // every struct field to nullable. Declaring them nullable here keeps the analyzed schema + // consistent with the optimized plan (and with what users see in cached or written outputs). + override def output: Seq[Attribute] = + left.output ++ right.output.map(_.withNullability(true)) + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + override lazy val resolved: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + duplicateResolved + } + + final override val nodePatterns: Seq[TreePattern] = Seq(NEAREST_BY_JOIN) + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): NearestByJoin = { + copy(left = newLeft, right = newRight) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7fb9f6b13e445..772c3a22b2770 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -2421,57 +2421,3 @@ object AsOfJoin { } } -object NearestByJoin { - /** Upper bound on `numResults`. Mirrors the K-overload limit of `MaxMinByK`. */ - val MaxNumResults: Int = 100000 -} - -/** - * A logical plan for a nearest-by top-K ranking join. For each row on the left side it returns - * up to `numResults` rows from the right side ordered by `rankingExpression`: - * - `NearestByDistance`: smallest values of `rankingExpression` first. - * - `NearestBySimilarity`: largest values of `rankingExpression` first. - * - * The `approx` field records the user's APPROX/EXACT choice from the SPIP. Today both modes - * use the same brute-force rewrite. The flag is preserved on the logical plan so future - * indexed approximate-nearest-neighbor strategies can fire only when `approx = true`, - * leaving EXACT queries unaffected. See the SPIP linked from SPARK-56395. - */ -case class NearestByJoin( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - approx: Boolean, - numResults: Int, - rankingExpression: Expression, - direction: NearestByDirection) - extends BinaryNode with SupportsNonDeterministicExpression { - - require(Seq(Inner, LeftOuter).contains(joinType), - s"Unsupported nearest-by join type $joinType") - - // APPROX permits a nondeterministic ranking expression (per the SPIP); the rewrite - override def allowNonDeterministicExpression: Boolean = approx - - // Right-side attributes are always declared nullable because the rewrite materializes them - // through `Inline` over `MaxMinByK`'s `ArrayType(.., containsNull = true)`, which widens - // every struct field to nullable. Declaring them nullable here keeps the analyzed schema - // consistent with the optimized plan (and with what users see in cached or written outputs). - override def output: Seq[Attribute] = - left.output ++ right.output.map(_.withNullability(true)) - - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - - override lazy val resolved: Boolean = { - childrenResolved && - expressions.forall(_.resolved) && - duplicateResolved - } - - final override val nodePatterns: Seq[TreePattern] = Seq(NEAREST_BY_JOIN) - - override protected def withNewChildrenInternal( - newLeft: LogicalPlan, newRight: LogicalPlan): NearestByJoin = { - copy(left = newLeft, right = newRight) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 7956a9692dc61..a890d43f0672c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -176,6 +176,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin" :: "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: "org.apache.spark.sql.catalyst.optimizer.RewriteAsOfJoin" :: + "org.apache.spark.sql.catalyst.optimizer.RewriteNearestByJoin" :: "org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" :: "org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" :: diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out index 79cbee6001619..57308f6b8ff11 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out @@ -32,6 +32,44 @@ Project [user_id#x, product#x] +- LocalRelation [col1#x, col2#x] +-- !query +SELECT * +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, score#x, product#x, pscore#x] ++- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT * +FROM users u LEFT OUTER JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, score#x, product#x, pscore#x] ++- NearestByJoin LeftOuter, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + -- !query SELECT u.user_id, p.product, p.pscore FROM users u JOIN products p @@ -280,79 +318,38 @@ ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode -- !query -SET spark.sql.crossJoin.enabled = false --- !query analysis -SetCommand (spark.sql.crossJoin.enabled,Some(false)) - - --- !query +EXPLAIN SELECT u.user_id, p.product FROM users u JOIN products p - APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) + APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) +WHERE u.user_id > 1 -- !query analysis -Project [user_id#x, product#x] -+- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity - :- SubqueryAlias u - : +- SubqueryAlias spark_catalog.default.users - : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) - : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] - : +- LocalRelation [col1#x, col2#x] - +- SubqueryAlias p - +- SubqueryAlias spark_catalog.default.products - +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) - +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] - +- LocalRelation [col1#x, col2#x] +ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode -- !query +EXPLAIN SELECT u.user_id, p.product FROM users u JOIN products p APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) WHERE p.product != 'C' -- !query analysis -Project [user_id#x, product#x] -+- Filter NOT (product#x = C) - +- NearestByJoin Inner, true, 2, abs((score#x - pscore#x)), NearestByDistance - :- SubqueryAlias u - : +- SubqueryAlias spark_catalog.default.users - : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) - : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] - : +- LocalRelation [col1#x, col2#x] - +- SubqueryAlias p - +- SubqueryAlias spark_catalog.default.products - +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) - +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] - +- LocalRelation [col1#x, col2#x] +ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode -- !query -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) -WHERE u.user_id > 1 +SET spark.sql.crossJoin.enabled = false -- !query analysis -Project [user_id#x, product#x] -+- Filter (user_id#x > 1) - +- NearestByJoin Inner, true, 2, abs((score#x - pscore#x)), NearestByDistance - :- SubqueryAlias u - : +- SubqueryAlias spark_catalog.default.users - : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) - : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] - : +- LocalRelation [col1#x, col2#x] - +- SubqueryAlias p - +- SubqueryAlias spark_catalog.default.products - +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) - +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] - +- LocalRelation [col1#x, col2#x] +SetCommand (spark.sql.crossJoin.enabled,Some(false)) -- !query SELECT u.user_id, p.product -FROM users u LEFT OUTER JOIN products p - EXACT NEAREST 1 BY DISTANCE abs(u.score - p.pscore) +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) -- !query analysis Project [user_id#x, product#x] -+- NearestByJoin LeftOuter, false, 1, abs((score#x - pscore#x)), NearestByDistance ++- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity :- SubqueryAlias u : +- SubqueryAlias spark_catalog.default.users : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) @@ -365,26 +362,6 @@ Project [user_id#x, product#x] +- LocalRelation [col1#x, col2#x] --- !query -EXPLAIN -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) -WHERE u.user_id > 1 --- !query analysis -ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode - - --- !query -EXPLAIN -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) -WHERE p.product != 'C' --- !query analysis -ExplainCommand 'Project ['u.user_id, 'p.product], SimpleMode - - -- !query SET spark.sql.crossJoin.enabled = true -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql index ad6506cd50191..20b9b2fb73169 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql @@ -8,6 +8,19 @@ SELECT u.user_id, p.product FROM users u JOIN products p APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); +-- SELECT * to validate the output schema. Must surface only the user-visible columns from +-- left and right (`user_id`, `score`, `product`, `pscore`) -- no rewrite-internal columns +-- (`__qid`, `__nearest_matches__`, `__ranking__`) and no Generator-aliased names. +SELECT * +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + +-- Same schema check but for LEFT OUTER. Right-side columns are nullable in this mode (left +-- rows with no matches surface as NULL); the schema still must not leak internal columns. +SELECT * +FROM users u LEFT OUTER JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + -- APPROX NEAREST BY DISTANCE with k = 2 SELECT u.user_id, p.product, p.pscore FROM users u JOIN products p @@ -87,39 +100,9 @@ SELECT u.user_id, p.product FROM users u JOIN products p APPROX NEAREST 1 BY SIMILARITY rand(0) + p.pscore; --- spark.sql.crossJoin.enabled = false must NOT reject NEAREST BY queries. --- The synthetic LEFT OUTER cross-join inside the rewrite is recognized structurally --- by `CheckCartesianProducts` (its parent `Aggregate` contains `MaxMinByK`) and skipped. -SET spark.sql.crossJoin.enabled = false; - --- Basic NEAREST BY with crossJoin disabled. -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); - --- NEAREST BY with a top-level filter on a right-side column. This exercises the path --- where filter pushdown / column pruning may run between the rewrite (FinishAnalysis batch) --- and `CheckCartesianProducts` (a much later batch). -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) -WHERE p.product != 'C'; - --- NEAREST BY with a top-level filter on a left-side column. -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) -WHERE u.user_id > 1; - --- LEFT OUTER NEAREST BY with crossJoin disabled. -SELECT u.user_id, p.product -FROM users u LEFT OUTER JOIN products p - EXACT NEAREST 1 BY DISTANCE abs(u.score - p.pscore); - -- EXPLAIN of a query whose left-side predicate (user_id > 1) is pushed down to the left --- input of the rewrite's synthetic join. Demonstrates that CheckCartesianProducts succeeds --- AFTER pushdown rules run, and that the rewrite's Aggregate -> Join shape is preserved in --- the optimized plan. +-- input of the rewrite's synthetic join. Demonstrates that pushdown rules walk through +-- the rewrite's Generate -> Aggregate -> Join shape and reach the underlying left input. EXPLAIN SELECT u.user_id, p.product FROM users u JOIN products p @@ -128,13 +111,23 @@ WHERE u.user_id > 1; -- EXPLAIN of a query whose right-side predicate (p.product != 'C') cannot push below the -- rewrite's Generate(inline) and stays above it. Demonstrates that the optimizer pipeline --- runs end-to-end without CheckCartesianProducts rejecting the synthetic join. +-- runs end-to-end and the rewrite's plan shape (Generate over Aggregate over Join) survives +-- to physical planning even when a top-level filter cannot be pushed into it. EXPLAIN SELECT u.user_id, p.product FROM users u JOIN products p APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) WHERE p.product != 'C'; +-- The rewrite produces an unconditioned cross-product internally. When the user has opted +-- out of cross-products via `spark.sql.crossJoin.enabled = false`, NEAREST BY queries are +-- rejected by `CheckCartesianProducts` -- the rewrite does not bypass the user's choice. +SET spark.sql.crossJoin.enabled = false; + +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + SET spark.sql.crossJoin.enabled = true; -- Tie behavior: when multiple right rows have equal ranking values for a given left row, diff --git a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out index 44fb44af8c4c9..3a872cab39665 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out @@ -27,6 +27,30 @@ struct 3 B +-- !query +SELECT * +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 10.0 A 11.0 +2 20.0 B 22.0 +3 30.0 B 22.0 + + +-- !query +SELECT * +FROM users u LEFT OUTER JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct +-- !query output +1 10.0 A 11.0 +2 20.0 B 22.0 +3 30.0 B 22.0 + + -- !query SELECT u.user_id, p.product, p.pscore FROM users u JOIN products p @@ -270,67 +294,6 @@ AdaptiveSparkPlan isFinalPlan=false +- LocalTableScan [col1#x, col2#x] --- !query -SET spark.sql.crossJoin.enabled = false --- !query schema -struct --- !query output -spark.sql.crossJoin.enabled false - - --- !query -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) --- !query schema -struct --- !query output -1 A -2 B -3 B - - --- !query -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) -WHERE p.product != 'C' --- !query schema -struct --- !query output -1 A -2 A -2 B -3 A -3 B - - --- !query -SELECT u.user_id, p.product -FROM users u JOIN products p - APPROX NEAREST 2 BY DISTANCE abs(u.score - p.pscore) -WHERE u.user_id > 1 --- !query schema -struct --- !query output -2 A -2 B -3 A -3 B - - --- !query -SELECT u.user_id, p.product -FROM users u LEFT OUTER JOIN products p - EXACT NEAREST 1 BY DISTANCE abs(u.score - p.pscore) --- !query schema -struct --- !query output -1 A -2 B -3 B - - -- !query EXPLAIN SELECT u.user_id, p.product @@ -387,6 +350,32 @@ AdaptiveSparkPlan isFinalPlan=false +- LocalTableScan [col1#x, col2#x] +-- !query +SET spark.sql.crossJoin.enabled = false +-- !query schema +struct +-- !query output +spark.sql.crossJoin.enabled false + + +-- !query +SELECT u.user_id, p.product +FROM users u JOIN products p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1211", + "messageParameters" : { + "joinType" : "LEFT OUTER", + "leftPlan" : "Project [col1#x AS user_id#x, col2#x AS score#x, monotonically_increasing_id() AS __qid#xL]\n+- LocalRelation [col1#x, col2#x]", + "rightPlan" : "Project [col1#x AS product#x, col2#x AS pscore#x]\n+- LocalRelation [col1#x, col2#x]" + } +} + + -- !query SET spark.sql.crossJoin.enabled = true -- !query schema From 808affff06a016d432222cde026bd5c15daf7cd5 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 4 May 2026 15:25:40 -0700 Subject: [PATCH 4/7] Code review 3. --- .../plans/logical/NearestByJoin.scala | 14 ++++++---- .../optimizer/RewriteNearestByJoinSuite.scala | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala index 27a92353267c7..9df79ba128b8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala @@ -74,12 +74,16 @@ case class NearestByJoin( // the analyzer's whitelist. override def allowNonDeterministicExpression: Boolean = approx - // Right-side attributes are always declared nullable because the rewrite materializes them - // through `Inline` over `MaxMinByK`'s `ArrayType(.., containsNull = true)`, which widens - // every struct field to nullable. Declaring them nullable here keeps the analyzed schema - // consistent with the optimized plan (and with what users see in cached or written outputs). + // Both left- and right-side attributes are declared nullable to match the schema produced + // by `RewriteNearestByJoin`. Right-side attributes are widened because the rewrite + // materializes them through `Inline` over `MaxMinByK`'s `ArrayType(.., containsNull = true)`, + // which widens every struct field to nullable. Left-side attributes are widened because the + // rewrite carries each left column through a `First` aggregate, whose result type is always + // nullable (`First` may return `null` for empty groups). Declaring both nullable here keeps + // the analyzed schema consistent with the optimized plan (and with what users see in cached + // or written outputs). override def output: Seq[Attribute] = - left.output ++ right.output.map(_.withNullability(true)) + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala index a9132f5f9d4a9..580d0b07066ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cre import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, First, MaxMinByK} import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, JoinHint, LocalRelation, NearestByJoin, Project} +import org.apache.spark.sql.types.IntegerType class RewriteNearestByJoinSuite extends PlanTest { @@ -253,4 +254,30 @@ class RewriteNearestByJoinSuite extends PlanTest { s"expected Aggregate's child to be the Join directly when ranking is deterministic, " + s"got ${agg.child.getClass.getSimpleName}") } + + test("output declares both left- and right-side attributes nullable") { + // The rewrite carries left columns through `First` aggregates (always nullable result type) + // and right columns through `Inline` over `MaxMinByK`'s `ArrayType(.., containsNull = true)` + // (every struct field becomes nullable). NearestByJoin.output must reflect both widenings + // so the analyzed schema matches the optimized plan; otherwise cached / written outputs + // would advertise a stricter nullability than the data actually carries. + val left = LocalRelation( + AttributeReference("a", IntegerType, nullable = false)(), + AttributeReference("b", IntegerType, nullable = false)()) + val right = LocalRelation( + AttributeReference("x", IntegerType, nullable = false)(), + AttributeReference("y", IntegerType, nullable = false)()) + val query = NearestByJoin( + left, right, Inner, approx = true, numResults = 1, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + assert(left.output.forall(!_.nullable), + "preconditions: left input attributes should start non-nullable") + assert(right.output.forall(!_.nullable), + "preconditions: right input attributes should start non-nullable") + assert(query.output.forall(_.nullable), + "NearestByJoin.output should declare every attribute nullable, regardless of the " + + "nullability of the underlying inputs") + } } From 8719eb151b96743671954e5f3c966994c1ddc810 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 4 May 2026 16:27:25 -0700 Subject: [PATCH 5/7] Code review 4. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 18 ++++----- .../optimizer/RewriteNearestByJoin.scala | 6 ++- .../optimizer/RewriteNearestByJoinSuite.scala | 28 ++++++++----- .../sql-tests/results/join-nearest-by.sql.out | 40 +++++++++---------- .../apache/spark/sql/SQLQueryTestHelper.scala | 4 +- 5 files changed, 54 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index e231fe20d6186..aef65c7532f82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -657,15 +657,15 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString messageParameters = Map.empty) } - // Reject streaming inputs early. The optimizer rewrite groups by a `__qid` derived - // from `MonotonicallyIncreasingID()` and feeds it to a global `Aggregate`, which - // Spark turns into a stateful streaming aggregation. Because MID restarts per - // micro-batch, `__qid` values collide across batches, and the stateful aggregate - // silently merges state from old batches into new rows that share the same key -- - // producing wrong top-K results. Failing at analysis time is clearer than letting - // this slip through. Streaming support is tracked as a follow-up; resolving it does - // not require streaming-aware MID and is likely to come from a different grouping - // strategy or a dedicated physical operator. + // Reject streaming inputs early. The optimizer rewrite is built around an + // unconditioned cross-product fed into a global `Aggregate` keyed by a per-row + // identifier (`__qid`). That shape doesn't compose cleanly with structured-streaming + // semantics: a stateful aggregate keyed by a freshly-generated identifier accumulates + // state indefinitely (every batch creates new keys, old keys never match again) and a + // cross-product against a streaming right side has no bounded state model today. + // Failing at analysis time is clearer than letting either fail at runtime. Streaming + // support is tracked as a follow-up; resolving it likely comes from a different + // grouping strategy or a dedicated physical operator. case j: NearestByJoin if j.isStreaming => j.failAnalysis( errorClass = "NEAREST_BY_JOIN.STREAMING_NOT_SUPPORTED", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala index fabd3203f3c98..8119d13ccc680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.rules._ * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS left.colN-1, * max_by(struct(right.*), expr, k) AS _matches] * +- Join LeftOuter - * :- Project [left.*, monotonically_increasing_id() AS __qid] + * :- Project [left.*, uuid() AS __qid] * : +- left * +- right * }}} @@ -68,11 +68,13 @@ import org.apache.spark.sql.catalyst.rules._ * makes the intended shape explicit and avoids round-tripping through subquery decorrelation. */ object RewriteNearestByJoin extends Rule[LogicalPlan] { + private lazy val random = new scala.util.Random() + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case j @ NearestByJoin(left, right, joinType, _, numResults, rankingExpression, direction) => // 1. Tag each left row with a unique id so that rows from the same left row can later be // grouped together after the cross-join with `right`. - val qidAlias = Alias(MonotonicallyIncreasingID(), "__qid")() + val qidAlias = Alias(Uuid(Some(random.nextLong())), "__qid")() val taggedLeft = Project(left.output :+ qidAlias, left) val qidAttr = qidAlias.toAttribute diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala index 580d0b07066ad..5f41d5e0ad001 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, MonotonicallyIncreasingID, Rand} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, Rand, Uuid} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, First, MaxMinByK} import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, JoinHint, LocalRelation, NearestByJoin, Project} @@ -27,6 +27,14 @@ import org.apache.spark.sql.types.IntegerType class RewriteNearestByJoinSuite extends PlanTest { + // The rewrite synthesizes `Uuid(Some())` for `__qid`, whose seed is fresh per call; + // expected plans below use `Uuid(Some(0L))`, and we normalize the actual plan's `Uuid` + // seeds to 0L before `comparePlans` so the structural shape is the only thing being + // compared, not the (necessarily different) random seed values. + private def normalizeUuidSeed(plan: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan) + : org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = + plan.transformAllExpressions { case _: Uuid => Uuid(Some(0L)) } + private def expectedRewrite( left: LocalRelation, right: LocalRelation, @@ -34,7 +42,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking: org.apache.spark.sql.catalyst.expressions.Expression, reverse: Boolean, outer: Boolean) = { - val qidAlias = Alias(MonotonicallyIncreasingID(), "__qid")() + val qidAlias = Alias(Uuid(Some(0L)), "__qid")() val taggedLeft = Project(left.output :+ qidAlias, left) val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE) @@ -77,7 +85,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = left.output(0) + right.output(0), reverse = false, outer = false) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("distance, inner, k=3") { @@ -94,7 +102,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = left.output(0) - right.output(0), reverse = true, outer = false) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("similarity, left outer, k=1") { @@ -111,7 +119,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = left.output(0) + right.output(0), reverse = false, outer = true) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("distance, left outer, k=2") { @@ -128,7 +136,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = left.output(0) - right.output(0), reverse = true, outer = true) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("EXACT (approx = false) produces the same rewrite as APPROX") { @@ -148,7 +156,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = left.output(0) + right.output(0), reverse = false, outer = false) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("k = 1 (lower boundary)") { @@ -165,7 +173,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = left.output(0) + right.output(0), reverse = false, outer = false) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("k = NearestByJoin.MaxNumResults (upper boundary)") { @@ -182,7 +190,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = left.output(0) + right.output(0), reverse = false, outer = false) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("self-join: rewrite resolves duplicate ExprIds via DeduplicateRelations") { @@ -202,7 +210,7 @@ class RewriteNearestByJoinSuite extends PlanTest { ranking = t.output(0) + tDup.output(0), reverse = false, outer = false) - comparePlans(rewritten, expected, checkAnalysis = false) + comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } test("APPROX with nondeterministic ranking pre-materializes via Project") { diff --git a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out index 3a872cab39665..ac515e98dbd6e 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out @@ -280,14 +280,14 @@ AdaptiveSparkPlan isFinalPlan=false +- Project [user_id#x, product#x] +- Generate inline(__nearest_matches__#x), [user_id#x], false, [product#x, pscore#x] +- Filter ((size(__nearest_matches__#x, false) > 0) AND isnotnull(__nearest_matches__#x)) - +- SortAggregate(key=[__qid#xL], functions=[first(user_id#x, false), max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)]) - +- Sort [__qid#xL ASC NULLS FIRST], false, 0 - +- Exchange hashpartitioning(__qid#xL, 4), ENSURE_REQUIREMENTS, [plan_id=x] - +- SortAggregate(key=[__qid#xL], functions=[partial_first(user_id#x, false), partial_max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)]) - +- Sort [__qid#xL ASC NULLS FIRST], false, 0 - +- Project [user_id#x, __qid#xL, product#x, pscore#x, (rand(0) + cast(pscore#x as double)) AS __ranking__#x] + +- SortAggregate(key=[__qid#x], functions=[first(user_id#x, false), max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)]) + +- Sort [__qid#x ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, [plan_id=x] + +- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)]) + +- Sort [__qid#x ASC NULLS FIRST], false, 0 + +- Project [user_id#x, __qid#x, product#x, pscore#x, (rand(0) + cast(pscore#x as double)) AS __ranking__#x] +- BroadcastNestedLoopJoin BuildRight, LeftOuter - :- Project [col1#x AS user_id#x, monotonically_increasing_id() AS __qid#xL] + :- Project [col1#x AS user_id#x, uuid(Some(x)) AS __qid#x] : +- LocalTableScan [col1#x, col2#x] +- BroadcastExchange IdentityBroadcastMode, [plan_id=x] +- Project [col1#x AS product#x, col2#x AS pscore#x] @@ -308,14 +308,14 @@ AdaptiveSparkPlan isFinalPlan=false +- Project [user_id#x, product#x] +- Generate inline(__nearest_matches__#x), [user_id#x], false, [product#x, pscore#x] +- Filter ((size(__nearest_matches__#x, false) > 0) AND isnotnull(__nearest_matches__#x)) - +- SortAggregate(key=[__qid#xL], functions=[first(user_id#x, false), min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) - +- Sort [__qid#xL ASC NULLS FIRST], false, 0 - +- Exchange hashpartitioning(__qid#xL, 4), ENSURE_REQUIREMENTS, [plan_id=x] - +- SortAggregate(key=[__qid#xL], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) - +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- SortAggregate(key=[__qid#x], functions=[first(user_id#x, false), min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#x ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, [plan_id=x] + +- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#x ASC NULLS FIRST], false, 0 +- BroadcastNestedLoopJoin BuildRight, LeftOuter :- Filter (user_id#x > 1) - : +- Project [col1#x AS user_id#x, col2#x AS score#x, monotonically_increasing_id() AS __qid#xL] + : +- Project [col1#x AS user_id#x, col2#x AS score#x, uuid(Some(x)) AS __qid#x] : +- LocalTableScan [col1#x, col2#x] +- BroadcastExchange IdentityBroadcastMode, [plan_id=x] +- Project [col1#x AS product#x, col2#x AS pscore#x] @@ -337,13 +337,13 @@ AdaptiveSparkPlan isFinalPlan=false +- Filter (isnotnull(product#x) AND NOT (product#x = C)) +- Generate inline(__nearest_matches__#x), [user_id#x], false, [product#x, pscore#x] +- Filter ((size(__nearest_matches__#x, false) > 0) AND isnotnull(__nearest_matches__#x)) - +- SortAggregate(key=[__qid#xL], functions=[first(user_id#x, false), min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) - +- Sort [__qid#xL ASC NULLS FIRST], false, 0 - +- Exchange hashpartitioning(__qid#xL, 4), ENSURE_REQUIREMENTS, [plan_id=x] - +- SortAggregate(key=[__qid#xL], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) - +- Sort [__qid#xL ASC NULLS FIRST], false, 0 + +- SortAggregate(key=[__qid#x], functions=[first(user_id#x, false), min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#x ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, [plan_id=x] + +- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) + +- Sort [__qid#x ASC NULLS FIRST], false, 0 +- BroadcastNestedLoopJoin BuildRight, LeftOuter - :- Project [col1#x AS user_id#x, col2#x AS score#x, monotonically_increasing_id() AS __qid#xL] + :- Project [col1#x AS user_id#x, col2#x AS score#x, uuid(Some(x)) AS __qid#x] : +- LocalTableScan [col1#x, col2#x] +- BroadcastExchange IdentityBroadcastMode, [plan_id=x] +- Project [col1#x AS product#x, col2#x AS pscore#x] @@ -370,7 +370,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "_LEGACY_ERROR_TEMP_1211", "messageParameters" : { "joinType" : "LEFT OUTER", - "leftPlan" : "Project [col1#x AS user_id#x, col2#x AS score#x, monotonically_increasing_id() AS __qid#xL]\n+- LocalRelation [col1#x, col2#x]", + "leftPlan" : "Project [col1#x AS user_id#x, col2#x AS score#x, uuid(Some(x)) AS __qid#x]\n+- LocalRelation [col1#x, col2#x]", "rightPlan" : "Project [col1#x AS product#x, col2#x AS pscore#x]\n+- LocalRelation [col1#x, col2#x]" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala index 3cf26aa94a5d1..8028970193acd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala @@ -50,6 +50,7 @@ trait SQLQueryTestHelper extends SQLConfHelper with Logging { protected def replaceNotIncludedMsg(line: String): String = { line.replaceAll("#\\d+", "#x") .replaceAll("plan_id=\\d+", "plan_id=x") + .replaceAll("uuid\\(Some\\(-?\\d+\\)\\)", "uuid(Some(x))") .replaceAll( s"Location.*$clsName/", s"Location $notIncludedMsg/{warehouse_dir}/") @@ -178,7 +179,8 @@ trait SQLQueryTestHelper extends SQLConfHelper with Logging { val msg = Option(e.getMessageParameters.get("traceback")).getOrElse("") (emptySchema, Seq(e.getClass.getName, msg)) case e: SparkThrowable with Throwable if e.getCondition != null => - (emptySchema, Seq(e.getClass.getName, getMessage(e, format))) + (emptySchema, Seq(e.getClass.getName, + getMessage(e, format).replaceAll("uuid\\(Some\\(-?\\d+\\)\\)", "uuid(Some(x))"))) case a: AnalysisException => // Do not output the logical plan tree which contains expression IDs. // Also implement a crude way of masking expression IDs in the error message From d1f13785dd7e2e5385d2a2bc5080cb8b18cd9cac Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 4 May 2026 17:04:02 -0700 Subject: [PATCH 6/7] Code review5 --- .../optimizer/RewriteNearestByJoin.scala | 22 +++++++++++-------- .../optimizer/RewriteNearestByJoinSuite.scala | 10 +++++++-- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala index 8119d13ccc680..3d45855cd60da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala @@ -36,14 +36,15 @@ import org.apache.spark.sql.catalyst.rules._ * * Rewritten Plan (SIMILARITY, INNER join type): * {{{ - * Generate inline(_matches), [N], outer=false, [right.col1, right.col2, ...] - * +- Aggregate [__qid], - * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS left.colN-1, - * max_by(struct(right.*), expr, k) AS _matches] - * +- Join LeftOuter - * :- Project [left.*, uuid() AS __qid] - * : +- left - * +- right + * Project [left.*, right.*] + * +- Generate inline(_matches), [N], outer=false, [right.col1, right.col2, ...] + * +- Aggregate [__qid], + * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS left.colN-1, + * max_by(struct(right.*), expr, k) AS _matches] + * +- Join LeftOuter + * :- Project [left.*, uuid() AS __qid] + * : +- left + * +- right * }}} * * For `DISTANCE`, `MIN_BY` is used instead of `MAX_BY`. For `LEFT OUTER`, the `Generate` is @@ -139,12 +140,15 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { AttributeReference(a.name, a.dataType, nullable = true, a.metadata)( exprId = a.exprId, qualifier = a.qualifier) } - Generate( + val generate = Generate( Inline(matchesAlias.toAttribute), unrequiredChildIndex = Seq(aggregate.output.indexOf(matchesAlias.toAttribute)), outer = joinType == LeftOuter, qualifier = None, generatorOutput = generatorOutput, child = aggregate) + + // 5. Final `Project` pinning the output schema to `NearestByJoin.output`. + Project(j.output, generate) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala index 5f41d5e0ad001..650bdc7a6c358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala @@ -60,15 +60,21 @@ class RewriteNearestByJoinSuite extends PlanTest { Seq(qidAlias.toAttribute), firstLeftAggs :+ matchesAlias, join) val generatorOutput = right.output.map { a => - AttributeReference(a.name, a.dataType, nullable = true)(qualifier = a.qualifier) + AttributeReference(a.name, a.dataType, nullable = true)( + exprId = a.exprId, qualifier = a.qualifier) } - Generate( + val generate = Generate( Inline(matchesAlias.toAttribute), unrequiredChildIndex = Seq(aggregate.output.indexOf(matchesAlias.toAttribute)), outer = outer, qualifier = None, generatorOutput = generatorOutput, child = aggregate) + // Mirror the rewrite's final Project that constrains the output schema to + // `NearestByJoin.output` (left and right widened to nullable). + val expectedOutput = + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + Project(expectedOutput, generate) } test("similarity, inner, k=5") { From b692b267a975ea845e208e228bd44f66c2641448 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 4 May 2026 20:54:42 -0700 Subject: [PATCH 7/7] Apply suggestion from @gengliangwang --- .../spark/sql/catalyst/plans/logical/basicLogicalOperators.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 772c3a22b2770..8e9f264698caf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -2420,4 +2420,3 @@ object AsOfJoin { } } } -