diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 142420ee258ae..b87d018f2ab1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -580,19 +580,18 @@ case class Union( allowMissingCol: Boolean = false) extends UnionBase { assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.") - override def maxRows: Option[Long] = { - var sum = BigInt(0) - children.foreach { child => - if (child.maxRows.isDefined) { - sum += child.maxRows.get - if (!sum.isValidLong) { - return None + override lazy val maxRows: Option[Long] = { + val sum = children.foldLeft(Option(BigInt(0))) { + case (Some(acc), child) => + child.maxRows match { + case Some(n) => + val newSum = acc + n + if (newSum.isValidLong) Some(newSum) else None + case None => None } - } else { - return None - } + case (None, _) => None } - Some(sum.toLong) + sum.map(_.toLong) } final override val nodePatterns: Seq[TreePattern] = Seq(UNION) @@ -600,19 +599,18 @@ case class Union( /** * Note the definition has assumption about how union is implemented physically. */ - override def maxRowsPerPartition: Option[Long] = { - var sum = BigInt(0) - children.foreach { child => - if (child.maxRowsPerPartition.isDefined) { - sum += child.maxRowsPerPartition.get - if (!sum.isValidLong) { - return None + override lazy val maxRowsPerPartition: Option[Long] = { + val sum = children.foldLeft(Option(BigInt(0))) { + case (Some(acc), child) => + child.maxRowsPerPartition match { + case Some(n) => + val newSum = acc + n + if (newSum.isValidLong) Some(newSum) else None + case None => None } - } else { - return None - } + case (None, _) => None } - Some(sum.toLong) + sum.map(_.toLong) } private def duplicatesResolvedPerBranch: Boolean = @@ -666,7 +664,7 @@ case class Join( hint: JoinHint) extends BinaryNode with PredicateHelper { - override def maxRows: Option[Long] = { + override lazy val maxRows: Option[Long] = { joinType match { case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle if left.maxRows.isDefined && right.maxRows.isDefined =>