Skip to content

Commit dfbe1c4

Browse files
committed
pass 4
1 parent a154c51 commit dfbe1c4

File tree

2 files changed

+60
-41
lines changed

2 files changed

+60
-41
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -652,11 +652,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo
652652
numBits: Long,
653653
fpp: Double): BloomFilter = {
654654

655-
val agg = if (!fpp.isNaN) {
656-
Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(fpp))
657-
} else {
658-
Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits))
659-
}
655+
val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits), lit(fpp))
660656

661657
val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
662658
builder.getProjectBuilder

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,49 +1732,72 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
17321732
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
17331733
Some(Lead(children.head, children(1), children(2), ignoreNulls))
17341734

1735-
case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
1736-
// [col, expectedNumItems: Long, numBits: Long] or
1737-
// [col, expectedNumItems: Long, fpp: Double]
1735+
case "bloom_filter_agg" if fun.getArgumentsCount == 4 =>
1736+
// [col, expectedNumItems: Long, numBits: Long, fpp: Double]
17381737
val children = fun.getArgumentsList.asScala.map(transformExpression)
17391738

1740-
// Check expectedNumItems > 0L
1741-
val expectedNumItemsExpr = children(1)
1742-
val expectedNumItems = expectedNumItemsExpr match {
1743-
case Literal(l: Long, LongType) => l
1739+
val fpp = children(3) match {
1740+
case DoubleLiteral(d) => d
17441741
case _ =>
1745-
throw InvalidPlanInput("Expected insertions must be long literal.")
1746-
}
1747-
if (expectedNumItems <= 0L) {
1748-
throw InvalidPlanInput("Expected insertions must be positive.")
1742+
throw InvalidPlanInput("False positive must be double literal.")
17491743
}
17501744

1751-
val numberBitsOrFpp = children(2)
1745+
if (fpp.isNaN) {
1746+
// Use expectedNumItems and numBits when `fpp.isNaN` if true.
1747+
// Check expectedNumItems > 0L
1748+
val expectedNumItemsExpr = children(1)
1749+
expectedNumItemsExpr match {
1750+
case Literal(l: Long, LongType) =>
1751+
if (l <= 0L) {
1752+
throw InvalidPlanInput("Expected insertions must be positive.")
1753+
}
1754+
case _ =>
1755+
throw InvalidPlanInput("Expected insertions must be long literal.")
1756+
}
1757+
// Check numBits > 0L
1758+
val numBitsExpr = children(2)
1759+
val numBits = numBitsExpr match {
1760+
case Literal(l: Long, LongType) => l
1761+
case _ =>
1762+
throw InvalidPlanInput("Number of bits must be long literal.")
1763+
}
1764+
if (numBits <= 0L) {
1765+
throw InvalidPlanInput("Number of bits must be positive.")
1766+
}
1767+
// Create BloomFilterAggregate with expectedNumItemsExpr and numBitsExpr.
1768+
Some(
1769+
new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr)
1770+
.toAggregateExpression())
17521771

1753-
val numBitsExpr = numberBitsOrFpp match {
1754-
case Literal(numBits: Long, LongType) =>
1755-
// Check numBits > 0L
1756-
if (numBits <= 0L) {
1757-
throw InvalidPlanInput("Number of bits must be positive.")
1758-
}
1759-
numberBitsOrFpp
1760-
case DoubleLiteral(fpp) =>
1761-
// Check fpp not NaN and in (0.0, 1.0).
1762-
if (fpp.isNaN || fpp <= 0d || fpp >= 1d) {
1763-
throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)")
1764-
}
1765-
// Calculate numBits through expectedNumItems and fpp,
1766-
// refer to `BloomFilter.optimalNumOfBits(long, double)`.
1767-
val numBits = BloomFilterHelper.optimalNumOfBits(expectedNumItems, fpp)
1768-
if (numBits <= 0L) {
1769-
throw InvalidPlanInput("Number of bits must be positive")
1770-
}
1771-
Literal(numBits, LongType)
1772-
case _ =>
1773-
throw InvalidPlanInput("The 3rd parameter must be double or long literal.")
1772+
} else if (fpp <= 0d || fpp >= 1d) {
1773+
// fpp must in (0.0, 1.0).
1774+
throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)")
1775+
} else {
1776+
// Use expectedNumItems and fpp when `fpp.isNaN` if false and `fpp` in (0.0, 1.0).
1777+
// Check expectedNumItems > 0L and extract expectedNumItems value.
1778+
val expectedNumItemsExpr = children(1)
1779+
val expectedNumItems = expectedNumItemsExpr match {
1780+
case Literal(l: Long, LongType) => l
1781+
case _ =>
1782+
throw InvalidPlanInput("Expected insertions must be long literal.")
1783+
}
1784+
if (expectedNumItems <= 0L) {
1785+
throw InvalidPlanInput("Expected insertions must be positive.")
1786+
}
1787+
1788+
// Calculate numBits through expectedNumItems and fpp, numBits must be greater than 0.
1789+
val numBits = BloomFilterHelper.optimalNumOfBits(expectedNumItems, fpp)
1790+
if (numBits <= 0L) {
1791+
throw InvalidPlanInput("Number of bits must be positive")
1792+
}
1793+
// Create BloomFilterAggregate with expectedNumItemsExpr and new numBits.
1794+
Some(
1795+
new BloomFilterAggregate(
1796+
children.head,
1797+
expectedNumItemsExpr,
1798+
Literal(numBits, LongType))
1799+
.toAggregateExpression())
17741800
}
1775-
Some(
1776-
new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr)
1777-
.toAggregateExpression())
17781801

17791802
case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
17801803
val children = fun.getArgumentsList.asScala.map(transformExpression)

0 commit comments

Comments
 (0)