@@ -1732,49 +1732,72 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
17321732 val ignoreNulls = extractBoolean(children(3 ), " ignoreNulls" )
17331733 Some (Lead (children.head, children(1 ), children(2 ), ignoreNulls))
17341734
1735- case " bloom_filter_agg" if fun.getArgumentsCount == 3 =>
1736- // [col, expectedNumItems: Long, numBits: Long] or
1737- // [col, expectedNumItems: Long, fpp: Double]
1735+ case " bloom_filter_agg" if fun.getArgumentsCount == 4 =>
1736+ // [col, expectedNumItems: Long, numBits: Long, fpp: Double]
17381737 val children = fun.getArgumentsList.asScala.map(transformExpression)
17391738
1740- // Check expectedNumItems > 0L
1741- val expectedNumItemsExpr = children(1 )
1742- val expectedNumItems = expectedNumItemsExpr match {
1743- case Literal (l : Long , LongType ) => l
1739+ val fpp = children(3 ) match {
1740+ case DoubleLiteral (d) => d
17441741 case _ =>
1745- throw InvalidPlanInput (" Expected insertions must be long literal." )
1746- }
1747- if (expectedNumItems <= 0L ) {
1748- throw InvalidPlanInput (" Expected insertions must be positive." )
1742+ throw InvalidPlanInput (" False positive must be double literal." )
17491743 }
17501744
1751- val numberBitsOrFpp = children(2 )
1745+ if (fpp.isNaN) {
1746+ // Use expectedNumItems and numBits when `fpp.isNaN` if true.
1747+ // Check expectedNumItems > 0L
1748+ val expectedNumItemsExpr = children(1 )
1749+ expectedNumItemsExpr match {
1750+ case Literal (l : Long , LongType ) =>
1751+ if (l <= 0L ) {
1752+ throw InvalidPlanInput (" Expected insertions must be positive." )
1753+ }
1754+ case _ =>
1755+ throw InvalidPlanInput (" Expected insertions must be long literal." )
1756+ }
1757+ // Check numBits > 0L
1758+ val numBitsExpr = children(2 )
1759+ val numBits = numBitsExpr match {
1760+ case Literal (l : Long , LongType ) => l
1761+ case _ =>
1762+ throw InvalidPlanInput (" Number of bits must be long literal." )
1763+ }
1764+ if (numBits <= 0L ) {
1765+ throw InvalidPlanInput (" Number of bits must be positive." )
1766+ }
1767+ // Create BloomFilterAggregate with expectedNumItemsExpr and numBitsExpr.
1768+ Some (
1769+ new BloomFilterAggregate (children.head, expectedNumItemsExpr, numBitsExpr)
1770+ .toAggregateExpression())
17521771
1753- val numBitsExpr = numberBitsOrFpp match {
1754- case Literal (numBits : Long , LongType ) =>
1755- // Check numBits > 0L
1756- if (numBits <= 0L ) {
1757- throw InvalidPlanInput (" Number of bits must be positive." )
1758- }
1759- numberBitsOrFpp
1760- case DoubleLiteral (fpp) =>
1761- // Check fpp not NaN and in (0.0, 1.0).
1762- if (fpp.isNaN || fpp <= 0d || fpp >= 1d ) {
1763- throw InvalidPlanInput (" False positive probability must be within range (0.0, 1.0)" )
1764- }
1765- // Calculate numBits through expectedNumItems and fpp,
1766- // refer to `BloomFilter.optimalNumOfBits(long, double)`.
1767- val numBits = BloomFilterHelper .optimalNumOfBits(expectedNumItems, fpp)
1768- if (numBits <= 0L ) {
1769- throw InvalidPlanInput (" Number of bits must be positive" )
1770- }
1771- Literal (numBits, LongType )
1772- case _ =>
1773- throw InvalidPlanInput (" The 3rd parameter must be double or long literal." )
1772+ } else if (fpp <= 0d || fpp >= 1d ) {
1773+ // fpp must in (0.0, 1.0).
1774+ throw InvalidPlanInput (" False positive probability must be within range (0.0, 1.0)" )
1775+ } else {
1776+ // Use expectedNumItems and fpp when `fpp.isNaN` if false and `fpp` in (0.0, 1.0).
1777+ // Check expectedNumItems > 0L and extract expectedNumItems value.
1778+ val expectedNumItemsExpr = children(1 )
1779+ val expectedNumItems = expectedNumItemsExpr match {
1780+ case Literal (l : Long , LongType ) => l
1781+ case _ =>
1782+ throw InvalidPlanInput (" Expected insertions must be long literal." )
1783+ }
1784+ if (expectedNumItems <= 0L ) {
1785+ throw InvalidPlanInput (" Expected insertions must be positive." )
1786+ }
1787+
1788+ // Calculate numBits through expectedNumItems and fpp, numBits must be greater than 0.
1789+ val numBits = BloomFilterHelper .optimalNumOfBits(expectedNumItems, fpp)
1790+ if (numBits <= 0L ) {
1791+ throw InvalidPlanInput (" Number of bits must be positive" )
1792+ }
1793+ // Create BloomFilterAggregate with expectedNumItemsExpr and new numBits.
1794+ Some (
1795+ new BloomFilterAggregate (
1796+ children.head,
1797+ expectedNumItemsExpr,
1798+ Literal (numBits, LongType ))
1799+ .toAggregateExpression())
17741800 }
1775- Some (
1776- new BloomFilterAggregate (children.head, expectedNumItemsExpr, numBitsExpr)
1777- .toAggregateExpression())
17781801
17791802 case " window" if Seq (2 , 3 , 4 ).contains(fun.getArgumentsCount) =>
17801803 val children = fun.getArgumentsList.asScala.map(transformExpression)
0 commit comments