-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-42664][CONNECT] Support bloomFilter function for DataFrameStatFunctions
#42414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
beaaae6
a154c51
dfbe1c4
d600ebb
4709dd5
fe958a6
6ffbfa0
cf3104a
80a6b4b
1b88765
473ad60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a negative test case where mightContain evaluates to false?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 6ffbfa0 Added checks for values that are definitely not included. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu | |
| import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} | ||
| import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate | ||
| import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} | ||
| import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} | ||
| import org.apache.spark.sql.catalyst.plans.logical | ||
|
|
@@ -78,6 +79,7 @@ import org.apache.spark.sql.types._ | |
| import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
| import org.apache.spark.storage.CacheId | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.util.sketch.BloomFilterHelper | ||
|
|
||
| final case class InvalidCommandInput( | ||
| private val message: String = "", | ||
|
|
@@ -1730,6 +1732,50 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { | |
| val ignoreNulls = extractBoolean(children(3), "ignoreNulls") | ||
| Some(Lead(children.head, children(1), children(2), ignoreNulls)) | ||
|
|
||
| case "bloom_filter_agg" if fun.getArgumentsCount == 3 => | ||
| // [col, expectedNumItems: Long, numBits: Long] or | ||
| // [col, expectedNumItems: Long, fpp: Double] | ||
| val children = fun.getArgumentsList.asScala.map(transformExpression) | ||
|
|
||
| // Check expectedNumItems > 0L | ||
| val expectedNumItemsExpr = children(1) | ||
| val expectedNumItems = expectedNumItemsExpr match { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to @hvanhovell Do you think we should check the validity of the input here? By checking here, the error message can be exactly the same as the api in Perhaps we don't need to ensure that the error message is the same as before?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do that in a follow-up. |
||
| case Literal(l: Long, LongType) => l | ||
| case _ => | ||
| throw InvalidPlanInput("Expected insertions must be long literal.") | ||
| } | ||
| if (expectedNumItems <= 0L) { | ||
| throw InvalidPlanInput("Expected insertions must be positive.") | ||
| } | ||
|
|
||
| val numberBitsOrFpp = children(2) | ||
|
|
||
| val numBitsExpr = numberBitsOrFpp match { | ||
| case Literal(numBits: Long, LongType) => | ||
| // Check numBits > 0L | ||
| if (numBits <= 0L) { | ||
| throw InvalidPlanInput("Number of bits must be positive.") | ||
| } | ||
| numberBitsOrFpp | ||
| case DoubleLiteral(fpp) => | ||
| // Check fpp not NaN and in (0.0, 1.0). | ||
| if (fpp.isNaN || fpp <= 0d || fpp >= 1d) { | ||
| throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)") | ||
| } | ||
| // Calculate numBits through expectedNumItems and fpp, | ||
| // refer to `BloomFilter.optimalNumOfBits(long, double)`. | ||
| val numBits = BloomFilterHelper.optimalNumOfBits(expectedNumItems, fpp) | ||
| if (numBits <= 0L) { | ||
| throw InvalidPlanInput("Number of bits must be positive") | ||
| } | ||
| Literal(numBits, LongType) | ||
| case _ => | ||
| throw InvalidPlanInput("The 3rd parameter must be double or long literal.") | ||
| } | ||
| Some( | ||
| new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr) | ||
| .toAggregateExpression()) | ||
|
|
||
| case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => | ||
| val children = fun.getArgumentsList.asScala.map(transformExpression) | ||
| val timeCol = children.head | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.util.sketch | ||
|
|
||
| /** | ||
| * `BloomFilterHelper` is used to bridge helper methods in BloomFilter` | ||
| */ | ||
| private[spark] object BloomFilterHelper { | ||
|
||
| def optimalNumOfBits(expectedNumItems: Long, fpp: Double): Long = | ||
| BloomFilter.optimalNumOfBits(expectedNumItems, fpp) | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike | |||||||||||||||||||
| import org.apache.spark.sql.internal.SQLConf | ||||||||||||||||||||
| import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS} | ||||||||||||||||||||
| import org.apache.spark.sql.types._ | ||||||||||||||||||||
| import org.apache.spark.unsafe.types.UTF8String | ||||||||||||||||||||
| import org.apache.spark.util.sketch.BloomFilter | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /** | ||||||||||||||||||||
|
|
@@ -78,7 +79,7 @@ case class BloomFilterAggregate( | |||||||||||||||||||
| "exprName" -> "estimatedNumItems or numBits" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| case (LongType, LongType, LongType) => | ||||||||||||||||||||
| case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) => | ||||||||||||||||||||
| if (!estimatedNumItemsExpression.foldable) { | ||||||||||||||||||||
| DataTypeMismatch( | ||||||||||||||||||||
| errorSubClass = "NON_FOLDABLE_INPUT", | ||||||||||||||||||||
|
|
@@ -150,6 +151,15 @@ case class BloomFilterAggregate( | |||||||||||||||||||
| Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, | ||||||||||||||||||||
| SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Mark as lazy so that `updater` is not evaluated during tree transformation. | ||||||||||||||||||||
| private lazy val updater: BloomFilterUpdater = first.dataType match { | ||||||||||||||||||||
|
||||||||||||||||||||
| // Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation. | |
| private lazy val estimatedNumItems: Long = | |
| Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, | |
| SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS)) | |
| // Mark as lazy so that `numBits` is not evaluated during tree transformation. | |
| private lazy val numBits: Long = | |
| Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, | |
| SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like the ambiguity here. Since we are managing this function ourselves, can we just have one way of invoking it. I kind of prefer
Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)).Alternatively you pass all three, where you pick either
fppornumItemsand passnullfor the other field. Another idea would be to have different names.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me think about how to refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fe958a6 chang e to only use
Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)).