From 512958ba0054407124bd3dd34a32ff7a2690c063 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 12 Jul 2022 21:39:26 +0900 Subject: [PATCH 01/44] Basework --- .../sql/catalyst/plans/logical/object.scala | 21 + .../logical/pythonLogicalOperators.scala | 23 + ...UntypedFlatMapGroupsWithStateFunction.java | 15 + .../UntypedMapGroupsWithStateFunction.java | 15 + .../spark/sql/RelationalGroupedDataset.scala | 106 +++++ .../spark/sql/execution/SparkStrategies.scala | 44 ++ .../execution/python/PandasGroupUtils.scala | 4 +- .../streaming/IncrementalExecution.scala | 16 + .../PythonFlatMapGroupsWithStateExec.scala | 286 ++++++++++++ .../UntypedFlatMapGroupsWithStateExec.scala | 290 ++++++++++++ .../UntypedFlatMapGroupsWithStateSuite.scala | 438 ++++++++++++++++++ 11 files changed, 1257 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index e5fe07e2d950d..61bd361588be4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -551,6 +551,27 @@ case class FlatMapGroupsWithState( copy(child = newLeft, initialState = newRight) } +case class UntypedFlatMapGroupsWithState( + func: (Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row], + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + stateType: StructType, + outputMode: OutputMode, + isMapGroupsWithState: Boolean = false, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode { + if (isMapGroupsWithState) { + assert(outputMode == OutputMode.Update) + } + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): UntypedFlatMapGroupsWithState = copy(child = newChild) +} + /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { def apply( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index c2f74b3508342..72593edb42aa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType /** * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. @@ -98,6 +100,27 @@ case class FlatMapCoGroupsInPandas( copy(left = newLeft, right = newRight) } +case class PythonFlatMapGroupsWithState( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + stateType: StructType, + outputMode: OutputMode, + isMapGroupsWithState: Boolean = false, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode { + if (isMapGroupsWithState) { + assert(outputMode == OutputMode.Update) + } + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): PythonFlatMapGroupsWithState = copy(child = newChild) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java new file mode 100644 index 0000000000000..e4634b15eca8c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java @@ -0,0 +1,15 @@ +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.streaming.GroupState; + +@Experimental +@Evolving +public interface UntypedFlatMapGroupsWithStateFunction extends Serializable { + Iterator call(Row key, Iterator values, GroupState state) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java new file mode 100644 index 0000000000000..14167c84a8bed --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java @@ -0,0 +1,15 @@ +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.streaming.GroupState; + +@Experimental +@Evolving +public interface UntypedMapGroupsWithStateFunction extends Serializable { + Row call(Row key, Iterator values, GroupState state) throws Exception; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 989ee32521871..ad19b2f067f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Stable +import org.apache.spark.api.java.function.{UntypedFlatMapGroupsWithStateFunction, UntypedMapGroupsWithStateFunction} import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedFunction} @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.{NumericType, StructType} /** @@ -620,6 +622,110 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + def mapGroupsWithState( + func: UntypedMapGroupsWithStateFunction, + outputStructType: StructType, + stateStructType: StructType, + timeoutConf: GroupStateTimeout): DataFrame = { + mapGroupsWithState( + outputStructType, stateStructType, timeoutConf)( + (key: Row, it: Iterator[Row], s: GroupState[Row]) => func.call(key, it.asJava, s) + ) + } + + def mapGroupsWithState( + outputStructType: StructType, + stateStructType: StructType, + timeoutConf: GroupStateTimeout)( + func: (Row, Iterator[Row], GroupState[Row]) => Row): DataFrame = { + val flatMapFunc = (key: Row, it: Iterator[Row], s: GroupState[Row]) => + Iterator(func(key, it, s)) + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = UntypedFlatMapGroupsWithState( + flatMapFunc.asInstanceOf[(Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row]], + groupingAttrs, + outputAttrs, + stateStructType, + OutputMode.Update(), + isMapGroupsWithState = true, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + + def flatMapGroupsWithState( + func: UntypedFlatMapGroupsWithStateFunction, + outputStructType: StructType, + stateStructType: StructType, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout): DataFrame = { + val f = (key: Row, it: Iterator[Row], s: GroupState[Row]) => + func.call(key, it.asJava, s).asScala + flatMapGroupsWithState(outputStructType, stateStructType, outputMode, timeoutConf)(f) + } + + def flatMapGroupsWithState( + outputStructType: StructType, + stateStructType: StructType, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout)( + func: (Row, Iterator[Row], GroupState[Row]) => Iterator[Row]): DataFrame = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = UntypedFlatMapGroupsWithState( + func.asInstanceOf[(Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row]], + groupingAttrs, + outputAttrs, + stateStructType, + outputMode, + isMapGroupsWithState = false, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + + // FIXME: probably we have to change the type as String for outputMode and timeoutConf to provide + // parameters from Python? + private[sql] def pythonFlatMapGroupsWithState( + func: PythonUDF, + outputStructType: StructType, + stateStructType: StructType, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout): DataFrame = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = PythonFlatMapGroupsWithState( + func, + groupingAttrs, + outputAttrs, + stateStructType, + outputMode, + isMapGroupsWithState = false, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6104104c7bea4..767fd8cf9f11f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -684,6 +684,44 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert [[UntypedFlatMapGroupsWithState]] logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object UntypedFlatMapGroupsWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case UntypedFlatMapGroupsWithState( + func, groupAttr, outputAttr, stateType, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val execPlan = UntypedFlatMapGroupsWithStateExec( + func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, + batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + ) + execPlan :: Nil + case _ => + Nil + } + } + + /** + * Strategy to convert [[UntypedFlatMapGroupsWithState]] logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object PythonFlatMapGroupsWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PythonFlatMapGroupsWithState( + func, groupAttr, outputAttr, stateType, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val execPlan = PythonFlatMapGroupsWithStateExec( + func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, + batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + ) + execPlan :: Nil + case _ => + Nil + } + } + /** * Strategy to convert EvalPython logical operator to physical operator. */ @@ -793,6 +831,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + // FIXME: implement it! + case _: logical.UntypedFlatMapGroupsWithState => + throw new UnsupportedOperationException("Not yet implemented for batch query!") + // FIXME: implement it! + case _: PythonFlatMapGroupsWithState => + throw new UnsupportedOperationException("Not yet implemented for batch query!") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 2da0000dad4ef..1ffc6a64d7708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -30,7 +30,9 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** * Base functionality for plans which execute grouped python udfs. */ -private[python] object PandasGroupUtils { +// FIXME: should we move PythonFlatMapGroupsWithStateExec to python package? +// private[python] +object PandasGroupUtils { /** * passes the data to the python runner and coverts the resulting * columnarbatch into internal rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3f369ac5e973b..105df798e7059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -62,6 +62,8 @@ class IncrementalExecution( StreamingJoinStrategy :: StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: + UntypedFlatMapGroupsWithStateStrategy :: + PythonFlatMapGroupsWithStateStrategy :: StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: Nil @@ -210,6 +212,20 @@ class IncrementalExecution( hasInitialState = hasInitialState ) + case m: UntypedFlatMapGroupsWithStateExec => + m.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + ) + + case m: PythonFlatMapGroupsWithStateExec => + m.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + ) + case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala new file mode 100644 index 0000000000000..05c0a396a3e48 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, PythonUDF, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, EventTimeWatermark, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreOps} +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.createStateManager +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.CompletionIterator + +case class PythonFlatMapGroupsWithStateExec( + func: Expression, + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + override def output: Seq[Attribute] = outAttributes + private val isTimeoutEnabled = timeoutConf != NoTimeout + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonFunction = func.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + + private val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } + + private val outputType = outAttributes.toStructType + private val keyEncoder = RowEncoder(groupingAttributes.toStructType) + .resolveAndBind(groupingAttributes) + private val valueEncoder = RowEncoder(child.output.toStructType).resolveAndBind(child.output) + private val stateEncoder = RowEncoder(stateType).resolveAndBind() + private val outputEncoder = RowEncoder(outputType).resolveAndBind(outAttributes) + + private[sql] val stateManager = + createStateManager(stateEncoder.asInstanceOf[ExpressionEncoder[Any]], isTimeoutEnabled, + stateFormatVersion) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def keyExpressions: Seq[Attribute] = groupingAttributes + + override def shortName: String = "pythonFlatMapGroupsWithState" + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + + /** + * Process data by applying the user defined function on a per partition basis. + * + * @param iter - Iterator of the data rows + * @param store - associated state store for this partition + * @param processor - handle to the input processor object. + */ + def processDataWithPartition( + iter: Iterator[InternalRow], + store: StateStore, + processor: InputProcessor): CompletionIterator[InternalRow, Iterator[InternalRow]] = { + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val timeoutLatencyMs = longMetric("allRemovalsTimeMs") + + val currentTimeNs = System.nanoTime + val updatesStartTimeNs = currentTimeNs + var timeoutProcessingStartTimeNs = currentTimeNs + + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForData match { + case Some(predicate) if timeoutConf == EventTimeTimeout => + applyRemovingRowsOlderThanWatermark(iter, predicate) + case _ => + iter + } + + val processedOutputIterator = processor.processNewData(filteredIter) + + val newDataProcessorIter = + CompletionIterator[InternalRow, Iterator[InternalRow]]( + processedOutputIterator, { + // Once the input is processed, mark the start time for timeout processing to measure + // it separately from the overall processing time. + timeoutProcessingStartTimeNs = System.nanoTime + }) + + // SPARK-38320: Late-bind the timeout processing iterator so it is created *after* the input is + // processed (the input iterator is exhausted) and the state updates are written into the + // state store. Otherwise the iterator may not see the updates (e.g. with RocksDB state store). + val timeoutProcessorIter = new Iterator[InternalRow] { + private lazy val itr = getIterator() + override def hasNext = itr.hasNext + override def next() = itr.next() + private def getIterator(): Iterator[InternalRow] = + CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { + // Note: `timeoutLatencyMs` also includes the time the parent operator took for + // processing output returned through iterator. + timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) + }) + } + + // Generate a iterator that returns the rows grouped by the grouping function + // Note that this code ensures that the filtering for timeout occurs only after + // all the data has been processed. This is to ensure that the timeout information of all + // the keys with data is updated before they are processed for timeouts. + val outputIterator = newDataProcessorIter ++ timeoutProcessorIter + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumed, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { + // Note: Due to the iterator lazy execution, this metric also captures the time taken + // by the upstream (consumer) operators in addition to the processing in this operator. + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { + store.commit() + } + setStoreMetrics(store) + setOperatorMetrics() + }) + } + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + // Throw errors early if parameters are not as expected + timeoutConf match { + case ProcessingTimeTimeout => + require(batchTimestampMs.nonEmpty) + case EventTimeTimeout => + require(eventTimeWatermark.nonEmpty) // watermark value has been populated + require(watermarkExpression.nonEmpty) // input schema has watermark attribute + case _ => + } + + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateInfo, + groupingAttributes.toStructType, + stateManager.stateSchema, + numColsPrefixKey = 0, + session.sqlContext.sessionState, + Some(session.sqlContext.streams.stateStoreCoordinator) + ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => + val processor = new InputProcessor(store) + processDataWithPartition(singleIterator, store, processor) + } + } + + /** Helper class to update the state store */ + class InputProcessor(store: StateStore) { + private val keyDeserializer = keyEncoder.createDeserializer() + private val valueDeserializer = valueEncoder.createDeserializer() + private val outputSerializer = outputEncoder.createSerializer() + + // Metrics + private val numUpdatedStateRows = longMetric("numUpdatedStateRows") + private val numOutputRows = longMetric("numOutputRows") + private val numRemovedStateRows = longMetric("numRemovedStateRows") + + /** + * For every group, get the key, values and corresponding state and call the function, + * and return an iterator of rows + */ + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + + val data = groupAndProject(dataIter, groupingAttributes, child.output, + dedupAttributes).map { case (keyRow, valueRowIter) => + + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + val stateData = stateManager.getState(store, keyUnsafeRow) + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), // TODO: check whether the object is Row or not + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = false, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + + // UnsafeRow, Iterator[UnsafeRow], GroupStateImpl[Row] + (keyRow, valueRowIter, groupedState) + } + + // FIXME: need to construct the code to pass the iterator of (key, valueIter, GroupState) + // and receive an iterator of (outputs, state update). + + // FIXME: outputs should be produced to the downstream, with conversion from Row to + // InternalRow. + // FIXME: state updates should be reflected to the state store. + // FIXME: refer UntypedFlatMapGroupsWithStateExec.callFunctionAndUpdateState for more details + + // FIXME: pretty sure this is a dummy code + Iterator.empty + } + + /** Find the groups that have timeout set and are timing out right now, and call the function */ + def processTimedOutState(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + + val data = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + }.map { stateData => + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), // TODO: check whether the object is Row or not + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = true, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + + // UnsafeRow, Iterator[UnsafeRow], GroupStateImpl[Row] + (stateData.keyRow, Iterator.empty.asInstanceOf[Iterator[UnsafeRow]], groupedState) + } + + // FIXME: need to construct the code to pass the iterator of (key, valueIter, GroupState) + // and receive an iterator of (outputs, state update). + + // FIXME: outputs should be produced to the downstream, with conversion from Row to + // InternalRow. + // FIXME: state updates should be reflected to the state store. + + // FIXME: pretty sure this is a dummy code + Iterator.empty + } else Iterator.empty + } + } + + override protected def withNewChildInternal( + newChild: SparkPlan): PythonFlatMapGroupsWithStateExec = copy(child = newChild) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala new file mode 100644 index 0000000000000..4214eebcc26c7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreOps} +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.{createStateManager, StateData} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.CompletionIterator + +case class UntypedFlatMapGroupsWithStateExec( + func: (Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row], + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + override def output: Seq[Attribute] = outAttributes + private val isTimeoutEnabled = timeoutConf != NoTimeout + + private val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } + + private val outputType = outAttributes.toStructType + private val keyEncoder = RowEncoder(groupingAttributes.toStructType) + .resolveAndBind(groupingAttributes) + private val valueEncoder = RowEncoder(child.output.toStructType).resolveAndBind(child.output) + private val stateEncoder = RowEncoder(stateType).resolveAndBind() + private val outputEncoder = RowEncoder(outputType).resolveAndBind(outAttributes) + + private[sql] val stateManager = + createStateManager(stateEncoder.asInstanceOf[ExpressionEncoder[Any]], isTimeoutEnabled, + stateFormatVersion) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def keyExpressions: Seq[Attribute] = groupingAttributes + + override def shortName: String = "untypedFlatMapGroupsWithState" + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + + /** + * Process data by applying the user defined function on a per partition basis. + * + * @param iter - Iterator of the data rows + * @param store - associated state store for this partition + * @param processor - handle to the input processor object. + */ + def processDataWithPartition( + iter: Iterator[InternalRow], + store: StateStore, + processor: InputProcessor): CompletionIterator[InternalRow, Iterator[InternalRow]] = { + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val timeoutLatencyMs = longMetric("allRemovalsTimeMs") + + val currentTimeNs = System.nanoTime + val updatesStartTimeNs = currentTimeNs + var timeoutProcessingStartTimeNs = currentTimeNs + + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForData match { + case Some(predicate) if timeoutConf == EventTimeTimeout => + applyRemovingRowsOlderThanWatermark(iter, predicate) + case _ => + iter + } + + val processedOutputIterator = processor.processNewData(filteredIter) + + val newDataProcessorIter = + CompletionIterator[InternalRow, Iterator[InternalRow]]( + processedOutputIterator, { + // Once the input is processed, mark the start time for timeout processing to measure + // it separately from the overall processing time. + timeoutProcessingStartTimeNs = System.nanoTime + }) + + // SPARK-38320: Late-bind the timeout processing iterator so it is created *after* the input is + // processed (the input iterator is exhausted) and the state updates are written into the + // state store. Otherwise the iterator may not see the updates (e.g. with RocksDB state store). + val timeoutProcessorIter = new Iterator[InternalRow] { + private lazy val itr = getIterator() + override def hasNext = itr.hasNext + override def next() = itr.next() + private def getIterator(): Iterator[InternalRow] = + CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { + // Note: `timeoutLatencyMs` also includes the time the parent operator took for + // processing output returned through iterator. + timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) + }) + } + + // Generate a iterator that returns the rows grouped by the grouping function + // Note that this code ensures that the filtering for timeout occurs only after + // all the data has been processed. This is to ensure that the timeout information of all + // the keys with data is updated before they are processed for timeouts. + val outputIterator = newDataProcessorIter ++ timeoutProcessorIter + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumed, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { + // Note: Due to the iterator lazy execution, this metric also captures the time taken + // by the upstream (consumer) operators in addition to the processing in this operator. + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { + store.commit() + } + setStoreMetrics(store) + setOperatorMetrics() + }) + } + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + // Throw errors early if parameters are not as expected + timeoutConf match { + case ProcessingTimeTimeout => + require(batchTimestampMs.nonEmpty) + case EventTimeTimeout => + require(eventTimeWatermark.nonEmpty) // watermark value has been populated + require(watermarkExpression.nonEmpty) // input schema has watermark attribute + case _ => + } + + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateInfo, + groupingAttributes.toStructType, + stateManager.stateSchema, + numColsPrefixKey = 0, + session.sqlContext.sessionState, + Some(session.sqlContext.streams.stateStoreCoordinator) + ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => + val processor = new InputProcessor(store) + processDataWithPartition(singleIterator, store, processor) + } + } + + /** Helper class to update the state store */ + class InputProcessor(store: StateStore) { + private val keyDeserializer = keyEncoder.createDeserializer() + private val valueDeserializer = valueEncoder.createDeserializer() + private val outputSerializer = outputEncoder.createSerializer() + + // Metrics + private val numUpdatedStateRows = longMetric("numUpdatedStateRows") + private val numOutputRows = longMetric("numOutputRows") + private val numRemovedStateRows = longMetric("numRemovedStateRows") + + /** + * For every group, get the key, values and corresponding state and call the function, + * and return an iterator of rows + */ + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + groupedIter.flatMap { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + callFunctionAndUpdateState( + stateManager.getState(store, keyUnsafeRow), + valueRowIter, + hasTimedOut = false) + } + } + + /** Find the groups that have timeout set and are timing out right now, and call the function */ + def processTimedOutState(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + } + timingOutPairs.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) + } + } else Iterator.empty + } + + /** + * Call the user function on a key's data, update the state store, and return the return data + * iterator. Note that the store updating is lazy, that is, the store will be updated only + * after the returned iterator is fully consumed. + * + * @param stateData All the data related to the state to be updated + * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty + * @param hasTimedOut Whether this function is being called for a key timeout + */ + private def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val keyRowAsUntyped = keyDeserializer(stateData.keyRow) + val valueRowsIterAsUntyped = valueRowIter.map(valueDeserializer.apply) + + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), // TODO: check whether the object is Row or not + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + + // Call function, get the returned objects and convert them to rows + val mappedIterator = func(keyRowAsUntyped, valueRowsIterAsUntyped, groupState).map { row => + numOutputRows += 1 + outputSerializer(row) + } + + // When the iterator is consumed, then write changes to state + def onIteratorCompletion: Unit = { + if (groupState.isRemoved && !groupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, stateData.keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = groupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.isUpdated || groupState.isRemoved || hasTimeoutChanged + + if (shouldWriteState) { + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } + + // Return an iterator of rows such that fully consumed, the updated state value will be saved + CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) + } + } + + override protected def withNewChildInternal( + newChild: SparkPlan): UntypedFlatMapGroupsWithStateExec = copy(child = newChild) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala new file mode 100644 index 0000000000000..60b66fb1319ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.streaming.GroupStateTimeout.EventTimeTimeout +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} + +class UntypedFlatMapGroupsWithStateSuite extends StateStoreMetricsTest { + + import testImplicits._ + + import FlatMapGroupsWithStateSuite._ + + /** + * Sample `flatMapGroupsWithState` function implementation. It maintains the max event time as + * state and set the timeout timestamp based on the current max event time seen. It returns the + * max event time in the state, or -1 if the state was removed by timeout. Timeout is 5sec. + */ + val sampleTestFunction = + (key: Row, values: Iterator[Row], state: GroupState[Row]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + + // key: String, values: (String, Timestamp), state: Long, output: (String, Int) + val keyAsString = key.getString(0) + + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator(Row(keyAsString, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_.getTimestamp(1).getTime / 1000).max, + state.getOption.map(_.getLong(0)).getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(Row(maxEventTimeSec)) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator(Row(keyAsString, maxEventTimeSec.toInt)) + } + } + + test("flatMapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + + val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator.empty + } else { + state.update(Row(count)) + Iterator(Row(key.getString(0), count.toString)) + } + } + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputData.toDS() + .groupBy("value") + .flatMapGroupsWithState( + outputStructType, stateStructType, Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + // Additionally, it updates state lazily as the returned iterator get consumed + val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { + values.flatMap { _ => + val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + 1 + if (count == 3) { + state.remove() + None + } else { + state.update(Row(count)) + Some(Row(key.getString(0), count.toString)) + } + } + } + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputData.toDS() + .groupBy("value") + .flatMapGroupsWithState(outputStructType, stateStructType, Update, + GroupStateTimeout.NoTimeout)(stateFunc) + testStream(result, Update)( + AddData(inputData, "a", "a", "b"), + CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) + ) + } + + test("flatMapGroupsWithState - streaming + aggregation") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { + + val keyAsString = key.getString(0) + + val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator(Row(keyAsString, "-1")) + } else { + state.update(Row(count)) + Iterator(Row(keyAsString, count.toString)) + } + } + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputData.toDS() + .groupBy("value") + .flatMapGroupsWithState(outputStructType, stateStructType, Append, + GroupStateTimeout.NoTimeout)(stateFunc) + .groupBy("key") + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckNewAnswer(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckNewAnswer(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckNewAnswer(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("flatMapGroupsWithState - streaming with processing time timeout") { + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + + val keyAsString = key.getString(0) + if (state.hasTimedOut) { + state.remove() + Iterator(Row(keyAsString, "-1")) + } else { + val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size + state.update(Row(count)) + state.setTimeoutDuration("10 seconds") + Iterator(Row(keyAsString, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputData.toDS() + .groupBy("value") + .flatMapGroupsWithState(outputStructType, stateStructType, Update, + ProcessingTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + StopStream, + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(11 * 1000), + CheckNewAnswer(("b", "-1"), ("c", "1")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows( + total = Seq(0), updated = Seq(0), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } + + test("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { + val inputData = MemoryStream[(String, Int)] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("maxEventTimeSec", IntegerType))) + val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) + val result = + inputData.toDS + .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .groupBy("key") + .flatMapGroupsWithState(outputStructType, stateStructType, Update, + EventTimeTimeout)(sampleTestFunction) + + testStream(result, Update)( + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + + test("mapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + + val keyAsString = key.getString(0) + val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Row(keyAsString, "-1") + } else { + state.update(Row(count)) + Row(keyAsString, count.toString) + } + } + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputData.toDS() + .groupBy("value") + .mapGroupsWithState(outputStructType, stateStructType, + GroupStateTimeout.NoTimeout)(stateFunc) // Types = State: MyState, Out: (Str, Str) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckNewAnswer(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 1), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckNewAnswer(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // String, (String, Long), RunningCount(Long) + val stateFunc = + (key: Row, values: Iterator[Row], state: GroupState[Row]) => { + val keyAsString = key.getString(0) + if (state.hasTimedOut) { + state.remove() + Iterator(Row(keyAsString, "-1")) + } else { + val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size + state.update(Row(count)) + state.setTimeoutDuration("10 seconds") + Iterator(Row(keyAsString, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputData.toDF().toDF("key", "time") + .selectExpr("key", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .groupBy("key") + .flatMapGroupsWithState(outputStructType, stateStructType, Update, + ProcessingTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")) + ) + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) +} + +object UntypedFlatMapGroupsWithStateSuite { + + var failInTask = true + + def assertCanGetProcessingTime(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCanGetWatermark(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCannotGetWatermark(func: => Unit): Unit = { + try { + func + } catch { + case u: UnsupportedOperationException => + return + case _: Throwable => + throw new TestFailedException("Unexpected exception when trying to get watermark", 20) + } + throw new TestFailedException("Could get watermark when not expected", 20) + } +} From d36373bfe035cb055ae558bb55cbe6584ad7b331 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 14 Jul 2022 15:32:27 +0900 Subject: [PATCH 02/44] Add Python implementation --- .../spark/api/python/PythonRunner.scala | 2 + .../scala/org/apache/spark/util/Utils.scala | 2 +- dev/sparktestsupport/modules.py | 1 + python/pyspark/rdd.py | 2 + .../pyspark/sql/pandas/_typing/__init__.pyi | 11 +- python/pyspark/sql/pandas/functions.py | 2 + python/pyspark/sql/pandas/group_ops.py | 45 +- python/pyspark/sql/streaming/state.py | 182 ++++++++ .../test_pandas_grouped_map_with_state.py | 97 ++++ python/pyspark/sql/udf.py | 9 +- python/pyspark/worker.py | 95 +++- .../sql/streaming/GroupStateTimeout.java | 3 + .../sql/catalyst/plans/logical/object.scala | 21 - .../logical/pythonLogicalOperators.scala | 20 +- ...UntypedFlatMapGroupsWithStateFunction.java | 15 - .../UntypedMapGroupsWithStateFunction.java | 15 - .../spark/sql/RelationalGroupedDataset.scala | 92 +--- .../spark/sql/api/python/PythonSQLUtils.scala | 42 +- .../spark/sql/execution/SparkStrategies.scala | 36 +- .../python/ArrowPythonRunnerWithState.scala | 103 ++++ .../FlatMapGroupsInPandasWithStateExec.scala | 158 +++++++ .../execution/python/PandasGroupUtils.scala | 4 +- .../FlatMapGroupsWithStateExec.scala | 17 +- .../execution/streaming/GroupStateImpl.scala | 56 ++- .../streaming/IncrementalExecution.scala | 13 +- .../PythonFlatMapGroupsWithStateExec.scala | 286 ------------ .../UntypedFlatMapGroupsWithStateExec.scala | 290 ------------ .../execution/streaming/state/package.scala | 2 +- .../spark/sql/IntegratedUDFTestUtils.scala | 74 ++- .../apache/spark/sql/SQLQueryTestSuite.scala | 4 +- .../errors/QueryCompilationErrorsSuite.scala | 4 +- .../FlatMapGroupsInPandasWithStateSuite.scala | 434 +++++++++++++++++ .../UntypedFlatMapGroupsWithStateSuite.scala | 438 ------------------ .../continuous/ContinuousSuite.scala | 2 +- 34 files changed, 1342 insertions(+), 1235 deletions(-) create mode 100644 python/pyspark/sql/streaming/state.py create mode 100644 python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py delete mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java delete mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5a13674e8bfbf..7b31fa93c32e5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -53,6 +53,7 @@ private[spark] object PythonEvalType { val SQL_MAP_PANDAS_ITER_UDF = 205 val SQL_COGROUPED_MAP_PANDAS_UDF = 206 val SQL_MAP_ARROW_ITER_UDF = 207 + val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -65,6 +66,7 @@ private[spark] object PythonEvalType { case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index aef79c7882ca1..484a07c18ed0e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -3051,7 +3051,7 @@ private[spark] object Utils extends Logging { * and return the trailing part after the last dollar sign in the middle */ @scala.annotation.tailrec - private def stripDollars(s: String): String = { + def stripDollars(s: String): String = { val lastDollarIndex = s.lastIndexOf('$') if (lastDollarIndex < s.length - 1) { // The last char is not a dollar sign diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2b9d526937942..f9e2144d334e3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -452,6 +452,7 @@ def __hash__(self): "pyspark.sql.tests.test_group", "pyspark.sql.tests.test_pandas_cogrouped_map", "pyspark.sql.tests.test_pandas_grouped_map", + "pyspark.sql.tests.test_pandas_grouped_map_with_state", "pyspark.sql.tests.test_pandas_map", "pyspark.sql.tests.test_arrow_map", "pyspark.sql.tests.test_pandas_udf", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7ef0014ae7518..5f4f4d494e13c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -105,6 +105,7 @@ PandasMapIterUDFType, PandasCogroupedMapUDFType, ArrowMapIterUDFType, + PandasGroupedMapUDFWithStateType, ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType @@ -147,6 +148,7 @@ class PythonEvalType: SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 + SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 def portable_hash(x: Hashable) -> int: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 27ac64a7238ba..82b861c51cf5c 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -22,19 +22,19 @@ from typing import ( Iterable, NewType, Tuple, - Type, TypeVar, Union, ) from typing_extensions import Protocol, Literal from types import FunctionType -from pyspark.sql._typing import LiteralType +import pyarrow from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries from numpy import ndarray as NDArray -import pyarrow +from pyspark.sql._typing import LiteralType +from pyspark.sql.streaming.state import GroupStateImpl ArrayLike = NDArray DataFrameLike = PandasDataFrame @@ -51,6 +51,7 @@ PandasScalarIterUDFType = Literal[204] PandasMapIterUDFType = Literal[205] PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] +PandasGroupedMapUDFWithStateType = Literal[208] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... @@ -253,9 +254,11 @@ PandasScalarIterFunction = Union[ PandasGroupedMapFunction = Union[ Callable[[DataFrameLike], DataFrameLike], - Callable[[Any, DataFrameLike], DataFrameLike], + Callable[[Tuple, DataFrameLike], DataFrameLike], ] +PandasGroupedMapFunctionWithState = Callable[[Tuple, DataFrameLike, GroupStateImpl], DataFrameLike] + class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 94fabdbb29590..1c6c2219edcec 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, None, ]: # None means it should infer the type from type hints. @@ -399,6 +400,7 @@ def _create_pandas_udf(f, returnType, evalType): ) elif evalType in [ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 6178433573e9e..948fe5ce71355 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -15,18 +15,20 @@ # limitations under the License. # import sys -from typing import List, Union, TYPE_CHECKING +from typing import List, Union, TYPE_CHECKING, cast import warnings from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType +from pyspark.sql.streaming.state import GroupStateTimeout +from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, PandasGroupedMapFunction, + PandasGroupedMapFunctionWithState, PandasCogroupedMapFunction, ) from pyspark.sql.group import GroupedData @@ -216,6 +218,45 @@ def applyInPandas( jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.session) + def applyInPandasWithState( + self, + func: "PandasGroupedMapFunctionWithState", + outputStructType: Union[StructType, str], + stateStructType: Union[StructType, str], + outputMode: str, + timeoutConf: str, + ) -> DataFrame: + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + + if isinstance(outputStructType, str): + outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + if isinstance(stateStructType, str): + stateStructType = cast(StructType, _parse_datatype_string(stateStructType)) + + udf = pandas_udf( + func, # type: ignore[call-overload] + returnType=outputStructType, + functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ) + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.applyInPandasWithState( + udf_column._jc.expr(), + self.session._jsparkSession.parseDataType(outputStructType.json()), + self.session._jsparkSession.parseDataType(stateStructType.json()), + outputMode, + timeoutConf, + ) + return DataFrame(jdf, self.session) + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ Cogroups this group with another group so that we can run cogrouped operations. diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py new file mode 100644 index 0000000000000..6281dbadba61b --- /dev/null +++ b/python/pyspark/sql/streaming/state.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import datetime +import json +from typing import Tuple, Optional + +from pyspark.sql.types import DateType, Row, StructType + +__all__ = ["GroupStateImpl", "GroupStateTimeout"] + + +class GroupStateTimeout: + NoTimeout: str = "NoTimeout" + ProcessingTimeTimeout: str = "ProcessingTimeTimeout" + EventTimeTimeout: str = "EventTimeTimeout" + + +class GroupStateImpl: + NO_TIMESTAMP: int = -1 + + def __init__( + self, + # JVM Constructor + optionalValue: Row, + batchProcessingTimeMs: int, + eventTimeWatermarkMs: int, + timeoutConf: str, + hasTimedOut: bool, + watermarkPresent: bool, + # JVM internal state. + defined: bool, + updated: bool, + removed: bool, + timeoutTimestamp: int, + # Python internal state. + keySchema: StructType, + ) -> None: + self._value = optionalValue + self._batch_processing_time_ms = batchProcessingTimeMs + self._event_time_watermark_ms = eventTimeWatermarkMs + + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + self._timeout_conf = timeoutConf + + self._has_timed_out = hasTimedOut + self._watermark_present = watermarkPresent + + self._defined = defined + self._updated = updated + self._removed = removed + self._timeout_timestamp = timeoutTimestamp + + self._key_schema = keySchema + + @property + def exists(self) -> bool: + return self._defined + + @property + def get(self) -> Tuple: + if self.exists: + return tuple(self._value) + else: + raise ValueError("State is either not defined or has already been removed") + + @property + def getOption(self) -> Optional[Tuple]: + if self.exists: + return tuple(self._value) + else: + return None + + @property + def hasTimedOut(self) -> bool: + return self._has_timed_out + + def update(self, newValue: Tuple) -> None: + if newValue is None: + raise ValueError("'None' is not a valid state value") + + self._value = Row(*newValue) + self._defined = True + self._updated = True + self._removed = False + + def remove(self) -> None: + self._defined = False + self._updated = False + self._removed = True + + def setTimeoutDuration(self, durationMs: int) -> None: + if isinstance(durationMs, str): + # TODO(SPARK-XXXXX): Support string representation of durationMs. + raise ValueError("durationMs should be int but get :%s" % type(durationMs)) + + if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if durationMs <= 0: + raise ValueError("Timeout duration must be positive") + self._timeout_timestamp = durationMs + self._batch_processing_time_ms + + # TODO(SPARK-XXXXX): Implement additionalDuration parameter. + def setTimeoutTimestamp(self, timestampMs: int) -> None: + if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if isinstance(timestampMs, datetime.datetime): + timestampMs = DateType().toInternal(timestampMs) + + if timestampMs <= 0: + raise ValueError("Timeout timestamp must be positive") + + if ( + self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + and timestampMs < self._event_time_watermark_ms + ): + raise ValueError( + "Timeout timestamp (%s) cannot be earlier than the " + "current watermark (%s)" % (timestampMs, self._event_time_watermark_ms) + ) + + self._timeout_timestamp = timestampMs + + def getCurrentWatermarkMs(self) -> int: + if not self._watermark_present: + raise RuntimeError( + "Cannot get event time watermark timestamp without setting watermark before " + "applyInPandasWithState" + ) + return self._event_time_watermark_ms + + def getCurrentProcessingTimeMs(self) -> int: + return self._batch_processing_time_ms + + def __str__(self) -> str: + if self.exists: + return "GroupState(%s)" % self.get + else: + return "GroupState()" + + def json(self) -> str: + return json.dumps( + { + # Constructor + "optionalValue": None, # Note that optionalValue will be manually serialized. + "batchProcessingTimeMs": self._batch_processing_time_ms, + "eventTimeWatermarkMs": self._event_time_watermark_ms, + "timeoutConf": self._timeout_conf, + "hasTimedOut": self._has_timed_out, + "watermarkPresent": self._watermark_present, + # JVM internal state. + "defined": self._defined, + "updated": self._updated, + "removed": self._removed, + "timeoutTimestamp": self._timeout_timestamp, + } + ) diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py new file mode 100644 index 0000000000000..a9a56c557fabd --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import cast + +from pyspark.sql.streaming.state import GroupStateTimeout, GroupStateImpl +from pyspark.sql.types import ( + LongType, + StringType, + StructType, + StructField, + Row, +) +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pandas: + import pandas as pd + +if have_pyarrow: + import pyarrow as pa # noqa: F401 + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): + def test_apply_in_pandas_with_state_basic(self): + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_type = StructType( + [StructField("key", StringType()), StructField("countAsString", StringType())] + ) + state_type = StructType([StructField("c", LongType())]) + + def func(key, pdf, state): + assert isinstance(state, GroupStateImpl) + state.update((len(pdf),)) + assert state.get[0] == 1 + return pd.DataFrame({"key": [key[0]], "countAsString": [str(len(pdf))]}) + + def check_results(batch_df, _): + self.assertEqual( + set(batch_df.collect()), + {Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")}, + ) + + q = ( + df.groupBy(df["value"]) + .applyInPandasWithState( + func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout + ) + .writeStream.queryName("this_query") + .foreachBatch(check_results) + .start() + ) + + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 6a01e399d0400..417896ab738c7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -144,20 +144,23 @@ def returnType(self) -> DataType: "Invalid return type with scalar Pandas UDFs: %s is " "not supported" % str(self._returnType_placeholder) ) - elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif ( + self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE + ): if isinstance(self._returnType_placeholder, StructType): try: to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid return type with grouped map Pandas UDFs or " - "at groupby.applyInPandas: %s is not supported" + "at groupby.applyInPandas(withState): %s is not supported" % str(self._returnType_placeholder) ) else: raise TypeError( "Invalid return type for grouped map Pandas " - "UDFs or at groupby.applyInPandas: return type must be a " + "UDFs or at groupby.applyInPandas(withState): return type must be a " "StructType." ) elif ( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c486b7bed1d81..6f63a6cd5c471 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,7 @@ import time from inspect import currentframe, getframeinfo, getfullargspec import importlib +import json # 'resource' is a Unix specific module. has_resource_module = True @@ -62,6 +63,7 @@ from pyspark.sql.types import StructType from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle +from pyspark.sql.streaming.state import GroupStateImpl pickleSer = CPickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -207,6 +209,37 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_map_pandas_udf_with_state(f, return_type, state): + def wrapped(key_series, value_series): + import pandas as pd + + key = tuple(s[0] for s in key_series) + if state.hasTimedOut: + # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. + result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), state) + else: + result = f(key, pd.concat(value_series, axis=1), state) + + if not isinstance(result, pd.DataFrame): + raise TypeError( + "Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty + ): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + return result + + return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] + + def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) @@ -281,7 +314,7 @@ def wrapped(begin_index, end_index, *series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, state=None): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] chained_func = None @@ -311,6 +344,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type, state) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) @@ -327,6 +362,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): def read_udfs(pickleSer, infile, eval_type): runner_conf = {} + # Used for state support in Structured Streaming. + state = None + if eval_type in ( PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, @@ -336,6 +374,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ): # Load conf used for pandas_udf evaluation @@ -345,6 +384,21 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # 1. State properties + properties = json.loads(utf8_deserializer.loads(infile)) + + # 2. State key + length = read_int(infile) + row = None + if length > 0: + row = pickleSer.loads(infile.read(length)) + # 3. Schema for state key + key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + properties["optionalValue"] = row + + state = GroupStateImpl(keySchema=key_schema, **properties) + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -438,7 +492,7 @@ def map_batch(batch): ) # profiling is not supported for UDF - return func, None, ser, ser + return func, None, ser, ser, state def extract_key_value_indexes(grouped_arg_offsets): """ @@ -486,6 +540,23 @@ def mapper(a): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See PythonFlatMapGroupsWithStateExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, state=state + ) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + keys = [a[o] for o in parsed_offsets[0][0]] + vals = a # it's always all series. + return f(keys, vals) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. @@ -519,11 +590,12 @@ def func(_, it): return map(mapper, it) # profiling is not supported for UDF - return func, None, ser, ser + return func, None, ser, ser, state def main(infile, outfile): faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) + state = None try: if faulthandler_log_path: faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) @@ -585,6 +657,7 @@ def main(infile, outfile): ) # initialize global state + state = None taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() @@ -667,7 +740,9 @@ def main(infile, outfile): if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) + func, profiler, deserializer, serializer, state = read_udfs( + pickleSer, infile, eval_type + ) init_time = time.time() @@ -722,6 +797,18 @@ def process(): # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + + # Send GroupState back to JVM if exists. + if state is not None: + # 1. Send JSON-serialized GroupState + write_with_length(state.json().encode("utf-8"), outfile) + + # 2. Send pickled Row. + if state._value is None: + write_int(0, outfile) + else: + write_with_length(pickleSer.dumps(state._key_schema.toInternal(state._value)), outfile) + write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index a814525f870c9..479a097713f51 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -32,6 +32,9 @@ @Experimental @Evolving public class GroupStateTimeout { + // NOTE: if you're adding new type of timeout, you should also fix the places below: + // - Scala: org.apache.spark.sql.api.python.PythonSQLUtils.getGroupStateTimeoutFromString + // - Python: pyspark.sql.streaming.state.GroupStateTimeout /** * Timeout based on processing time. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 61bd361588be4..e5fe07e2d950d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -551,27 +551,6 @@ case class FlatMapGroupsWithState( copy(child = newLeft, initialState = newRight) } -case class UntypedFlatMapGroupsWithState( - func: (Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row], - groupingAttributes: Seq[Attribute], - outputAttrs: Seq[Attribute], - stateType: StructType, - outputMode: OutputMode, - isMapGroupsWithState: Boolean = false, - timeout: GroupStateTimeout, - child: LogicalPlan) extends UnaryNode { - if (isMapGroupsWithState) { - assert(outputMode == OutputMode.Update) - } - - override def output: Seq[Attribute] = outputAttrs - - override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) - - override protected def withNewChildInternal( - newChild: LogicalPlan): UntypedFlatMapGroupsWithState = copy(child = newChild) -} - /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { def apply( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 72593edb42aa0..67d072bc36824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -100,7 +100,23 @@ case class FlatMapCoGroupsInPandas( copy(left = newLeft, right = newRight) } -case class PythonFlatMapGroupsWithState( +/** + * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group + * in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * `functionExpr` is invoked with an pandas DataFrame representation and the + * grouping key (tuple). + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outputAttrs used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param outputMode the output mode of `func` + * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method + * @param timeout used to timeout groups that have not received data in a while + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithState( functionExpr: Expression, groupingAttributes: Seq[Attribute], outputAttrs: Seq[Attribute], @@ -118,7 +134,7 @@ case class PythonFlatMapGroupsWithState( override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) override protected def withNewChildInternal( - newChild: LogicalPlan): PythonFlatMapGroupsWithState = copy(child = newChild) + newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) } trait BaseEvalPython extends UnaryNode { diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java deleted file mode 100644 index e4634b15eca8c..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedFlatMapGroupsWithStateFunction.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.apache.spark.api.java.function; - -import java.io.Serializable; -import java.util.Iterator; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.streaming.GroupState; - -@Experimental -@Evolving -public interface UntypedFlatMapGroupsWithStateFunction extends Serializable { - Iterator call(Row key, Iterator values, GroupState state) throws Exception; -} diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java deleted file mode 100644 index 14167c84a8bed..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/UntypedMapGroupsWithStateFunction.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.apache.spark.api.java.function; - -import java.io.Serializable; -import java.util.Iterator; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.streaming.GroupState; - -@Experimental -@Evolving -public interface UntypedMapGroupsWithStateFunction extends Serializable { - Row call(Row key, Iterator values, GroupState state) throws Exception; -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index ad19b2f067f39..6c7b14b2334cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Stable -import org.apache.spark.api.java.function.{UntypedFlatMapGroupsWithStateFunction, UntypedMapGroupsWithStateFunction} import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedFunction} @@ -31,10 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} /** @@ -622,89 +622,15 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } - def mapGroupsWithState( - func: UntypedMapGroupsWithStateFunction, - outputStructType: StructType, - stateStructType: StructType, - timeoutConf: GroupStateTimeout): DataFrame = { - mapGroupsWithState( - outputStructType, stateStructType, timeoutConf)( - (key: Row, it: Iterator[Row], s: GroupState[Row]) => func.call(key, it.asJava, s) - ) - } - - def mapGroupsWithState( - outputStructType: StructType, - stateStructType: StructType, - timeoutConf: GroupStateTimeout)( - func: (Row, Iterator[Row], GroupState[Row]) => Row): DataFrame = { - val flatMapFunc = (key: Row, it: Iterator[Row], s: GroupState[Row]) => - Iterator(func(key, it, s)) - - val groupingNamedExpressions = groupingExprs.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() - } - val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) - val outputAttrs = outputStructType.toAttributes - val plan = UntypedFlatMapGroupsWithState( - flatMapFunc.asInstanceOf[(Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row]], - groupingAttrs, - outputAttrs, - stateStructType, - OutputMode.Update(), - isMapGroupsWithState = true, - timeoutConf, - child = df.logicalPlan) - Dataset.ofRows(df.sparkSession, plan) - } - - def flatMapGroupsWithState( - func: UntypedFlatMapGroupsWithStateFunction, - outputStructType: StructType, - stateStructType: StructType, - outputMode: OutputMode, - timeoutConf: GroupStateTimeout): DataFrame = { - val f = (key: Row, it: Iterator[Row], s: GroupState[Row]) => - func.call(key, it.asJava, s).asScala - flatMapGroupsWithState(outputStructType, stateStructType, outputMode, timeoutConf)(f) - } - - def flatMapGroupsWithState( - outputStructType: StructType, - stateStructType: StructType, - outputMode: OutputMode, - timeoutConf: GroupStateTimeout)( - func: (Row, Iterator[Row], GroupState[Row]) => Iterator[Row]): DataFrame = { - if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { - throw new IllegalArgumentException("The output mode of function should be append or update") - } - val groupingNamedExpressions = groupingExprs.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() - } - val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) - val outputAttrs = outputStructType.toAttributes - val plan = UntypedFlatMapGroupsWithState( - func.asInstanceOf[(Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row]], - groupingAttrs, - outputAttrs, - stateStructType, - outputMode, - isMapGroupsWithState = false, - timeoutConf, - child = df.logicalPlan) - Dataset.ofRows(df.sparkSession, plan) - } - - // FIXME: probably we have to change the type as String for outputMode and timeoutConf to provide - // parameters from Python? - private[sql] def pythonFlatMapGroupsWithState( + private[sql] def applyInPandasWithState( func: PythonUDF, outputStructType: StructType, stateStructType: StructType, - outputMode: OutputMode, - timeoutConf: GroupStateTimeout): DataFrame = { + outputModeStr: String, + timeoutConfStr: String): DataFrame = { + val timeoutConf = org.apache.spark.sql.execution.streaming + .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr) + val outputMode = InternalOutputModes(outputModeStr) if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { throw new IllegalArgumentException("The output mode of function should be append or update") } @@ -714,7 +640,7 @@ class RelationalGroupedDataset protected[sql]( } val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) val outputAttrs = outputStructType.toAttributes - val plan = PythonFlatMapGroupsWithState( + val plan = FlatMapGroupsInPandasWithState( func, groupingAttrs, outputAttrs, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 2b74bcc38501a..258d8a87f8b48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -22,14 +22,15 @@ import java.net.Socket import java.nio.channels.Channels import java.util.Locale -import net.razorvine.pickle.Pickler +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -37,12 +38,29 @@ import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} private[sql] object PythonSQLUtils extends Logging { - private lazy val internalRowPickler = { + private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = { EvaluatePython.registerPicklers() - new Pickler(true, false) + val pickler = new Pickler(true, false) + val ret = try { + f(pickler) + } finally { + pickler.close() + } + ret + } + + private def withInternalRowUnpickler(f: Unpickler => Any): Any = { + EvaluatePython.registerPicklers() + val unpickler = new Unpickler + val ret = try { + f(unpickler) + } finally { + unpickler.close() + } + ret } def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText) @@ -94,8 +112,18 @@ private[sql] object PythonSQLUtils extends Logging { def toPyRow(row: Row): Array[Byte] = { assert(row.isInstanceOf[GenericRowWithSchema]) - internalRowPickler.dumps(EvaluatePython.toJava( - CatalystTypeConverters.convertToCatalyst(row), row.schema)) + withInternalRowPickler(_.dumps(EvaluatePython.toJava( + CatalystTypeConverters.convertToCatalyst(row), row.schema))) + } + + def toJVMRow( + arr: Array[Byte], + returnType: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]): Row = { + val fromJava = EvaluatePython.makeFromJava(returnType) + val internalRow = + fromJava(withInternalRowUnpickler(_.loads(arr))).asInstanceOf[InternalRow] + deserializer(internalRow) } def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 767fd8cf9f11f..7ec47f469adde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -685,34 +685,15 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Strategy to convert [[UntypedFlatMapGroupsWithState]] logical operator to physical operator + * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ - object UntypedFlatMapGroupsWithStateStrategy extends Strategy { + object FlatMapGroupsInPandasWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case UntypedFlatMapGroupsWithState( + case FlatMapGroupsInPandasWithState( func, groupAttr, outputAttr, stateType, outputMode, _, timeout, child) => val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) - val execPlan = UntypedFlatMapGroupsWithStateExec( - func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, - batchTimestampMs = None, eventTimeWatermark = None, planLater(child) - ) - execPlan :: Nil - case _ => - Nil - } - } - - /** - * Strategy to convert [[UntypedFlatMapGroupsWithState]] logical operator to physical operator - * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. - */ - object PythonFlatMapGroupsWithStateStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PythonFlatMapGroupsWithState( - func, groupAttr, outputAttr, stateType, outputMode, _, timeout, child) => - val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) - val execPlan = PythonFlatMapGroupsWithStateExec( + val execPlan = python.FlatMapGroupsInPandasWithStateExec( func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child) ) @@ -831,12 +812,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil - // FIXME: implement it! - case _: logical.UntypedFlatMapGroupsWithState => - throw new UnsupportedOperationException("Not yet implemented for batch query!") - // FIXME: implement it! - case _: PythonFlatMapGroupsWithState => - throw new UnsupportedOperationException("Not yet implemented for batch query!") + case _: FlatMapGroupsInPandasWithState => + // TODO(SPARK-XXXXX): Implement batch support for applyInPandasWithState + throw new UnsupportedOperationException("applyInPandasWithState is unsupported.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala new file mode 100644 index 0000000000000..95e993f0685b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.nio.charset.StandardCharsets + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. + */ +class ArrowPythonRunnerWithState( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + protected override val schema: StructType, + protected override val timeZoneId: String, + protected override val workerConf: Map[String, String], + oldState: GroupStateImpl[Row], + deserializer: ExpressionEncoder.Deserializer[Row], + stateType: StructType) + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) + with PythonArrowInput + with PythonArrowOutput { + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + + var newGroupState: GroupStateImpl[Row] = _ + + protected override def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + + // 1. Send JSON-serialized GroupState + PythonRDD.writeUTF(oldState.json(), stream) + + // 2. Send pickled Row from the GroupState + val rowInState = oldState.getOption.map(PythonSQLUtils.toPyRow).getOrElse(Array.empty) + stream.writeInt(rowInState.length) + if (rowInState.length > 0) { + stream.write(rowInState) + } + + // 3. Send the state type to serialize the output state back from Python. + PythonRDD.writeUTF(stateType.json, stream) + } + + protected override def handleMetadataAfterExec(stream: DataInputStream): Unit = { + super.handleMetadataAfterExec(stream) + + implicit val formats = org.json4s.DefaultFormats + + // 1. Receive JSON-serialized GroupState + val jsonStr = new Array[Byte](stream.readInt()) + stream.readFully(jsonStr) + val properties = parse(new String(jsonStr, StandardCharsets.UTF_8)) + + // 2. Receive and deserialized pickled Row to JVM Row. + val length = stream.readInt() + val maybeRow = if (length > 0) { + val pickledRow = new Array[Byte](length) + stream.readFully(pickledRow) + Some(PythonSQLUtils.toJVMRow(pickledRow, stateType, deserializer)) + } else { + None + } + + // 3. Create a group state. + newGroupState = GroupStateImpl.fromJson(maybeRow, properties) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala new file mode 100644 index 0000000000000..85e03a262ec9e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, PythonUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, resolveArgOffsets} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing + * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]] + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outAttributes used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator. + * @param stateFormatVersion the version of state format. + * @param outputMode the output mode of `functionExpr` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithStateExec( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + + // TODO(SPARK-XXXXX): Add the support of initial state. + override protected val initialStateDeserializer: Expression = null + override protected val initialStateGroupAttrs: Seq[Attribute] = null + override protected val initialStateDataAttrs: Seq[Attribute] = null + override protected val initialState: SparkPlan = null + override protected val hasInitialState: Boolean = false + + override protected val stateEncoder: ExpressionEncoder[Any] = + RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + + override def output: Seq[Attribute] = outAttributes + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def shortName: String = "applyInPandasWithState" + + override protected def withNewChildInternal( + newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { + private val stateDeserializer = + stateEncoder.asInstanceOf[ExpressionEncoder[Row]].createDeserializer() + + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + + def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + + val runner = new ArrowPythonRunnerWithState( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), + sessionLocalTimeZone, + pythonRunnerConf, + groupedState, + stateDeserializer, + stateType) + + val inputIter = + if (hasTimedOut) Iterator.single(Iterator.single(stateData.keyRow)) + else Iterator.single(valueRowIter) + + val ret = executePython(inputIter, output, runner).toArray + numOutputRows += ret.length + val newGroupState: GroupStateImpl[Row] = runner.newGroupState + assert(newGroupState != null) + + // When the iterator is consumed, then write changes to state + def onIteratorCompletion: Unit = { + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, stateData.keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + hasTimeoutChanged + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } + + // Return an iterator of rows such that fully consumed, the updated state value will + // be saved + CompletionIterator[InternalRow, Iterator[InternalRow]](ret.iterator, onIteratorCompletion) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 1ffc6a64d7708..2da0000dad4ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -30,9 +30,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** * Base functionality for plans which execute grouped python udfs. */ -// FIXME: should we move PythonFlatMapGroupsWithStateExec to python package? -// private[python] -object PandasGroupUtils { +private[python] object PandasGroupUtils { /** * passes the data to the python runner and coverts the resulting * columnarbatch into internal rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 790a652f21124..ee9449151758b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -56,7 +56,7 @@ trait FlatMapGroupsWithStateExecBase protected val batchTimestampMs: Option[Long] val eventTimeWatermark: Option[Long] - protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout + private val isTimeoutEnabled = timeoutConf != NoTimeout protected val watermarkPresent: Boolean = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false @@ -271,8 +271,7 @@ trait FlatMapGroupsWithStateExecBase */ def processNewDataWithInitialState( childDataIter: Iterator[InternalRow], - initStateIter: Iterator[InternalRow] - ): Iterator[InternalRow] = { + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty @@ -284,7 +283,7 @@ trait FlatMapGroupsWithStateExecBase // Create a CoGroupedIterator that will group the two iterators together for every key group. new CoGroupedIterator( - groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { + groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { case (keyRow, valueRowIter, initialStateRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] var foundInitialStateForKey = false @@ -299,8 +298,8 @@ trait FlatMapGroupsWithStateExecBase // We apply the values for the key after applying the initial state. callFunctionAndUpdateState( stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false + valueRowIter, + hasTimedOut = false ) } } @@ -334,9 +333,9 @@ trait FlatMapGroupsWithStateExecBase * @param hasTimedOut Whether this function is being called for a key timeout */ protected def callFunctionAndUpdateState( - stateData: StateData, - valueRowIter: Iterator[InternalRow], - hasTimedOut: Boolean): Iterator[InternalRow] + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index b4f37125f4fa9..861ceabaf7f5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.sql.Date import java.util.concurrent.TimeUnit +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.api.java.Optional import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.util.IntervalUtils @@ -27,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.GroupStateImpl._ import org.apache.spark.sql.streaming.{GroupStateTimeout, TestGroupState} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * Internal implementation of the [[TestGroupState]] interface. Methods are not thread-safe. @@ -39,7 +43,10 @@ import org.apache.spark.unsafe.types.UTF8String * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class GroupStateImpl[S] private( +private[sql] class GroupStateImpl[S] private[sql]( + // NOTE:if you're adding new properties here, fix: + // - `json` and `fromJson` methods of this class in Scala + // - pyspark.sql.streaming.state.GroupStateImpl in Python optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, @@ -173,6 +180,22 @@ private[sql] class GroupStateImpl[S] private( throw QueryExecutionErrors.cannotSetTimeoutTimestampError() } } + + private[sql] def json(): String = compact(render(new JObject( + // Constructor + "optionalValue" -> JNull :: // Note that optionalValue will be manually serialized. + "batchProcessingTimeMs" -> JLong(batchProcessingTimeMs) :: + "eventTimeWatermarkMs" -> JLong(eventTimeWatermarkMs) :: + "timeoutConf" -> JString(Utils.stripDollars(Utils.getSimpleName(timeoutConf.getClass))) :: + "hasTimedOut" -> JBool(hasTimedOut) :: + "watermarkPresent" -> JBool(watermarkPresent) :: + + // Internal state + "defined" -> JBool(defined) :: + "updated" -> JBool(updated) :: + "removed" -> JBool(removed) :: + "timeoutTimestamp" -> JLong(timeoutTimestamp) :: Nil + ))) } @@ -214,4 +237,35 @@ private[sql] object GroupStateImpl { hasTimedOut = false, watermarkPresent) } + + def groupStateTimeoutFromString(clazz: String): GroupStateTimeout = clazz match { + case "ProcessingTimeTimeout" => GroupStateTimeout.ProcessingTimeTimeout + case "EventTimeTimeout" => GroupStateTimeout.EventTimeTimeout + case "NoTimeout" => GroupStateTimeout.NoTimeout + case _ => throw new IllegalStateException("Invalid string for GroupStateTimeout: " + clazz) + } + + def fromJson[S](key: Option[S], json: JValue): GroupStateImpl[S] = { + implicit val formats = org.json4s.DefaultFormats + + val hmap = json.extract[Map[String, Any]] + + // Constructor + val newGroupState = new GroupStateImpl[S]( + key, + hmap("batchProcessingTimeMs").asInstanceOf[Number].longValue(), + hmap("eventTimeWatermarkMs").asInstanceOf[Number].longValue(), + groupStateTimeoutFromString(hmap("timeoutConf").asInstanceOf[String]), + hmap("hasTimedOut").asInstanceOf[Boolean], + hmap("watermarkPresent").asInstanceOf[Boolean]) + + // Internal state + newGroupState.defined = hmap("defined").asInstanceOf[Boolean] + newGroupState.updated = hmap("updated").asInstanceOf[Boolean] + newGroupState.removed = hmap("removed").asInstanceOf[Boolean] + newGroupState.timeoutTimestamp = + hmap("timeoutTimestamp").asInstanceOf[Number].longValue() + + newGroupState + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 105df798e7059..f386282a0b3e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -62,8 +63,7 @@ class IncrementalExecution( StreamingJoinStrategy :: StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: - UntypedFlatMapGroupsWithStateStrategy :: - PythonFlatMapGroupsWithStateStrategy :: + FlatMapGroupsInPandasWithStateStrategy :: StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: Nil @@ -212,14 +212,7 @@ class IncrementalExecution( hasInitialState = hasInitialState ) - case m: UntypedFlatMapGroupsWithStateExec => - m.copy( - stateInfo = Some(nextStatefulOperationStateInfo), - batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) - ) - - case m: PythonFlatMapGroupsWithStateExec => + case m: FlatMapGroupsInPandasWithStateExec => m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala deleted file mode 100644 index 05c0a396a3e48..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/PythonFlatMapGroupsWithStateExec.scala +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit.NANOSECONDS - -import org.apache.spark.api.python.ChainedPythonFunctions -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, PythonUDF, SortOrder, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, EventTimeWatermark, ProcessingTimeTimeout} -import org.apache.spark.sql.catalyst.plans.physical.Distribution -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.python.PandasGroupUtils._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreOps} -import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.createStateManager -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.util.CompletionIterator - -case class PythonFlatMapGroupsWithStateExec( - func: Expression, - groupingAttributes: Seq[Attribute], - outAttributes: Seq[Attribute], - stateType: StructType, - stateInfo: Option[StatefulOperatorStateInfo], - stateFormatVersion: Int, - outputMode: OutputMode, - timeoutConf: GroupStateTimeout, - batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], - child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { - - override def output: Seq[Attribute] = outAttributes - private val isTimeoutEnabled = timeoutConf != NoTimeout - - private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - private val pythonFunction = func.asInstanceOf[PythonUDF].func - private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) - - private val watermarkPresent = child.output.exists { - case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true - case _ => false - } - - private val outputType = outAttributes.toStructType - private val keyEncoder = RowEncoder(groupingAttributes.toStructType) - .resolveAndBind(groupingAttributes) - private val valueEncoder = RowEncoder(child.output.toStructType).resolveAndBind(child.output) - private val stateEncoder = RowEncoder(stateType).resolveAndBind() - private val outputEncoder = RowEncoder(outputType).resolveAndBind(outAttributes) - - private[sql] val stateManager = - createStateManager(stateEncoder.asInstanceOf[ExpressionEncoder[Any]], isTimeoutEnabled, - stateFormatVersion) - - override def requiredChildDistribution: Seq[Distribution] = - StatefulOperatorPartitioning.getCompatibleDistribution( - groupingAttributes, getStateInfo, conf) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( - groupingAttributes.map(SortOrder(_, Ascending))) - - override def keyExpressions: Seq[Attribute] = groupingAttributes - - override def shortName: String = "pythonFlatMapGroupsWithState" - - override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { - timeoutConf match { - case ProcessingTimeTimeout => - true // Always run batches to process timeouts - case EventTimeTimeout => - // Process another non-data batch only if the watermark has changed in this executed plan - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get - case _ => - false - } - } - - /** - * Process data by applying the user defined function on a per partition basis. - * - * @param iter - Iterator of the data rows - * @param store - associated state store for this partition - * @param processor - handle to the input processor object. - */ - def processDataWithPartition( - iter: Iterator[InternalRow], - store: StateStore, - processor: InputProcessor): CompletionIterator[InternalRow, Iterator[InternalRow]] = { - val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") - val commitTimeMs = longMetric("commitTimeMs") - val timeoutLatencyMs = longMetric("allRemovalsTimeMs") - - val currentTimeNs = System.nanoTime - val updatesStartTimeNs = currentTimeNs - var timeoutProcessingStartTimeNs = currentTimeNs - - // If timeout is based on event time, then filter late data based on watermark - val filteredIter = watermarkPredicateForData match { - case Some(predicate) if timeoutConf == EventTimeTimeout => - applyRemovingRowsOlderThanWatermark(iter, predicate) - case _ => - iter - } - - val processedOutputIterator = processor.processNewData(filteredIter) - - val newDataProcessorIter = - CompletionIterator[InternalRow, Iterator[InternalRow]]( - processedOutputIterator, { - // Once the input is processed, mark the start time for timeout processing to measure - // it separately from the overall processing time. - timeoutProcessingStartTimeNs = System.nanoTime - }) - - // SPARK-38320: Late-bind the timeout processing iterator so it is created *after* the input is - // processed (the input iterator is exhausted) and the state updates are written into the - // state store. Otherwise the iterator may not see the updates (e.g. with RocksDB state store). - val timeoutProcessorIter = new Iterator[InternalRow] { - private lazy val itr = getIterator() - override def hasNext = itr.hasNext - override def next() = itr.next() - private def getIterator(): Iterator[InternalRow] = - CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { - // Note: `timeoutLatencyMs` also includes the time the parent operator took for - // processing output returned through iterator. - timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) - }) - } - - // Generate a iterator that returns the rows grouped by the grouping function - // Note that this code ensures that the filtering for timeout occurs only after - // all the data has been processed. This is to ensure that the timeout information of all - // the keys with data is updated before they are processed for timeouts. - val outputIterator = newDataProcessorIter ++ timeoutProcessorIter - - // Return an iterator of all the rows generated by all the keys, such that when fully - // consumed, all the state updates will be committed by the state store - CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { - // Note: Due to the iterator lazy execution, this metric also captures the time taken - // by the upstream (consumer) operators in addition to the processing in this operator. - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - commitTimeMs += timeTakenMs { - store.commit() - } - setStoreMetrics(store) - setOperatorMetrics() - }) - } - - override protected def doExecute(): RDD[InternalRow] = { - metrics // force lazy init at driver - - // Throw errors early if parameters are not as expected - timeoutConf match { - case ProcessingTimeTimeout => - require(batchTimestampMs.nonEmpty) - case EventTimeTimeout => - require(eventTimeWatermark.nonEmpty) // watermark value has been populated - require(watermarkExpression.nonEmpty) // input schema has watermark attribute - case _ => - } - - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateInfo, - groupingAttributes.toStructType, - stateManager.stateSchema, - numColsPrefixKey = 0, - session.sqlContext.sessionState, - Some(session.sqlContext.streams.stateStoreCoordinator) - ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => - val processor = new InputProcessor(store) - processDataWithPartition(singleIterator, store, processor) - } - } - - /** Helper class to update the state store */ - class InputProcessor(store: StateStore) { - private val keyDeserializer = keyEncoder.createDeserializer() - private val valueDeserializer = valueEncoder.createDeserializer() - private val outputSerializer = outputEncoder.createSerializer() - - // Metrics - private val numUpdatedStateRows = longMetric("numUpdatedStateRows") - private val numOutputRows = longMetric("numOutputRows") - private val numRemovedStateRows = longMetric("numRemovedStateRows") - - /** - * For every group, get the key, values and corresponding state and call the function, - * and return an iterator of rows - */ - def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { - val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) - - val data = groupAndProject(dataIter, groupingAttributes, child.output, - dedupAttributes).map { case (keyRow, valueRowIter) => - - val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] - val stateData = stateManager.getState(store, keyUnsafeRow) - val groupedState = GroupStateImpl.createForStreaming( - Option(stateData.stateObj), // TODO: check whether the object is Row or not - batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), - timeoutConf, - hasTimedOut = false, - watermarkPresent).asInstanceOf[GroupStateImpl[Row]] - - // UnsafeRow, Iterator[UnsafeRow], GroupStateImpl[Row] - (keyRow, valueRowIter, groupedState) - } - - // FIXME: need to construct the code to pass the iterator of (key, valueIter, GroupState) - // and receive an iterator of (outputs, state update). - - // FIXME: outputs should be produced to the downstream, with conversion from Row to - // InternalRow. - // FIXME: state updates should be reflected to the state store. - // FIXME: refer UntypedFlatMapGroupsWithStateExec.callFunctionAndUpdateState for more details - - // FIXME: pretty sure this is a dummy code - Iterator.empty - } - - /** Find the groups that have timeout set and are timing out right now, and call the function */ - def processTimedOutState(): Iterator[InternalRow] = { - if (isTimeoutEnabled) { - val timeoutThreshold = timeoutConf match { - case ProcessingTimeTimeout => batchTimestampMs.get - case EventTimeTimeout => eventTimeWatermark.get - case _ => - throw new IllegalStateException( - s"Cannot filter timed out keys for $timeoutConf") - } - - val data = stateManager.getAllState(store).filter { state => - state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold - }.map { stateData => - val groupedState = GroupStateImpl.createForStreaming( - Option(stateData.stateObj), // TODO: check whether the object is Row or not - batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), - timeoutConf, - hasTimedOut = true, - watermarkPresent).asInstanceOf[GroupStateImpl[Row]] - - // UnsafeRow, Iterator[UnsafeRow], GroupStateImpl[Row] - (stateData.keyRow, Iterator.empty.asInstanceOf[Iterator[UnsafeRow]], groupedState) - } - - // FIXME: need to construct the code to pass the iterator of (key, valueIter, GroupState) - // and receive an iterator of (outputs, state update). - - // FIXME: outputs should be produced to the downstream, with conversion from Row to - // InternalRow. - // FIXME: state updates should be reflected to the state store. - - // FIXME: pretty sure this is a dummy code - Iterator.empty - } else Iterator.empty - } - } - - override protected def withNewChildInternal( - newChild: SparkPlan): PythonFlatMapGroupsWithStateExec = copy(child = newChild) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala deleted file mode 100644 index 4214eebcc26c7..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UntypedFlatMapGroupsWithStateExec.scala +++ /dev/null @@ -1,290 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit.NANOSECONDS - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} -import org.apache.spark.sql.catalyst.plans.physical.Distribution -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreOps} -import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.{createStateManager, StateData} -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.CompletionIterator - -case class UntypedFlatMapGroupsWithStateExec( - func: (Row, Iterator[Row], LogicalGroupState[Row]) => Iterator[Row], - groupingAttributes: Seq[Attribute], - outAttributes: Seq[Attribute], - stateType: StructType, - stateInfo: Option[StatefulOperatorStateInfo], - stateFormatVersion: Int, - outputMode: OutputMode, - timeoutConf: GroupStateTimeout, - batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], - child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { - - override def output: Seq[Attribute] = outAttributes - private val isTimeoutEnabled = timeoutConf != NoTimeout - - private val watermarkPresent = child.output.exists { - case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true - case _ => false - } - - private val outputType = outAttributes.toStructType - private val keyEncoder = RowEncoder(groupingAttributes.toStructType) - .resolveAndBind(groupingAttributes) - private val valueEncoder = RowEncoder(child.output.toStructType).resolveAndBind(child.output) - private val stateEncoder = RowEncoder(stateType).resolveAndBind() - private val outputEncoder = RowEncoder(outputType).resolveAndBind(outAttributes) - - private[sql] val stateManager = - createStateManager(stateEncoder.asInstanceOf[ExpressionEncoder[Any]], isTimeoutEnabled, - stateFormatVersion) - - override def requiredChildDistribution: Seq[Distribution] = - StatefulOperatorPartitioning.getCompatibleDistribution( - groupingAttributes, getStateInfo, conf) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( - groupingAttributes.map(SortOrder(_, Ascending))) - - override def keyExpressions: Seq[Attribute] = groupingAttributes - - override def shortName: String = "untypedFlatMapGroupsWithState" - - override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { - timeoutConf match { - case ProcessingTimeTimeout => - true // Always run batches to process timeouts - case EventTimeTimeout => - // Process another non-data batch only if the watermark has changed in this executed plan - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get - case _ => - false - } - } - - /** - * Process data by applying the user defined function on a per partition basis. - * - * @param iter - Iterator of the data rows - * @param store - associated state store for this partition - * @param processor - handle to the input processor object. - */ - def processDataWithPartition( - iter: Iterator[InternalRow], - store: StateStore, - processor: InputProcessor): CompletionIterator[InternalRow, Iterator[InternalRow]] = { - val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") - val commitTimeMs = longMetric("commitTimeMs") - val timeoutLatencyMs = longMetric("allRemovalsTimeMs") - - val currentTimeNs = System.nanoTime - val updatesStartTimeNs = currentTimeNs - var timeoutProcessingStartTimeNs = currentTimeNs - - // If timeout is based on event time, then filter late data based on watermark - val filteredIter = watermarkPredicateForData match { - case Some(predicate) if timeoutConf == EventTimeTimeout => - applyRemovingRowsOlderThanWatermark(iter, predicate) - case _ => - iter - } - - val processedOutputIterator = processor.processNewData(filteredIter) - - val newDataProcessorIter = - CompletionIterator[InternalRow, Iterator[InternalRow]]( - processedOutputIterator, { - // Once the input is processed, mark the start time for timeout processing to measure - // it separately from the overall processing time. - timeoutProcessingStartTimeNs = System.nanoTime - }) - - // SPARK-38320: Late-bind the timeout processing iterator so it is created *after* the input is - // processed (the input iterator is exhausted) and the state updates are written into the - // state store. Otherwise the iterator may not see the updates (e.g. with RocksDB state store). - val timeoutProcessorIter = new Iterator[InternalRow] { - private lazy val itr = getIterator() - override def hasNext = itr.hasNext - override def next() = itr.next() - private def getIterator(): Iterator[InternalRow] = - CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { - // Note: `timeoutLatencyMs` also includes the time the parent operator took for - // processing output returned through iterator. - timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) - }) - } - - // Generate a iterator that returns the rows grouped by the grouping function - // Note that this code ensures that the filtering for timeout occurs only after - // all the data has been processed. This is to ensure that the timeout information of all - // the keys with data is updated before they are processed for timeouts. - val outputIterator = newDataProcessorIter ++ timeoutProcessorIter - - // Return an iterator of all the rows generated by all the keys, such that when fully - // consumed, all the state updates will be committed by the state store - CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { - // Note: Due to the iterator lazy execution, this metric also captures the time taken - // by the upstream (consumer) operators in addition to the processing in this operator. - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - commitTimeMs += timeTakenMs { - store.commit() - } - setStoreMetrics(store) - setOperatorMetrics() - }) - } - - override protected def doExecute(): RDD[InternalRow] = { - metrics // force lazy init at driver - - // Throw errors early if parameters are not as expected - timeoutConf match { - case ProcessingTimeTimeout => - require(batchTimestampMs.nonEmpty) - case EventTimeTimeout => - require(eventTimeWatermark.nonEmpty) // watermark value has been populated - require(watermarkExpression.nonEmpty) // input schema has watermark attribute - case _ => - } - - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateInfo, - groupingAttributes.toStructType, - stateManager.stateSchema, - numColsPrefixKey = 0, - session.sqlContext.sessionState, - Some(session.sqlContext.streams.stateStoreCoordinator) - ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => - val processor = new InputProcessor(store) - processDataWithPartition(singleIterator, store, processor) - } - } - - /** Helper class to update the state store */ - class InputProcessor(store: StateStore) { - private val keyDeserializer = keyEncoder.createDeserializer() - private val valueDeserializer = valueEncoder.createDeserializer() - private val outputSerializer = outputEncoder.createSerializer() - - // Metrics - private val numUpdatedStateRows = longMetric("numUpdatedStateRows") - private val numOutputRows = longMetric("numOutputRows") - private val numRemovedStateRows = longMetric("numRemovedStateRows") - - /** - * For every group, get the key, values and corresponding state and call the function, - * and return an iterator of rows - */ - def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { - val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) - groupedIter.flatMap { case (keyRow, valueRowIter) => - val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] - callFunctionAndUpdateState( - stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false) - } - } - - /** Find the groups that have timeout set and are timing out right now, and call the function */ - def processTimedOutState(): Iterator[InternalRow] = { - if (isTimeoutEnabled) { - val timeoutThreshold = timeoutConf match { - case ProcessingTimeTimeout => batchTimestampMs.get - case EventTimeTimeout => eventTimeWatermark.get - case _ => - throw new IllegalStateException( - s"Cannot filter timed out keys for $timeoutConf") - } - val timingOutPairs = stateManager.getAllState(store).filter { state => - state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold - } - timingOutPairs.flatMap { stateData => - callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) - } - } else Iterator.empty - } - - /** - * Call the user function on a key's data, update the state store, and return the return data - * iterator. Note that the store updating is lazy, that is, the store will be updated only - * after the returned iterator is fully consumed. - * - * @param stateData All the data related to the state to be updated - * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param hasTimedOut Whether this function is being called for a key timeout - */ - private def callFunctionAndUpdateState( - stateData: StateData, - valueRowIter: Iterator[InternalRow], - hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyRowAsUntyped = keyDeserializer(stateData.keyRow) - val valueRowsIterAsUntyped = valueRowIter.map(valueDeserializer.apply) - - val groupState = GroupStateImpl.createForStreaming( - Option(stateData.stateObj), // TODO: check whether the object is Row or not - batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), - timeoutConf, - hasTimedOut, - watermarkPresent).asInstanceOf[GroupStateImpl[Row]] - - // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyRowAsUntyped, valueRowsIterAsUntyped, groupState).map { row => - numOutputRows += 1 - outputSerializer(row) - } - - // When the iterator is consumed, then write changes to state - def onIteratorCompletion: Unit = { - if (groupState.isRemoved && !groupState.getTimeoutTimestampMs.isPresent()) { - stateManager.removeState(store, stateData.keyRow) - numRemovedStateRows += 1 - } else { - val currentTimeoutTimestamp = groupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP) - val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp - val shouldWriteState = groupState.isUpdated || groupState.isRemoved || hasTimeoutChanged - - if (shouldWriteState) { - val updatedStateObj = if (groupState.exists) groupState.get else null - stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) - numUpdatedStateRows += 1 - } - } - } - - // Return an iterator of rows such that fully consumed, the updated state value will be saved - CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) - } - } - - override protected def withNewChildInternal( - newChild: SparkPlan): UntypedFlatMapGroupsWithStateExec = copy(child = newChild) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 01ff72bac7bcc..022fd1239ce4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -49,7 +49,7 @@ package object state { } /** Map each partition of an RDD along with data in a [[StateStore]]. */ - private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( + def mapPartitionsWithStateStore[U: ClassTag]( stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 827cfcf32fead..3c41f6b47b5ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType} /** * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF, @@ -190,7 +191,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasFunc: Array[Byte] = if (shouldTestScalarPandasUDFs) { + private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -213,7 +214,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) { + private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -235,6 +236,34 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } + private def createPandasGroupedMapFuncWithState(pythonScript: String): Array[Byte] = { + if (shouldTestPandasUDFs) { + var binaryPandasFunc: Array[Byte] = null + withTempPath { codePath => + Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8)) + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + "from pyspark.serializers import CloudPickleSerializer; " + + s"f = open('$path', 'wb');" + + s"exec(open('$codePath', 'r').read());" + + "f.write(CloudPickleSerializer().dumps((" + + "func, tpe)))"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } else { + throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") + } + } + + // Make sure this map stays mutable - this map gets updated later in Python runners. private val workerEnv = new java.util.HashMap[String, String]() workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") @@ -251,11 +280,9 @@ object IntegratedUDFTestUtils extends SQLHelper { lazy val shouldTestPythonUDFs: Boolean = isPythonAvailable && isPySparkAvailable - lazy val shouldTestScalarPandasUDFs: Boolean = + lazy val shouldTestPandasUDFs: Boolean = isPythonAvailable && isPandasAvailable && isPyArrowAvailable - lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs - /** * A base trait for various UDFs defined in this object. */ @@ -420,6 +447,41 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String = "Grouped Aggregate Pandas UDF" } + /** + * Arbitrary stateful processing in Python is used for + * `DataFrame.groupBy.applyInPandasWithState`. It requires `pythonScript` to + * define `func` (Python function) and `tpe` (`StructType` for state key). + * + * Virtually equivalent to: + * + * {{{ + * # exec defines 'func' and 'tpe' (struct type for state key) + * exec(pythonScript) + * + * # ... are filled when this UDF is invoked, see also 'PythonFlatMapGroupsWithStateSuite'. + * df.groupBy(...).applyInPandasWithState(func, ..., tpe, ..., ...) + * }}} + */ + case class TestGroupedMapPandasUDFWithState(name: String, pythonScript: String) extends TestUDF { + private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( + name = name, + func = SimplePythonFunction( + command = createPandasGroupedMapFuncWithState(pythonScript), + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = NullType, // This is not respected. + pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + udfDeterministic = true) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Grouped Map Pandas UDF with State" + } + /** * A Scala UDF that takes one column, casts into string, executes the * Scala native function, and casts back to the type of input column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index cca9bb6741f68..a662caea74a94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -244,7 +244,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper /* Do nothing */ } case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestPandasUDFs => ignore(s"${testCase.name} is skipped because pyspark," + s"pandas and/or pyarrow were not available in [$pythonExec].") { /* Do nothing */ @@ -433,7 +433,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper if udfTestCase.udf.isInstanceOf[TestPythonUDF] && shouldTestPythonUDFs => s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}" case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestPandasUDFs => s"${testCase.name}${System.lineSeparator()}" + s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}" case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 00c774e2d1bee..92aadb6779e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -128,7 +128,7 @@ class QueryCompilationErrorsSuite test("INVALID_PANDAS_UDF_PLACEMENT: Using aggregate function with grouped aggregate pandas UDF") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), @@ -180,7 +180,7 @@ class QueryCompilationErrorsSuite test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala new file mode 100644 index 0000000000000..606d5d56744c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.IntegratedUDFTestUtils._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ + +class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { + + import testImplicits._ + + test("flatMapGroupsWithState - streaming") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | if count == 3: + | state.remove() + | return pd.DataFrame() + | else: + | state.update((count,)) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("flatMapGroupsWithState - streaming + aggregation") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | if count == 3: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | state.update((count,)) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Append", + "NoTimeout") + .groupBy("key") + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckNewAnswer(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckNewAnswer(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckNewAnswer(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("flatMapGroupsWithState - streaming with processing time timeout") { + assume(shouldTestPandasUDFs) + + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | if state.hasTimedOut: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | state.update((count,)) + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + StopStream, + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(11 * 1000), + CheckNewAnswer(("b", "-1"), ("c", "1")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows( + total = Seq(0), updated = Seq(0), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } + + test("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { + assume(shouldTestPandasUDFs) + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + val pythonScript = + """ + |import calendar + |import os + |import datetime + |import pandas as pd + |from pyspark.sql.types import StructType, StringType, StructField, IntegerType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("maxEventTimeSec", IntegerType())]) + | + |def func(key, pdf, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | assert state.getCurrentWatermarkMs() >= -1 + | + | timeout_delay_sec = 5 + | if state.hasTimedOut: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) + | else: + | m = state.getOption + | if m is None: + | m = 0 + | else: + | m = m[0] + | + | pser = pdf.eventTime.apply( + | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) + | max_event_time_sec = int(max(pser.max(), m)) + | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec + | state.update((max_event_time_sec,)) + | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + | return pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [max_event_time_sec]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[(String, Int)] + val inputDataDF = + inputData.toDF.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("maxEventTimeSec", IntegerType))) + val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) + val result = + inputDataDF + .withWatermark("eventTime", "10 seconds") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "EventTimeTimeout") + + testStream(result, Update)( + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + } + + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + assume(shouldTestPandasUDFs) + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // String, (String, Long), RunningCount(Long) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | if state.hasTimedOut: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | state.update((count,)) + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val inputDataDF = inputData + .toDF.toDF("key", "time") + .selectExpr("key", "timestamp_seconds(time) as timestamp") + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDF + .withWatermark("timestamp", "10 second") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("timestamp")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")) + ) + } + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala deleted file mode 100644 index 60b66fb1319ef..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UntypedFlatMapGroupsWithStateSuite.scala +++ /dev/null @@ -1,438 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.scalatest.exceptions.TestFailedException - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions.timestamp_seconds -import org.apache.spark.sql.streaming.GroupStateTimeout.EventTimeTimeout -import org.apache.spark.sql.streaming.util.StreamManualClock -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} - -class UntypedFlatMapGroupsWithStateSuite extends StateStoreMetricsTest { - - import testImplicits._ - - import FlatMapGroupsWithStateSuite._ - - /** - * Sample `flatMapGroupsWithState` function implementation. It maintains the max event time as - * state and set the timeout timestamp based on the current max event time seen. It returns the - * max event time in the state, or -1 if the state was removed by timeout. Timeout is 5sec. - */ - val sampleTestFunction = - (key: Row, values: Iterator[Row], state: GroupState[Row]) => { - assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } - assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } - - // key: String, values: (String, Timestamp), state: Long, output: (String, Int) - val keyAsString = key.getString(0) - - val timeoutDelaySec = 5 - if (state.hasTimedOut) { - state.remove() - Iterator(Row(keyAsString, -1)) - } else { - val valuesSeq = values.toSeq - val maxEventTimeSec = math.max(valuesSeq.map(_.getTimestamp(1).getTime / 1000).max, - state.getOption.map(_.getLong(0)).getOrElse(0L)) - val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec - state.update(Row(maxEventTimeSec)) - state.setTimeoutTimestamp(timeoutTimestampSec * 1000) - Iterator(Row(keyAsString, maxEventTimeSec.toInt)) - } - } - - test("flatMapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { - assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } - assertCannotGetWatermark { state.getCurrentWatermarkMs() } - - val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - Iterator.empty - } else { - state.update(Row(count)) - Iterator(Row(key.getString(0), count.toString)) - } - } - - val inputData = MemoryStream[String] - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("countAsString", StringType))) - val stateStructType = StructType(Seq(StructField("count", LongType))) - val result = - inputData.toDS() - .groupBy("value") - .flatMapGroupsWithState( - outputStructType, stateStructType, Update, GroupStateTimeout.NoTimeout)(stateFunc) - - testStream(result, Update)( - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckNewAnswer(("b", "2")), - assertNumStateRows( - total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckNewAnswer(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - // Additionally, it updates state lazily as the returned iterator get consumed - val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { - values.flatMap { _ => - val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + 1 - if (count == 3) { - state.remove() - None - } else { - state.update(Row(count)) - Some(Row(key.getString(0), count.toString)) - } - } - } - - val inputData = MemoryStream[String] - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("countAsString", StringType))) - val stateStructType = StructType(Seq(StructField("count", LongType))) - val result = - inputData.toDS() - .groupBy("value") - .flatMapGroupsWithState(outputStructType, stateStructType, Update, - GroupStateTimeout.NoTimeout)(stateFunc) - testStream(result, Update)( - AddData(inputData, "a", "a", "b"), - CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckNewAnswer(("b", "2")), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckNewAnswer(("a", "1"), ("c", "1")) - ) - } - - test("flatMapGroupsWithState - streaming + aggregation") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { - - val keyAsString = key.getString(0) - - val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - Iterator(Row(keyAsString, "-1")) - } else { - state.update(Row(count)) - Iterator(Row(keyAsString, count.toString)) - } - } - - val inputData = MemoryStream[String] - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("countAsString", StringType))) - val stateStructType = StructType(Seq(StructField("count", LongType))) - val result = - inputData.toDS() - .groupBy("value") - .flatMapGroupsWithState(outputStructType, stateStructType, Append, - GroupStateTimeout.NoTimeout)(stateFunc) - .groupBy("key") - .count() - - testStream(result, Complete)( - AddData(inputData, "a"), - CheckNewAnswer(("a", 1)), - AddData(inputData, "a", "b"), - // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckNewAnswer(("a", 2), ("b", 1)), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), - // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; - // so increment a and b by 1 - CheckNewAnswer(("a", 3), ("b", 2)), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), - // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; - // so increment a and c by 1 - CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) - ) - } - - test("flatMapGroupsWithState - streaming with processing time timeout") { - // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. - // It returns the count if changed, or -1 if the state was removed by timeout. - val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { - assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } - assertCannotGetWatermark { state.getCurrentWatermarkMs() } - - val keyAsString = key.getString(0) - if (state.hasTimedOut) { - state.remove() - Iterator(Row(keyAsString, "-1")) - } else { - val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size - state.update(Row(count)) - state.setTimeoutDuration("10 seconds") - Iterator(Row(keyAsString, count.toString)) - } - } - - val clock = new StreamManualClock - val inputData = MemoryStream[String] - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("countAsString", StringType))) - val stateStructType = StructType(Seq(StructField("count", LongType))) - val result = - inputData.toDS() - .groupBy("value") - .flatMapGroupsWithState(outputStructType, stateStructType, Update, - ProcessingTimeTimeout)(stateFunc) - - testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - - AddData(inputData, "b"), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("b", "1")), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, "b"), - AdvanceManualClock(10 * 1000), - CheckNewAnswer(("a", "-1"), ("b", "2")), - assertNumStateRows( - total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), - - StopStream, - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - - AddData(inputData, "c"), - AdvanceManualClock(11 * 1000), - CheckNewAnswer(("b", "-1"), ("c", "1")), - assertNumStateRows( - total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), - - AdvanceManualClock(12 * 1000), - AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, - Execute { q => - failAfter(streamingTimeout) { - while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { - Thread.sleep(1) - } - } - }, - CheckNewAnswer(("c", "-1")), - assertNumStateRows( - total = Seq(0), updated = Seq(0), droppedByWatermark = Seq(0), removed = Some(Seq(1))) - ) - } - - test("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { - val inputData = MemoryStream[(String, Int)] - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("maxEventTimeSec", IntegerType))) - val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) - val result = - inputData.toDS - .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) - .withWatermark("eventTime", "10 seconds") - .groupBy("key") - .flatMapGroupsWithState(outputStructType, stateStructType, Update, - EventTimeTimeout)(sampleTestFunction) - - testStream(result, Update)( - StartStream(), - - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), - // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. - CheckNewAnswer(("a", 15)), // Output = max event time of a - - AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckNewAnswer(), // No output as data should get filtered by watermark - - AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" - CheckNewAnswer(("a", 15)), // Max event time is still the same - // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. - // Watermark is still 5 as max event time for all data is still 15. - - AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" - // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. - CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 - ) - } - - test("mapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: Row, values: Iterator[Row], state: GroupState[Row]) => { - assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } - assertCannotGetWatermark { state.getCurrentWatermarkMs() } - - val keyAsString = key.getString(0) - val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - Row(keyAsString, "-1") - } else { - state.update(Row(count)) - Row(keyAsString, count.toString) - } - } - - val inputData = MemoryStream[String] - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("countAsString", StringType))) - val stateStructType = StructType(Seq(StructField("count", LongType))) - val result = - inputData.toDS() - .groupBy("value") - .mapGroupsWithState(outputStructType, stateStructType, - GroupStateTimeout.NoTimeout)(stateFunc) // Types = State: MyState, Out: (Str, Str) - - testStream(result, Update)( - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckNewAnswer(("a", "-1"), ("b", "2")), - assertNumStateRows(total = 1, updated = 1), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckNewAnswer(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { - test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - // String, (String, Long), RunningCount(Long) - val stateFunc = - (key: Row, values: Iterator[Row], state: GroupState[Row]) => { - val keyAsString = key.getString(0) - if (state.hasTimedOut) { - state.remove() - Iterator(Row(keyAsString, "-1")) - } else { - val count = state.getOption.map(_.getLong(0)).getOrElse(0L) + values.size - state.update(Row(count)) - state.setTimeoutDuration("10 seconds") - Iterator(Row(keyAsString, count.toString)) - } - } - - val clock = new StreamManualClock - val inputData = MemoryStream[(String, Long)] - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("countAsString", StringType))) - val stateStructType = StructType(Seq(StructField("count", LongType))) - val result = - inputData.toDF().toDF("key", "time") - .selectExpr("key", "timestamp_seconds(time) as timestamp") - .withWatermark("timestamp", "10 second") - .groupBy("key") - .flatMapGroupsWithState(outputStructType, stateStructType, Update, - ProcessingTimeTimeout)(stateFunc) - - testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, ("a", 1L)), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "1")) - ) - } - } - testWithTimeout(NoTimeout) - testWithTimeout(ProcessingTimeTimeout) -} - -object UntypedFlatMapGroupsWithStateSuite { - - var failInTask = true - - def assertCanGetProcessingTime(predicate: => Boolean): Unit = { - if (!predicate) throw new TestFailedException("Could not get processing time", 20) - } - - def assertCanGetWatermark(predicate: => Boolean): Unit = { - if (!predicate) throw new TestFailedException("Could not get processing time", 20) - } - - def assertCannotGetWatermark(func: => Unit): Unit = { - try { - func - } catch { - case u: UnsupportedOperationException => - return - case _: Throwable => - throw new TestFailedException("Unexpected exception when trying to get watermark", 20) - } - throw new TestFailedException("Could get watermark when not expected", 20) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 26c201d5921ed..fc6b51dce790b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -279,7 +279,7 @@ class ContinuousSuite extends ContinuousSuiteBase { Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).foreach { udf => test(s"continuous mode with various UDFs - ${udf.prettyName}") { assume( - shouldTestScalarPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || + shouldTestPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || shouldTestPythonUDFs && udf.isInstanceOf[TestPythonUDF] || udf.isInstanceOf[TestScalaUDF]) From f754fd9e6cd714bb45aa34d5661be930a0fe85e0 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 27 Jul 2022 13:55:33 +0900 Subject: [PATCH 03/44] Reorder key attributes from deduplicated data attributes --- .../FlatMapGroupsInPandasWithStateExec.scala | 16 ++++++++++++---- .../FlatMapGroupsInPandasWithStateSuite.scala | 8 ++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 85e03a262ec9e..ff4ec3cae8488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -20,7 +20,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, PythonUDF, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, resolveArgOffsets} @@ -121,9 +121,17 @@ case class FlatMapGroupsInPandasWithStateExec( stateDeserializer, stateType) - val inputIter = - if (hasTimedOut) Iterator.single(Iterator.single(stateData.keyRow)) - else Iterator.single(valueRowIter) + val inputIter = if (hasTimedOut) { + lazy val unsafeProj = UnsafeProjection.create( + dedupAttributes, groupingAttributes ++ dedupAttributes) + val joinedKeyRow = unsafeProj( + new JoinedRow( + stateData.keyRow, + new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + Iterator.single(Iterator.single(joinedKeyRow)) + } else { + Iterator.single(valueRowIter) + } val ret = executePython(inputIter, output, runner).toArray numOutputRows += ret.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 606d5d56744c8..03d3fd6dcff1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -31,7 +31,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { import testImplicits._ - test("flatMapGroupsWithState - streaming") { + test("applyInPandasWithState - streaming") { assume(shouldTestPandasUDFs) // Function to maintain running count up to 2, and then remove the count @@ -107,7 +107,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { ) } - test("flatMapGroupsWithState - streaming + aggregation") { + test("applyInPandasWithState - streaming + aggregation") { assume(shouldTestPandasUDFs) // Function to maintain running count up to 2, and then remove the count @@ -178,7 +178,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { ) } - test("flatMapGroupsWithState - streaming with processing time timeout") { + test("applyInPandasWithState - streaming with processing time timeout") { assume(shouldTestPandasUDFs) // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. @@ -277,7 +277,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { ) } - test("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { + test("applyInPandasWithState - streaming w/ event time timeout + watermark") { assume(shouldTestPandasUDFs) // timestamp_seconds assumes the base timezone is UTC. However, the provided function From 5194e0c0368ea103d9ebe8bae256df5bfbaf3ee8 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 28 Jul 2022 02:35:47 +0900 Subject: [PATCH 04/44] Apply suggestions from code review --- .../python/FlatMapGroupsInPandasWithStateExec.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index ff4ec3cae8488..0d6b6932f77bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -78,6 +78,9 @@ case class FlatMapGroupsInPandasWithStateExec( private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + private lazy val unsafeProj = UnsafeProjection.create( + dedupAttributes, groupingAttributes ++ dedupAttributes) override def requiredChildDistribution: Seq[Distribution] = StatefulOperatorPartitioning.getCompatibleDistribution( @@ -96,8 +99,6 @@ case class FlatMapGroupsInPandasWithStateExec( private val stateDeserializer = stateEncoder.asInstanceOf[ExpressionEncoder[Row]].createDeserializer() - private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) - def callFunctionAndUpdateState( stateData: StateData, valueRowIter: Iterator[InternalRow], @@ -122,8 +123,6 @@ case class FlatMapGroupsInPandasWithStateExec( stateType) val inputIter = if (hasTimedOut) { - lazy val unsafeProj = UnsafeProjection.create( - dedupAttributes, groupingAttributes ++ dedupAttributes) val joinedKeyRow = unsafeProj( new JoinedRow( stateData.keyRow, From 1301ee5bcce93bfb3075e4ba69ded530ace2f4c6 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 11 Aug 2022 10:36:52 +0900 Subject: [PATCH 05/44] Refactoring a bit to respect the column order --- python/pyspark/worker.py | 28 ++++++------------- .../python/FlatMapCoGroupsInPandasExec.scala | 4 +-- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../FlatMapGroupsInPandasWithStateExec.scala | 8 +++--- .../execution/python/PandasGroupUtils.scala | 7 +++-- .../sql/execution/python/PythonUDFSuite.scala | 2 +- 6 files changed, 20 insertions(+), 31 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6f63a6cd5c471..bdedc88e92a37 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -523,38 +523,26 @@ def extract_key_value_indexes(grouped_arg_offsets): idx += offsets_len return parsed - if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - # We assume there is only one UDF here because grouped map doesn't - # support combining multiple UDFs. - assert num_udfs == 1 - - # See FlatMapGroupsInPandasExec for how arg_offsets are used to - # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) - parsed_offsets = extract_key_value_indexes(arg_offsets) - - # Create function like this: - # mapper a: f([a[0]], [a[0], a[1]]) - def mapper(a): - keys = [a[o] for o in parsed_offsets[0][0]] - vals = [a[o] for o in parsed_offsets[0][1]] - return f(keys, vals) - - elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + if eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ): # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 - # See PythonFlatMapGroupsWithStateExec for how arg_offsets are used to + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( pickleSer, infile, eval_type, runner_conf, udf_index=0, state=state ) parsed_offsets = extract_key_value_indexes(arg_offsets) + # Create function like this: + # mapper a: f([a[0]], [a[0], a[1]]) def mapper(a): keys = [a[o] for o in parsed_offsets[0][0]] - vals = a # it's always all series. + vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index e830ea6b54662..b39787b12a484 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) - val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 3a3a6022f9985..f0e815e966e79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 0d6b6932f77bc..4ccd44a0b4297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -78,9 +78,9 @@ case class FlatMapGroupsInPandasWithStateExec( private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) - private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) - private lazy val unsafeProj = UnsafeProjection.create( - dedupAttributes, groupingAttributes ++ dedupAttributes) + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( + groupingAttributes ++ child.output, groupingAttributes) + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) override def requiredChildDistribution: Seq[Distribution] = StatefulOperatorPartitioning.getCompatibleDistribution( @@ -129,7 +129,7 @@ case class FlatMapGroupsInPandasWithStateExec( new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) Iterator.single(Iterator.single(joinedKeyRow)) } else { - Iterator.single(valueRowIter) + Iterator.single(valueRowIter.map(unsafeProj)) } val ret = executePython(inputIter, output, runner).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 2da0000dad4ef..078876664062d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.BasePythonRunner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.GroupedIterator import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -88,9 +88,10 @@ private[python] object PandasGroupUtils { * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes */ def resolveArgOffsets( - child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + attributes: Seq[Attribute], + groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { - val dataAttributes = child.output.drop(groupingAttributes.length) + val dataAttributes = attributes.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 4ad7f90105373..42e4b1accde72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -73,7 +73,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-39962: Global aggregation of Pandas UDF should respect the column order") { - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPythonUDFs) val df = Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("a", "b") val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") From 135a8266269853e2f5db1b964c6383077185f196 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 15 Aug 2022 18:48:10 +0900 Subject: [PATCH 06/44] WIP Changes to execute in pipelined manner --- .../spark/api/python/PythonRunner.scala | 9 +- python/pyspark/serializers.py | 1 + python/pyspark/sql/pandas/serializers.py | 113 ++++++- python/pyspark/sql/streaming/state.py | 11 +- python/pyspark/worker.py | 89 +++-- sql/core/pom.xml | 27 ++ .../python/ArrowPythonRunnerWithState.scala | 305 +++++++++++++++--- .../FlatMapGroupsInPandasWithStateExec.scala | 152 ++++++--- .../execution/python/PythonArrowOutput.scala | 7 + .../FlatMapGroupsWithStateExec.scala | 2 +- .../execution/streaming/GroupStateImpl.scala | 9 +- .../spark/sql/streaming/TestGroupState.scala | 3 + test-applyinpandaswithstate.py | 94 ++++++ 13 files changed, 675 insertions(+), 147 deletions(-) create mode 100644 test-applyinpandaswithstate.py diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7b31fa93c32e5..7616b7616587e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -539,13 +539,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() + val funcInitTime = stream.readLong() val finishTime = stream.readLong() val boot = bootTime - startTime val init = initTime - bootTime - val finish = finishTime - initTime + val funcInit = funcInitTime - initTime + val finish = finishTime - funcInitTime val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) + logInfo("Times: total = %s, boot = %s, init = %s, func_init = %s, finish = %s" + .format(total, boot, init, funcInit, finish)) val memoryBytesSpilled = stream.readLong() val diskBytesSpilled = stream.readLong() context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) @@ -782,6 +784,7 @@ private[spark] object SpecialLengths { val END_OF_STREAM = -4 val NULL = -5 val START_ARROW_STREAM = -6 + val START_STATE_UPDATE = -7 } private[spark] object BarrierTaskContextMessageProtocol { diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 8c5a941f376d2..c6eef13c5bf58 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -85,6 +85,7 @@ class SpecialLengths: END_OF_STREAM = -4 NULL = -5 START_ARROW_STREAM = -6 + START_STATE_UPDATE = -7 class Serializer: diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 992e82b403a1b..b1ac4c4f9492c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -18,8 +18,10 @@ """ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ +import sys -from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer +from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer +from pyspark.sql.types import StringType, StructType, BinaryType, StructField class SpecialLengths: @@ -29,6 +31,7 @@ class SpecialLengths: END_OF_STREAM = -4 NULL = -5 START_ARROW_STREAM = -6 + START_STATE_UPDATE = -7 class ArrowCollectSerializer(Serializer): @@ -371,3 +374,111 @@ def load_stream(self, stream): raise ValueError( "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group) ) + + +class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + + def __init__(self, timezone, safecheck, assign_cols_by_name): + super(ApplyInPandasWithStateSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name) + self.pickleSer = CPickleSerializer() + self.utf8_deserializer = UTF8Deserializer() + + def arrow_to_pandas(self, arrow_column): + return super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) + + def load_stream(self, stream): + import pyarrow as pa + import json + from pyspark.sql.types import StructType + from pyspark.sql.streaming.state import GroupStateImpl + + batches = ArrowStreamPandasUDFSerializer.load_stream(self, stream) + for batch in batches: + # FIXME: can we leverage schema here? doesn't work well so... + state_info_col = batch[-1][0] + + state_info_col_properties = state_info_col['properties'] + state_info_col_key_schema = state_info_col['keySchema'] + state_info_col_key_row = state_info_col['keyRow'] + state_info_col_object_schema = state_info_col['objectSchema'] + state_info_col_object = state_info_col['object'] + + state_properties = json.loads(state_info_col_properties) + state_key_schema = StructType.fromJson(json.loads(state_info_col_key_schema)) + state_key_row = self.pickleSer.loads(state_info_col_key_row) + state_object_schema = StructType.fromJson(json.loads(state_info_col_object_schema)) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + state = GroupStateImpl(key=state_key_row, keySchema=state_key_schema, + valueSchema=state_object_schema, **state_properties) + + state_column_dropped_series = batch[0:-1] + first_row_dropped_series = [x.iloc[1:].reset_index(drop=True) for x in state_column_dropped_series] + # state info + yield (first_row_dropped_series, state, ) + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + import pandas as pd + from pyspark.sql.pandas.types import to_arrow_type + + should_write_start_length = True + for data in iterator: + packaged_result = data[0] + + pdf = packaged_result[0][0].reset_index(drop=True) + state = packaged_result[0][-1] + return_schema = packaged_result[1] + + new_empty_row = pd.DataFrame(dict.fromkeys(pdf.columns), index=[0]) + + # Concatenate new_row with df + pdf_with_empty_row = pd.concat([new_empty_row, pdf[:]], axis=0).reset_index(drop=True) + + state_properties = state.json().encode("utf-8") + state_key_schema = state._key_schema.json().encode("utf-8") + state_key_row = self.pickleSer.dumps(state._key_schema.toInternal(state._key)) + state_object_schema = state._value_schema.json().encode("utf-8") + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + + state_dict = { + '__state__properties': [state_properties, ] + [None, ] * len(pdf), + '__state__keySchema': [state_key_schema, ] + [None, ] * len(pdf), + '__state__keyRow': [state_key_row, ] + [None, ] * len(pdf), + '__state__objectSchema': [state_object_schema, ] + [None, ] * len(pdf), + '__state__object': [state_object, ] + [None, ] * len(pdf), + } + + state_pdf = pd.DataFrame.from_dict(state_dict) + + state_df_type = StructType([ + StructField('__state__properties', StringType()), + StructField('__state__keySchema', StringType()), + StructField('__state__keyRow', BinaryType()), + StructField('__state__objectSchema', StringType()), + StructField('__state__object', BinaryType()), + ]) + + state_pdf_arrow_type = to_arrow_type(state_df_type) + + batch = self._create_batch([ + (pdf_with_empty_row, return_schema), + (state_pdf, state_pdf_arrow_type)]) + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 6281dbadba61b..6e05512d36304 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -45,10 +45,14 @@ def __init__( defined: bool, updated: bool, removed: bool, + timeoutUpdated: bool, timeoutTimestamp: int, # Python internal state. + key: Row, keySchema: StructType, + valueSchema: StructType, ) -> None: + self._key = key self._value = optionalValue self._batch_processing_time_ms = batchProcessingTimeMs self._event_time_watermark_ms = eventTimeWatermarkMs @@ -67,8 +71,10 @@ def __init__( self._updated = updated self._removed = removed self._timeout_timestamp = timeoutTimestamp + self._timeout_updated = timeoutUpdated self._key_schema = keySchema + self._value_schema = valueSchema @property def exists(self) -> bool: @@ -120,6 +126,7 @@ def setTimeoutDuration(self, durationMs: int) -> None: if durationMs <= 0: raise ValueError("Timeout duration must be positive") self._timeout_timestamp = durationMs + self._batch_processing_time_ms + self._timeout_updated = True # TODO(SPARK-XXXXX): Implement additionalDuration parameter. def setTimeoutTimestamp(self, timestampMs: int) -> None: @@ -145,6 +152,7 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: ) self._timeout_timestamp = timestampMs + self._timeout_updated = True def getCurrentWatermarkMs(self) -> int: if not self._watermark_present: @@ -159,7 +167,7 @@ def getCurrentProcessingTimeMs(self) -> int: def __str__(self) -> str: if self.exists: - return "GroupState(%s)" % self.get + return "GroupState(%s)" % (self.get, ) else: return "GroupState()" @@ -178,5 +186,6 @@ def json(self) -> str: "updated": self._updated, "removed": self._removed, "timeoutTimestamp": self._timeout_timestamp, + "timeoutUpdated": self._timeout_updated, } ) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index bdedc88e92a37..caaf1051a3615 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,22 +57,22 @@ from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, CogroupUDFSerializer, - ArrowStreamUDFSerializer, + ArrowStreamUDFSerializer, ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle -from pyspark.sql.streaming.state import GroupStateImpl pickleSer = CPickleSerializer() utf8_deserializer = UTF8Deserializer() -def report_times(outfile, boot, init, finish): +def report_times(outfile, boot, init, func_init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(int(1000 * boot), outfile) write_long(int(1000 * init), outfile) + write_long(int(1000 * func_init), outfile) write_long(int(1000 * finish), outfile) @@ -209,11 +209,11 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] -def wrap_grouped_map_pandas_udf_with_state(f, return_type, state): - def wrapped(key_series, value_series): +def wrap_grouped_map_pandas_udf_with_state(f, return_type): + def wrapped(key_series, value_series, state): import pandas as pd - key = tuple(s[0] for s in key_series) + key = tuple(s.head(1).at[0] for s in key_series) if state.hasTimedOut: # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), state) @@ -235,9 +235,10 @@ def wrapped(key_series, value_series): "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) ) - return result - return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] + return (result, state, ) + + return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] def wrap_grouped_agg_pandas_udf(f, return_type): @@ -314,7 +315,7 @@ def wrapped(begin_index, end_index, *series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, state=None): +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] chained_func = None @@ -345,7 +346,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, state= argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type, state) + return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) @@ -362,9 +363,6 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, state= def read_udfs(pickleSer, infile, eval_type): runner_conf = {} - # Used for state support in Structured Streaming. - state = None - if eval_type in ( PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, @@ -384,21 +382,6 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v - if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - # 1. State properties - properties = json.loads(utf8_deserializer.loads(infile)) - - # 2. State key - length = read_int(infile) - row = None - if length > 0: - row = pickleSer.loads(infile.read(length)) - # 3. Schema for state key - key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) - properties["optionalValue"] = row - - state = GroupStateImpl(keySchema=key_schema, **properties) - # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -417,6 +400,8 @@ def read_udfs(pickleSer, infile, eval_type): ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + ser = ApplyInPandasWithStateSerializer(timezone, safecheck, assign_cols_by_name) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. @@ -492,7 +477,7 @@ def map_batch(batch): ) # profiling is not supported for UDF - return func, None, ser, ser, state + return func, None, ser, ser def extract_key_value_indexes(grouped_arg_offsets): """ @@ -525,7 +510,6 @@ def extract_key_value_indexes(grouped_arg_offsets): if eval_type in ( PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ): # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. @@ -534,7 +518,7 @@ def extract_key_value_indexes(grouped_arg_offsets): # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, state=state + pickleSer, infile, eval_type, runner_conf, udf_index=0 ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -545,6 +529,24 @@ def mapper(a): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0 + ) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + keys = [a[0][o] for o in parsed_offsets[0][0]] + vals = [a[0][o] for o in parsed_offsets[0][1]] + state = a[1] + return f(keys, vals, state) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. @@ -578,12 +580,11 @@ def func(_, it): return map(mapper, it) # profiling is not supported for UDF - return func, None, ser, ser, state + return func, None, ser, ser def main(infile, outfile): faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) - state = None try: if faulthandler_log_path: faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) @@ -645,7 +646,6 @@ def main(infile, outfile): ) # initialize global state - state = None taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() @@ -724,15 +724,18 @@ def main(infile, outfile): broadcast_sock_file.close() _accumulatorRegistry.clear() + + init_time = time.time() + eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: - func, profiler, deserializer, serializer, state = read_udfs( + func, profiler, deserializer, serializer = read_udfs( pickleSer, infile, eval_type ) - init_time = time.time() + func_init_time = time.time() def process(): iterator = deserializer.load_stream(infile) @@ -778,25 +781,15 @@ def process(): faulthandler.disable() faulthandler_log_file.close() os.remove(faulthandler_log_path) + finish_time = time.time() - report_times(outfile, boot_time, init_time, finish_time) + report_times(outfile, boot_time, init_time, func_init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - # Send GroupState back to JVM if exists. - if state is not None: - # 1. Send JSON-serialized GroupState - write_with_length(state.json().encode("utf-8"), outfile) - - # 2. Send pickled Row. - if state._value is None: - write_int(0, outfile) - else: - write_with_length(pickleSer.dumps(state._key_schema.toInternal(state._value)), outfile) - write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7203fc591081a..5bb7708c0c6b7 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -267,6 +267,33 @@ + + + + org.scalastyle + scalastyle-maven-plugin + 1.0.0 + + true + false + false + false + false + ${basedir}/src/main/scala + ${basedir}/src/test/scala + ../../scalastyle-config.xml + ${basedir}/target/scalastyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 95e993f0685b2..354601c9c33f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -18,20 +18,31 @@ package org.apache.spark.sql.execution.python import java.io._ -import java.nio.charset.StandardCharsets +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python._ import org.apache.spark.sql.Row import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeProjection} +import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. @@ -40,15 +51,18 @@ class ArrowPythonRunnerWithState( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], - protected override val schema: StructType, - protected override val timeZoneId: String, - protected override val workerConf: Map[String, String], - oldState: GroupStateImpl[Row], - deserializer: ExpressionEncoder.Deserializer[Row], - stateType: StructType) - extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) - with PythonArrowInput - with PythonArrowOutput { + inputSchema: StructType, + timeZoneId: String, + workerConf: Map[String, String], + keyEncoder: ExpressionEncoder[Row], + stateEncoder: ExpressionEncoder[Row], + keySchema: StructType, + valueSchema: StructType, + stateSchema: StructType) + extends BasePythonRunner[ + (InternalRow, GroupStateImpl[Row], Iterator[InternalRow]), + (InternalRow, GroupStateImpl[Row], Iterator[InternalRow])]( + funcs, evalType, argOffsets) { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback @@ -58,46 +72,253 @@ class ArrowPythonRunnerWithState( "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") - var newGroupState: GroupStateImpl[Row] = _ - - protected override def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { - super.handleMetadataBeforeExec(stream) + val schemaWithState = inputSchema.add("!__state__!", + StructType( + Array( + StructField("properties", StringType), + StructField("keySchema", StringType), + StructField("keyRow", BinaryType), + StructField("objectSchema", StringType), + StructField("object", BinaryType) + ) + ) + ) - // 1. Send JSON-serialized GroupState - PythonRDD.writeUTF(oldState.json(), stream) + val keyRowSerializer = keyEncoder.createSerializer() + val keyRowDeserializer = keyEncoder.createDeserializer() + val stateRowSerializer = stateEncoder.createSerializer() + val stateRowDeserializer = stateEncoder.createDeserializer() - // 2. Send pickled Row from the GroupState - val rowInState = oldState.getOption.map(PythonSQLUtils.toPyRow).getOrElse(Array.empty) - stream.writeInt(rowInState.length) - if (rowInState.length > 0) { - stream.write(rowInState) + protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + // Write config for the worker as a number of key -> value pairs of strings + stream.writeInt(workerConf.size) + for ((k, v) <- workerConf) { + PythonRDD.writeUTF(k, stream) + PythonRDD.writeUTF(v, stream) } - - // 3. Send the state type to serialize the output state back from Python. - PythonRDD.writeUTF(stateType.json, stream) } - protected override def handleMetadataAfterExec(stream: DataInputStream): Unit = { - super.handleMetadataAfterExec(stream) + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[(InternalRow, GroupStateImpl[Row], Iterator[InternalRow])], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + handleMetadataBeforeExec(dataOut) + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + private def buildStateInfoRow( + keyRow: InternalRow, + groupState: GroupStateImpl[Row]): InternalRow = { + val keyRowAsPublicRow = keyRowDeserializer.apply(keyRow) + val stateUnderlyingRow = new GenericInternalRow( + Array[Any]( + UTF8String.fromString(groupState.json()), + UTF8String.fromString(keySchema.json), + PythonSQLUtils.toPyRow(keyRowAsPublicRow), + UTF8String.fromString(stateSchema.json), + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull + ) + ) + new GenericInternalRow(Array[Any](stateUnderlyingRow)) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + Utils.tryWithSafeFinally { + val nullDataRow = new GenericInternalRow(Array.fill(inputSchema.length)(null: Any)) + val nullStateInfoRow = new GenericInternalRow(Array.fill(1)(null: Any)) - implicit val formats = org.json4s.DefaultFormats + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() - // 1. Receive JSON-serialized GroupState - val jsonStr = new Array[Byte](stream.readInt()) - stream.readFully(jsonStr) - val properties = parse(new String(jsonStr, StandardCharsets.UTF_8)) + val joinedRow = new JoinedRow + while (inputIterator.hasNext) { + val (keyRow, groupState, dataIter) = inputIterator.next() - // 2. Receive and deserialized pickled Row to JVM Row. - val length = stream.readInt() - val maybeRow = if (length > 0) { - val pickledRow = new Array[Byte](length) - stream.readFully(pickledRow) - Some(PythonSQLUtils.toJVMRow(pickledRow, stateType, deserializer)) - } else { - None + // Provide state info row in the first row + val stateInfoRow = buildStateInfoRow(keyRow, groupState) + joinedRow.withLeft(nullDataRow).withRight(stateInfoRow) + arrowWriter.write(joinedRow) + + // Continue providing remaining data rows + while (dataIter.hasNext) { + val dataRow = dataIter.next() + joinedRow.withLeft(dataRow).withRight(nullStateInfoRow) + arrowWriter.write(joinedRow) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. + root.close() + allocator.close() + } + } } + } + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[(InternalRow, GroupStateImpl[Row], Iterator[InternalRow])] = { + + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ - // 3. Create a group state. - newGroupState = GroupStateImpl.fromJson(maybeRow, properties) + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def read(): (InternalRow, GroupStateImpl[Row], Iterator[InternalRow]) = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + deserializeColumnarBatch(batch) + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + // FIXME: should we validate schema here with value schema and state schema? + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + + private def deserializeColumnarBatch( + batch: ColumnarBatch): (InternalRow, GroupStateImpl[Row], Iterator[InternalRow]) = { + // this should at least have one row for state + assert(batch.numRows() > 0) + assert(schema.length == 2) + + val dataAttributes = schema(0).dataType.asInstanceOf[StructType].toAttributes + val stateInfoAttribute = schema(1).dataType.asInstanceOf[StructType].toAttributes + + val unsafeProjForStateInfo = UnsafeProjection.create( + stateInfoAttribute, stateInfoAttribute) + val unsafeProjForData = UnsafeProjection.create( + dataAttributes, dataAttributes) + + val structVectorForState = batch.column(1).asInstanceOf[ArrowColumnVector] + val outputVectorsForState = schema(1).dataType.asInstanceOf[StructType] + .indices.map(structVectorForState.getChild) + val flattenedBatchForState = new ColumnarBatch(outputVectorsForState.toArray) + flattenedBatchForState.setNumRows(1) + + val rowForStateInfo = unsafeProjForStateInfo(flattenedBatchForState.getRow(0)) + + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = schema(0).dataType.asInstanceOf[StructType] + .indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + val rowIterator = flattenedBatch.rowIterator.asScala + // drop first row as it's reserved for state + assert(rowIterator.hasNext) + rowIterator.next() + + // FIXME: we rely on known schema for state info, but would we want to access this by + // column name? + /* + Array( + StructField("properties", StringType), + StructField("keySchema", StringType), + StructField("keyRow", BinaryType), + StructField("objectSchema", StringType), + StructField("object", BinaryType) + ) + */ + implicit val formats = org.json4s.DefaultFormats + + val propertiesAsJson = parse(rowForStateInfo.getUTF8String(0).toString) + // FIXME: keySchema is probably not needed as we already know it... let's check whether + // it is needed for python worker, and if it does not, remove it. Or double check? + val pickledKeyRow = rowForStateInfo.getBinary(2) + val keyRowAsGenericRow = PythonSQLUtils.toJVMRow(pickledKeyRow, keySchema, + keyRowDeserializer) + val keyRowAsInternalRow = keyRowSerializer.apply(keyRowAsGenericRow) + val maybeObjectRow = if (rowForStateInfo.isNullAt(4)) { + None + } else { + val pickledRow = rowForStateInfo.getBinary(4) + Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) + } + + val newGroupState = GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson) + + (keyRowAsInternalRow, newGroupState, rowIterator.map(unsafeProjForData)) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 4ccd44a0b4297..f1c1acd73890f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -16,14 +16,16 @@ */ package org.apache.spark.sql.execution.python +import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.Distribution -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, resolveArgOffsets} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData @@ -62,6 +64,8 @@ case class FlatMapGroupsInPandasWithStateExec( eventTimeWatermark: Option[Long], child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + private val keySchema: StructType = groupingAttributes.toStructType + // TODO(SPARK-XXXXX): Add the support of initial state. override protected val initialStateDeserializer: Expression = null override protected val initialStateGroupAttrs: Seq[Attribute] = null @@ -72,6 +76,9 @@ case class FlatMapGroupsInPandasWithStateExec( override protected val stateEncoder: ExpressionEncoder[Any] = RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + private val keyEncoder: ExpressionEncoder[Row] = + RowEncoder(keySchema).resolveAndBind() + override def output: Seq[Attribute] = outAttributes private val sessionLocalTimeZone = conf.sessionLocalTimeZone @@ -96,20 +103,53 @@ case class FlatMapGroupsInPandasWithStateExec( override def createInputProcessor( store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { - private val stateDeserializer = - stateEncoder.asInstanceOf[ExpressionEncoder[Row]].createDeserializer() - def callFunctionAndUpdateState( - stateData: StateData, - valueRowIter: Iterator[InternalRow], - hasTimedOut: Boolean): Iterator[InternalRow] = { - val groupedState = GroupStateImpl.createForStreaming( - Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, - batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), - timeoutConf, - hasTimedOut = hasTimedOut, - watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + /** + * For every group, get the key, values and corresponding state and call the function, + * and return an iterator of rows + */ + override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + val processIter = groupedIter.map { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + val stateData = stateManager.getState(store, keyUnsafeRow) + (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj)) + } + + process(processIter, hasTimedOut = false) + } + + /** Find the groups that have timeout set and are timing out right now, and call the function */ + override def processTimedOutState(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + } + + val processIter = timingOutPairs.map { stateData => + val joinedKeyRow = unsafeProj( + new JoinedRow( + stateData.keyRow, + new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + + (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) + } + + process(processIter, hasTimedOut = true) + } else Iterator.empty + } + + private def process( + iter: Iterator[(InternalRow, StateData, Iterator[InternalRow])], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val keyUnsafeProj = UnsafeProjection.create(keySchema) val runner = new ArrowPythonRunnerWithState( chainedFunc, @@ -118,48 +158,60 @@ case class FlatMapGroupsInPandasWithStateExec( StructType.fromAttributes(dedupAttributes), sessionLocalTimeZone, pythonRunnerConf, - groupedState, - stateDeserializer, + keyEncoder, + stateEncoder.asInstanceOf[ExpressionEncoder[Row]], + groupingAttributes.toStructType, + child.output.toStructType, stateType) - val inputIter = if (hasTimedOut) { - val joinedKeyRow = unsafeProj( - new JoinedRow( - stateData.keyRow, - new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) - Iterator.single(Iterator.single(joinedKeyRow)) - } else { - Iterator.single(valueRowIter.map(unsafeProj)) + val context = TaskContext.get() + val processIter = iter.map { case (keyRow, stateData, valueIter) => + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + (keyRow, groupedState, valueIter) } + runner.compute(processIter, context.partitionId(), context).flatMap { + case (keyRow, newGroupState, outputIter) => + val keyUnsafeRow = keyUnsafeProj(keyRow) + + // When the iterator is consumed, then write changes to state + def onIteratorCompletion: Unit = { + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, keyUnsafeRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs + .orElse(NO_TIMESTAMP) + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + newGroupState.isTimeoutUpdated + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, keyUnsafeRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } - val ret = executePython(inputIter, output, runner).toArray - numOutputRows += ret.length - val newGroupState: GroupStateImpl[Row] = runner.newGroupState - assert(newGroupState != null) - - // When the iterator is consumed, then write changes to state - def onIteratorCompletion: Unit = { - if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { - stateManager.removeState(store, stateData.keyRow) - numRemovedStateRows += 1 - } else { - val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP) - val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp - val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || - hasTimeoutChanged - - if (shouldWriteState) { - val updatedStateObj = if (newGroupState.exists) newGroupState.get else null - stateManager.putState(store, stateData.keyRow, updatedStateObj, - currentTimeoutTimestamp) - numUpdatedStateRows += 1 + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIter, onIteratorCompletion).map { row => + numOutputRows += 1 + row } - } } + } - // Return an iterator of rows such that fully consumed, the updated state value will - // be saved - CompletionIterator[InternalRow, Iterator[InternalRow]](ret.iterator, onIteratorCompletion) + override protected def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 339f114539c28..c8398f2316b7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -41,6 +41,10 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT + protected def handleStateUpdate(stream: DataInputStream): Unit = { + new IllegalStateException("Should not reach here!") + } + protected def newReaderIterator( stream: DataInputStream, writerThread: WriterThread, @@ -103,6 +107,9 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ new ArrowColumnVector(vector) }.toArray[ColumnVector] read() + case SpecialLengths.START_STATE_UPDATE => + handleStateUpdate(stream) + read() case SpecialLengths.TIMING_DATA => handleTimingData() read() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index ee9449151758b..1071e522f6e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -56,7 +56,7 @@ trait FlatMapGroupsWithStateExecBase protected val batchTimestampMs: Option[Long] val eventTimeWatermark: Option[Long] - private val isTimeoutEnabled = timeoutConf != NoTimeout + protected val isTimeoutEnabled = timeoutConf != NoTimeout protected val watermarkPresent: Boolean = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 861ceabaf7f5a..1b7e0bf3c4e18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -59,6 +59,7 @@ private[sql] class GroupStateImpl[S] private[sql]( private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed private var timeoutTimestamp: Long = NO_TIMESTAMP + private var timeoutUpdated: Boolean = false // ========= Public API ========= override def exists: Boolean = defined @@ -103,6 +104,7 @@ private[sql] class GroupStateImpl[S] private[sql]( throw new IllegalArgumentException("Timeout duration must be positive") } timeoutTimestamp = durationMs + batchProcessingTimeMs + timeoutUpdated = true } override def setTimeoutDuration(duration: String): Unit = { @@ -120,6 +122,7 @@ private[sql] class GroupStateImpl[S] private[sql]( s"current watermark ($eventTimeWatermarkMs)") } timeoutTimestamp = timestampMs + timeoutUpdated = true } override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { @@ -158,6 +161,8 @@ private[sql] class GroupStateImpl[S] private[sql]( override def isUpdated: Boolean = updated + override def isTimeoutUpdated: Boolean = timeoutUpdated + override def getTimeoutTimestampMs: Optional[Long] = { if (timeoutTimestamp != NO_TIMESTAMP) { Optional.of(timeoutTimestamp) @@ -194,7 +199,8 @@ private[sql] class GroupStateImpl[S] private[sql]( "defined" -> JBool(defined) :: "updated" -> JBool(updated) :: "removed" -> JBool(removed) :: - "timeoutTimestamp" -> JLong(timeoutTimestamp) :: Nil + "timeoutTimestamp" -> JLong(timeoutTimestamp) :: + "timeoutUpdated" -> JBool(timeoutUpdated) :: Nil ))) } @@ -265,6 +271,7 @@ private[sql] object GroupStateImpl { newGroupState.removed = hmap("removed").asInstanceOf[Boolean] newGroupState.timeoutTimestamp = hmap("timeoutTimestamp").asInstanceOf[Number].longValue() + newGroupState.timeoutUpdated = hmap("timeoutUpdated").asInstanceOf[Boolean] newGroupState } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala index d53d6087d677c..346bde5df3e24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala @@ -123,6 +123,9 @@ trait TestGroupState[S] extends GroupState[S] { /** Whether the state has been updated but not removed */ def isUpdated: Boolean + /** FIXME: ... */ + def isTimeoutUpdated: Boolean + /** * Returns the timestamp if `setTimeoutTimestamp()` is called. * Or, returns batch processing time + the duration when diff --git a/test-applyinpandaswithstate.py b/test-applyinpandaswithstate.py new file mode 100644 index 0000000000000..03233b9e2cc6e --- /dev/null +++ b/test-applyinpandaswithstate.py @@ -0,0 +1,94 @@ +import calendar +import os +import datetime +import pandas as pd +from pyspark.sql import SparkSession +from pyspark.sql import Row + +def user_func(key, pdf, state): + timeout_delay_sec = 10 + + print('=' * 80) + print(key) + print(pdf) + print(state.getOption) + print(state.hasTimedOut) + print('=' * 80) + + if state.hasTimedOut: + state.remove() + return pd.DataFrame({'key1': [], 'key2': [], 'maxTimestampSeenMs': [], 'average': []}) + else: + prev_state = state.getOption + if prev_state is None: + prev_sum = 0 + prev_count = 0 + prev_max_timestamp_seen_sec = 0 # should be -Inf or something along with + else: + # FIXME: Is it better UX to access the state object as tuple instead of Row or dict at least? + prev_sum = prev_state[0] + prev_count = prev_state[1] + prev_max_timestamp_seen_sec = prev_state[2] + + new_sum = prev_sum + int(pdf.value.sum()) + new_count = prev_count + len(pdf) + + # TODO: now it's taking second precision - lower down to millisecond + # print(key) + # print(pdf) + pser = pdf.timestamp.apply(lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) + #print(pser) + new_max_event_time_sec = int(max(pser.max(), prev_max_timestamp_seen_sec)) + timeout_timestamp_sec = new_max_event_time_sec + timeout_delay_sec + + # FIXME: Is it better UX to access the state object as tuple instead of Row or dict at least? + state.update((new_sum, new_count, new_max_event_time_sec,)) + state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + return pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'maxTimestampSeenMs': [new_max_event_time_sec * 1000], 'average': [new_sum * 1.0 / new_count]}) + + +spark = SparkSession \ + .builder \ + .appName("Python ApplyInPandasWithState example") \ + .config("spark.sql.shuffle.partitions", 1) \ + .getOrCreate() + +rate_stream = ( + spark.readStream + .format('rate') + .option('numPartitions', 1) + .option('rowsPerSecond', 500000) + .load() +) + +output_struct = 'key1 string, key2 long, maxTimestampSeenMs long, average double' +state_struct = 'sum long, count long, maxTimestampSeenSec long' + +# desired_group_keys = 100 +desired_group_keys = 100000 +key1_expr = "(case when value % 5 = 0 then 'a' when value % 5 = 1 then 'b' when value % 5 = 2 then 'c' when value % 5 = 3 then 'd' else 'e' end) AS key1" +key2_expr = f"ceil(value / 5) % {desired_group_keys / 5} AS key2" + +# schema from rate source: 'timestamp' - TimestampType, 'value' - LongType +custom_session_window_stream = ( + rate_stream + # TODO: how many groups we want to track? + .selectExpr("timestamp", key1_expr, key2_expr, "value") + .withWatermark('timestamp', '0 seconds') + .groupby('key1', 'key2') + .applyInPandasWithState(user_func, outputStructType=output_struct, + stateStructType=state_struct, outputMode='update', timeoutConf='EventTimeTimeout') + .selectExpr('maxTimestampSeenMs * 1000', 'key1', 'key2', 'average') +) + +query = ( + custom_session_window_stream + .writeStream + .trigger(processingTime='0 seconds') + .outputMode('Update') + .format('console') + .start() +) + +query.awaitTermination() + From 9282e5c4de5dedea23d510a73850989babc3134b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 18 Aug 2022 14:01:50 +0900 Subject: [PATCH 07/44] WIP further optimization --- python/pyspark/sql/pandas/serializers.py | 35 ++++++++----------- .../python/ArrowPythonRunnerWithState.scala | 26 +++++++------- .../FlatMapGroupsInPandasWithStateExec.scala | 8 ++--- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index b1ac4c4f9492c..5bfc64aec9d2f 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -21,9 +21,9 @@ import sys from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer +from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StringType, StructType, BinaryType, StructField - class SpecialLengths: END_OF_DATA_SECTION = -1 PYTHON_EXCEPTION_THROWN = -2 @@ -384,6 +384,14 @@ def __init__(self, timezone, safecheck, assign_cols_by_name): self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() + self.state_df_type = StructType([ + StructField('__state__properties', StringType()), + StructField('__state__keyRow', BinaryType()), + StructField('__state__object', BinaryType()), + ]) + + self.state_pdf_arrow_type = to_arrow_type(self.state_df_type) + def arrow_to_pandas(self, arrow_column): return super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) @@ -431,7 +439,6 @@ def dump_stream(self, iterator, stream): def init_stream_yield_batches(): import pandas as pd - from pyspark.sql.pandas.types import to_arrow_type should_write_start_length = True for data in iterator: @@ -447,34 +454,22 @@ def init_stream_yield_batches(): pdf_with_empty_row = pd.concat([new_empty_row, pdf[:]], axis=0).reset_index(drop=True) state_properties = state.json().encode("utf-8") - state_key_schema = state._key_schema.json().encode("utf-8") state_key_row = self.pickleSer.dumps(state._key_schema.toInternal(state._key)) - state_object_schema = state._value_schema.json().encode("utf-8") state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + len_pdf = len(pdf) + none_array = [None, ] * len_pdf state_dict = { - '__state__properties': [state_properties, ] + [None, ] * len(pdf), - '__state__keySchema': [state_key_schema, ] + [None, ] * len(pdf), - '__state__keyRow': [state_key_row, ] + [None, ] * len(pdf), - '__state__objectSchema': [state_object_schema, ] + [None, ] * len(pdf), - '__state__object': [state_object, ] + [None, ] * len(pdf), + '__state__properties': [state_properties, ] + none_array, + '__state__keyRow': [state_key_row, ] + none_array, + '__state__object': [state_object, ] + none_array, } state_pdf = pd.DataFrame.from_dict(state_dict) - state_df_type = StructType([ - StructField('__state__properties', StringType()), - StructField('__state__keySchema', StringType()), - StructField('__state__keyRow', BinaryType()), - StructField('__state__objectSchema', StringType()), - StructField('__state__object', BinaryType()), - ]) - - state_pdf_arrow_type = to_arrow_type(state_df_type) - batch = self._create_batch([ (pdf_with_empty_row, return_schema), - (state_pdf, state_pdf_arrow_type)]) + (state_pdf, self.state_pdf_arrow_type)]) if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 354601c9c33f1..9212ff80e763c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -203,6 +203,7 @@ class ArrowPythonRunnerWithState( private var root: VectorSchemaRoot = _ private var schema: StructType = _ private var vectors: Array[ColumnVector] = _ + private var unsafeProjForData: UnsafeProjection = _ context.addTaskCompletionListener[Unit] { _ => if (reader != null) { @@ -237,6 +238,12 @@ class ArrowPythonRunnerWithState( root = reader.getVectorSchemaRoot() // FIXME: should we validate schema here with value schema and state schema? schema = ArrowUtils.fromArrowSchema(root.getSchema()) + + val dataAttributes = schema(0).dataType.asInstanceOf[StructType].toAttributes + val stateInfoAttribute = schema(1).dataType.asInstanceOf[StructType].toAttributes + + unsafeProjForData = UnsafeProjection.create(dataAttributes, dataAttributes) + vectors = root.getFieldVectors().asScala.map { vector => new ArrowColumnVector(vector) }.toArray[ColumnVector] @@ -260,21 +267,13 @@ class ArrowPythonRunnerWithState( assert(batch.numRows() > 0) assert(schema.length == 2) - val dataAttributes = schema(0).dataType.asInstanceOf[StructType].toAttributes - val stateInfoAttribute = schema(1).dataType.asInstanceOf[StructType].toAttributes - - val unsafeProjForStateInfo = UnsafeProjection.create( - stateInfoAttribute, stateInfoAttribute) - val unsafeProjForData = UnsafeProjection.create( - dataAttributes, dataAttributes) - val structVectorForState = batch.column(1).asInstanceOf[ArrowColumnVector] val outputVectorsForState = schema(1).dataType.asInstanceOf[StructType] .indices.map(structVectorForState.getChild) val flattenedBatchForState = new ColumnarBatch(outputVectorsForState.toArray) flattenedBatchForState.setNumRows(1) - val rowForStateInfo = unsafeProjForStateInfo(flattenedBatchForState.getRow(0)) + val rowForStateInfo = flattenedBatchForState.getRow(0) // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] @@ -290,12 +289,11 @@ class ArrowPythonRunnerWithState( // FIXME: we rely on known schema for state info, but would we want to access this by // column name? + // Received state information does not need schemas - this class already knows them. /* Array( StructField("properties", StringType), - StructField("keySchema", StringType), StructField("keyRow", BinaryType), - StructField("objectSchema", StringType), StructField("object", BinaryType) ) */ @@ -304,14 +302,14 @@ class ArrowPythonRunnerWithState( val propertiesAsJson = parse(rowForStateInfo.getUTF8String(0).toString) // FIXME: keySchema is probably not needed as we already know it... let's check whether // it is needed for python worker, and if it does not, remove it. Or double check? - val pickledKeyRow = rowForStateInfo.getBinary(2) + val pickledKeyRow = rowForStateInfo.getBinary(1) val keyRowAsGenericRow = PythonSQLUtils.toJVMRow(pickledKeyRow, keySchema, keyRowDeserializer) val keyRowAsInternalRow = keyRowSerializer.apply(keyRowAsGenericRow) - val maybeObjectRow = if (rowForStateInfo.isNullAt(4)) { + val maybeObjectRow = if (rowForStateInfo.isNullAt(2)) { None } else { - val pickledRow = rowForStateInfo.getBinary(4) + val pickledRow = rowForStateInfo.getBinary(2) Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index f1c1acd73890f..1765b400af7fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -104,6 +104,8 @@ case class FlatMapGroupsInPandasWithStateExec( override def createInputProcessor( store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { + private val keyUnsafeProj = UnsafeProjection.create(keySchema) + /** * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows @@ -147,10 +149,8 @@ case class FlatMapGroupsInPandasWithStateExec( } private def process( - iter: Iterator[(InternalRow, StateData, Iterator[InternalRow])], - hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyUnsafeProj = UnsafeProjection.create(keySchema) - + iter: Iterator[(InternalRow, StateData, Iterator[InternalRow])], + hasTimedOut: Boolean): Iterator[InternalRow] = { val runner = new ArrowPythonRunnerWithState( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, From a792c98cbe8355cb9f8cc8c5fa2137e05b8223fc Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 18 Aug 2022 18:22:33 +0900 Subject: [PATCH 08/44] WIP comments for more tunes --- python/pyspark/sql/pandas/serializers.py | 19 +++++++++++-------- .../python/ArrowPythonRunnerWithState.scala | 7 ++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 5bfc64aec9d2f..8050f7c787bfc 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -384,10 +384,11 @@ def __init__(self, timezone, safecheck, assign_cols_by_name): self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() + # FIXME: result_state_df_type? self.state_df_type = StructType([ - StructField('__state__properties', StringType()), - StructField('__state__keyRow', BinaryType()), - StructField('__state__object', BinaryType()), + StructField('properties', StringType()), + StructField('keyRow', BinaryType()), + StructField('object', BinaryType()), ]) self.state_pdf_arrow_type = to_arrow_type(self.state_df_type) @@ -412,10 +413,12 @@ def load_stream(self, stream): state_info_col_object_schema = state_info_col['objectSchema'] state_info_col_object = state_info_col['object'] - state_properties = json.loads(state_info_col_properties) + # FIXME: schemas can be retrieved as metadata since they are applied for all data state_key_schema = StructType.fromJson(json.loads(state_info_col_key_schema)) - state_key_row = self.pickleSer.loads(state_info_col_key_row) state_object_schema = StructType.fromJson(json.loads(state_info_col_object_schema)) + + state_properties = json.loads(state_info_col_properties) + state_key_row = self.pickleSer.loads(state_info_col_key_row) if state_info_col_object: state_object = self.pickleSer.loads(state_info_col_object) else: @@ -460,9 +463,9 @@ def init_stream_yield_batches(): len_pdf = len(pdf) none_array = [None, ] * len_pdf state_dict = { - '__state__properties': [state_properties, ] + none_array, - '__state__keyRow': [state_key_row, ] + none_array, - '__state__object': [state_object, ] + none_array, + 'properties': [state_properties, ] + none_array, + 'keyRow': [state_key_row, ] + none_array, + 'object': [state_object, ] + none_array, } state_pdf = pd.DataFrame.from_dict(state_dict) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 9212ff80e763c..614bf59292bb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -76,6 +76,8 @@ class ArrowPythonRunnerWithState( StructType( Array( StructField("properties", StringType), + // FIXME: don't need to send the key row in state separately if there is any data + // FIXME: same: don't need to send the key schema as we know the schema StructField("keySchema", StringType), StructField("keyRow", BinaryType), StructField("objectSchema", StringType), @@ -240,7 +242,6 @@ class ArrowPythonRunnerWithState( schema = ArrowUtils.fromArrowSchema(root.getSchema()) val dataAttributes = schema(0).dataType.asInstanceOf[StructType].toAttributes - val stateInfoAttribute = schema(1).dataType.asInstanceOf[StructType].toAttributes unsafeProjForData = UnsafeProjection.create(dataAttributes, dataAttributes) @@ -300,9 +301,9 @@ class ArrowPythonRunnerWithState( implicit val formats = org.json4s.DefaultFormats val propertiesAsJson = parse(rowForStateInfo.getUTF8String(0).toString) - // FIXME: keySchema is probably not needed as we already know it... let's check whether - // it is needed for python worker, and if it does not, remove it. Or double check? val pickledKeyRow = rowForStateInfo.getBinary(1) + // FIXME: we convert key as byte array -> generic Row -> internal Row -> unsafe Row + // is there any util to skip a part of conversion? val keyRowAsGenericRow = PythonSQLUtils.toJVMRow(pickledKeyRow, keySchema, keyRowDeserializer) val keyRowAsInternalRow = keyRowSerializer.apply(keyRowAsGenericRow) From 27e7af9e4d0c6fff9f12bab5986435e7a450d3db Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 18 Aug 2022 19:55:28 +0900 Subject: [PATCH 09/44] WIP further tune... --- python/pyspark/sql/pandas/serializers.py | 27 +++++------ python/pyspark/sql/streaming/state.py | 5 +- python/pyspark/worker.py | 7 ++- .../python/ArrowPythonRunnerWithState.scala | 46 +++++++------------ .../FlatMapGroupsInPandasWithStateExec.scala | 14 ++---- 5 files changed, 39 insertions(+), 60 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 8050f7c787bfc..31634647c5525 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -378,20 +378,21 @@ def load_stream(self, stream): class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): - def __init__(self, timezone, safecheck, assign_cols_by_name): + def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema): super(ApplyInPandasWithStateSerializer, self).__init__( timezone, safecheck, assign_cols_by_name) self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() + self.state_object_schema = state_object_schema # FIXME: result_state_df_type? - self.state_df_type = StructType([ + self.result_state_df_type = StructType([ StructField('properties', StringType()), - StructField('keyRow', BinaryType()), + StructField('keyRowAsUnsafe', BinaryType()), StructField('object', BinaryType()), ]) - self.state_pdf_arrow_type = to_arrow_type(self.state_df_type) + self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) def arrow_to_pandas(self, arrow_column): return super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) @@ -408,25 +409,19 @@ def load_stream(self, stream): state_info_col = batch[-1][0] state_info_col_properties = state_info_col['properties'] - state_info_col_key_schema = state_info_col['keySchema'] - state_info_col_key_row = state_info_col['keyRow'] + state_info_col_key_row = state_info_col['keyRowAsUnsafe'] state_info_col_object_schema = state_info_col['objectSchema'] state_info_col_object = state_info_col['object'] - # FIXME: schemas can be retrieved as metadata since they are applied for all data - state_key_schema = StructType.fromJson(json.loads(state_info_col_key_schema)) - state_object_schema = StructType.fromJson(json.loads(state_info_col_object_schema)) - state_properties = json.loads(state_info_col_properties) - state_key_row = self.pickleSer.loads(state_info_col_key_row) if state_info_col_object: state_object = self.pickleSer.loads(state_info_col_object) else: state_object = None state_properties["optionalValue"] = state_object - state = GroupStateImpl(key=state_key_row, keySchema=state_key_schema, - valueSchema=state_object_schema, **state_properties) + state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, **state_properties) state_column_dropped_series = batch[0:-1] first_row_dropped_series = [x.iloc[1:].reset_index(drop=True) for x in state_column_dropped_series] @@ -457,14 +452,14 @@ def init_stream_yield_batches(): pdf_with_empty_row = pd.concat([new_empty_row, pdf[:]], axis=0).reset_index(drop=True) state_properties = state.json().encode("utf-8") - state_key_row = self.pickleSer.dumps(state._key_schema.toInternal(state._key)) + state_key_row_as_binary = state._keyAsUnsafe state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) len_pdf = len(pdf) none_array = [None, ] * len_pdf state_dict = { 'properties': [state_properties, ] + none_array, - 'keyRow': [state_key_row, ] + none_array, + 'keyRowAsUnsafe': [state_key_row_as_binary, ] + none_array, 'object': [state_object, ] + none_array, } @@ -472,7 +467,7 @@ def init_stream_yield_batches(): batch = self._create_batch([ (pdf_with_empty_row, return_schema), - (state_pdf, self.state_pdf_arrow_type)]) + (state_pdf, self.result_state_pdf_arrow_type)]) if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 6e05512d36304..0f43bc1ea1520 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -48,11 +48,10 @@ def __init__( timeoutUpdated: bool, timeoutTimestamp: int, # Python internal state. - key: Row, - keySchema: StructType, + keyAsUnsafe: bytes, valueSchema: StructType, ) -> None: - self._key = key + self._keyAsUnsafe = keyAsUnsafe self._value = optionalValue self._batch_processing_time_ms = batchProcessingTimeMs self._event_time_watermark_ms = eventTimeWatermarkMs diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index caaf1051a3615..f52c3b56692d6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -382,6 +382,10 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v + state_object_schema = None + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -401,7 +405,7 @@ def read_udfs(pickleSer, infile, eval_type): elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - ser = ApplyInPandasWithStateSerializer(timezone, safecheck, assign_cols_by_name) + ser = ApplyInPandasWithStateSerializer(timezone, safecheck, assign_cols_by_name, state_object_schema) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. @@ -646,6 +650,7 @@ def main(infile, outfile): ) # initialize global state + state_schema = None taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 614bf59292bb0..3992721b17cc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -27,14 +27,14 @@ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.json4s._ import org.json4s.jackson.JsonMethods._ - import org.apache.spark.{SparkEnv, TaskContext} + import org.apache.spark.api.python._ import org.apache.spark.sql.Row import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.internal.SQLConf @@ -54,14 +54,13 @@ class ArrowPythonRunnerWithState( inputSchema: StructType, timeZoneId: String, workerConf: Map[String, String], - keyEncoder: ExpressionEncoder[Row], stateEncoder: ExpressionEncoder[Row], keySchema: StructType, valueSchema: StructType, stateSchema: StructType) extends BasePythonRunner[ - (InternalRow, GroupStateImpl[Row], Iterator[InternalRow]), - (InternalRow, GroupStateImpl[Row], Iterator[InternalRow])]( + (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]), + (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])]( funcs, evalType, argOffsets) { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback @@ -76,18 +75,12 @@ class ArrowPythonRunnerWithState( StructType( Array( StructField("properties", StringType), - // FIXME: don't need to send the key row in state separately if there is any data - // FIXME: same: don't need to send the key schema as we know the schema - StructField("keySchema", StringType), - StructField("keyRow", BinaryType), - StructField("objectSchema", StringType), + StructField("keyRowAsUnsafe", BinaryType), StructField("object", BinaryType) ) ) ) - val keyRowSerializer = keyEncoder.createSerializer() - val keyRowDeserializer = keyEncoder.createDeserializer() val stateRowSerializer = stateEncoder.createSerializer() val stateRowDeserializer = stateEncoder.createDeserializer() @@ -98,12 +91,13 @@ class ArrowPythonRunnerWithState( PythonRDD.writeUTF(k, stream) PythonRDD.writeUTF(v, stream) } + PythonRDD.writeUTF(stateSchema.json, stream) } protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[(InternalRow, GroupStateImpl[Row], Iterator[InternalRow])], + inputIterator: Iterator[(UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -114,15 +108,12 @@ class ArrowPythonRunnerWithState( } private def buildStateInfoRow( - keyRow: InternalRow, + keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): InternalRow = { - val keyRowAsPublicRow = keyRowDeserializer.apply(keyRow) val stateUnderlyingRow = new GenericInternalRow( Array[Any]( UTF8String.fromString(groupState.json()), - UTF8String.fromString(keySchema.json), - PythonSQLUtils.toPyRow(keyRowAsPublicRow), - UTF8String.fromString(stateSchema.json), + keyRow.getBytes, groupState.getOption.map(PythonSQLUtils.toPyRow).orNull ) ) @@ -193,7 +184,7 @@ class ArrowPythonRunnerWithState( worker: Socket, pid: Option[Int], releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[(InternalRow, GroupStateImpl[Row], Iterator[InternalRow])] = { + context: TaskContext): Iterator[(UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])] = { new ReaderIterator( stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { @@ -216,7 +207,7 @@ class ArrowPythonRunnerWithState( private var batchLoaded = true - protected override def read(): (InternalRow, GroupStateImpl[Row], Iterator[InternalRow]) = { + protected override def read(): (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } @@ -263,7 +254,7 @@ class ArrowPythonRunnerWithState( } private def deserializeColumnarBatch( - batch: ColumnarBatch): (InternalRow, GroupStateImpl[Row], Iterator[InternalRow]) = { + batch: ColumnarBatch): (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) = { // this should at least have one row for state assert(batch.numRows() > 0) assert(schema.length == 2) @@ -294,19 +285,16 @@ class ArrowPythonRunnerWithState( /* Array( StructField("properties", StringType), - StructField("keyRow", BinaryType), + StructField("keyRowAsUnsafe", BinaryType), StructField("object", BinaryType) ) */ implicit val formats = org.json4s.DefaultFormats val propertiesAsJson = parse(rowForStateInfo.getUTF8String(0).toString) - val pickledKeyRow = rowForStateInfo.getBinary(1) - // FIXME: we convert key as byte array -> generic Row -> internal Row -> unsafe Row - // is there any util to skip a part of conversion? - val keyRowAsGenericRow = PythonSQLUtils.toJVMRow(pickledKeyRow, keySchema, - keyRowDeserializer) - val keyRowAsInternalRow = keyRowSerializer.apply(keyRowAsGenericRow) + val keyRowAsUnsafeAsBinary = rowForStateInfo.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) val maybeObjectRow = if (rowForStateInfo.isNullAt(2)) { None } else { @@ -316,7 +304,7 @@ class ArrowPythonRunnerWithState( val newGroupState = GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson) - (keyRowAsInternalRow, newGroupState, rowIterator.map(unsafeProjForData)) + (keyRowAsUnsafe, newGroupState, rowIterator.map(unsafeProjForData)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 1765b400af7fb..2ed5dcbaed58f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -76,9 +76,6 @@ case class FlatMapGroupsInPandasWithStateExec( override protected val stateEncoder: ExpressionEncoder[Any] = RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] - private val keyEncoder: ExpressionEncoder[Row] = - RowEncoder(keySchema).resolveAndBind() - override def output: Seq[Attribute] = outAttributes private val sessionLocalTimeZone = conf.sessionLocalTimeZone @@ -104,8 +101,6 @@ case class FlatMapGroupsInPandasWithStateExec( override def createInputProcessor( store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { - private val keyUnsafeProj = UnsafeProjection.create(keySchema) - /** * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows @@ -149,7 +144,7 @@ case class FlatMapGroupsInPandasWithStateExec( } private def process( - iter: Iterator[(InternalRow, StateData, Iterator[InternalRow])], + iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])], hasTimedOut: Boolean): Iterator[InternalRow] = { val runner = new ArrowPythonRunnerWithState( chainedFunc, @@ -158,7 +153,6 @@ case class FlatMapGroupsInPandasWithStateExec( StructType.fromAttributes(dedupAttributes), sessionLocalTimeZone, pythonRunnerConf, - keyEncoder, stateEncoder.asInstanceOf[ExpressionEncoder[Row]], groupingAttributes.toStructType, child.output.toStructType, @@ -177,12 +171,10 @@ case class FlatMapGroupsInPandasWithStateExec( } runner.compute(processIter, context.partitionId(), context).flatMap { case (keyRow, newGroupState, outputIter) => - val keyUnsafeRow = keyUnsafeProj(keyRow) - // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { - stateManager.removeState(store, keyUnsafeRow) + stateManager.removeState(store, keyRow) numRemovedStateRows += 1 } else { val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs @@ -192,7 +184,7 @@ case class FlatMapGroupsInPandasWithStateExec( if (shouldWriteState) { val updatedStateObj = if (newGroupState.exists) newGroupState.get else null - stateManager.putState(store, keyUnsafeRow, updatedStateObj, + stateManager.putState(store, keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } From 04a6b98e26f0c653937316919d1e30b1ab0424a2 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 18 Aug 2022 20:14:56 +0900 Subject: [PATCH 10/44] WIP done more tune! didn't do any of pandas/arrow side tunes --- python/pyspark/sql/pandas/serializers.py | 2 -- python/pyspark/sql/streaming/state.py | 1 - 2 files changed, 3 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 31634647c5525..6ec6742335ce8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -400,7 +400,6 @@ def arrow_to_pandas(self, arrow_column): def load_stream(self, stream): import pyarrow as pa import json - from pyspark.sql.types import StructType from pyspark.sql.streaming.state import GroupStateImpl batches = ArrowStreamPandasUDFSerializer.load_stream(self, stream) @@ -410,7 +409,6 @@ def load_stream(self, stream): state_info_col_properties = state_info_col['properties'] state_info_col_key_row = state_info_col['keyRowAsUnsafe'] - state_info_col_object_schema = state_info_col['objectSchema'] state_info_col_object = state_info_col['object'] state_properties = json.loads(state_info_col_properties) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 0f43bc1ea1520..8b883c4b5b809 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -72,7 +72,6 @@ def __init__( self._timeout_timestamp = timeoutTimestamp self._timeout_updated = timeoutUpdated - self._key_schema = keySchema self._value_schema = valueSchema @property From 765f4d3519da7c0fe4224742a00a03742af54611 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 19 Aug 2022 14:50:03 +0900 Subject: [PATCH 11/44] WIP avoid adding additional empty row for state, empty row will be added only when there is no data --- python/pyspark/sql/pandas/serializers.py | 82 ++++++++++++++----- python/pyspark/worker.py | 6 +- sql/core/pom.xml | 39 +++++++++ .../sql/execution/arrow/ArrowWriter.scala | 2 +- .../python/ArrowPythonRunnerWithState.scala | 78 ++++++++++++------ 5 files changed, 159 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6ec6742335ce8..70aa04b1c97b3 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -22,7 +22,8 @@ from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.types import StringType, StructType, BinaryType, StructField +from pyspark.sql.types import StringType, StructType, BinaryType, StructField, BooleanType + class SpecialLengths: END_OF_DATA_SECTION = -1 @@ -248,6 +249,7 @@ def create_array(s, t): arrs = [] for s, t in series: + print("==== <_create_batch> s: %s t: %s" % (s, t, ), file=sys.stderr) if t is not None and pa.types.is_struct(t): if not isinstance(s, pd.DataFrame): raise ValueError( @@ -390,22 +392,52 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema StructField('properties', StringType()), StructField('keyRowAsUnsafe', BinaryType()), StructField('object', BinaryType()), + StructField('isEmptyData', BooleanType()), ]) self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) - def arrow_to_pandas(self, arrow_column): - return super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) - def load_stream(self, stream): import pyarrow as pa import json from pyspark.sql.streaming.state import GroupStateImpl - batches = ArrowStreamPandasUDFSerializer.load_stream(self, stream) + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + for batch in batches: - # FIXME: can we leverage schema here? doesn't work well so... - state_info_col = batch[-1][0] + print("=== batch: %s type(batch): %s" % (batch, type(batch), ), file=sys.stderr) + + batch_schema = batch.schema + + print("=== batch_schema: %s type(batch_schema): %s" % (batch_schema, type(batch_schema), ), file=sys.stderr) + + batch_columns = batch.columns + data_columns = batch_columns[0:-1] + state_column = batch_columns[-1] + + print("=== data_columns: %s state_column: %s" % (data_columns, state_column, ), file=sys.stderr) + + data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) + state_schema = pa.schema([batch_schema[-1], ]) + + print("=== data_schema: %s state_schema: %s" % (data_schema, state_schema, ), file=sys.stderr) + + data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) + state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) + + print("=== data_batch: %s state_batch: %s" % (data_batch, state_batch, ), file=sys.stderr) + + data_arrow = pa.Table.from_batches([data_batch]).itercolumns() + state_arrow = pa.Table.from_batches([state_batch]).itercolumns() + + print("=== data_arrow_columns: %s state_arrow_columns: %s" % (data_arrow, state_arrow, ), file=sys.stderr) + + data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] + state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] + + print("=== data_pandas: %s type(data_pandas): %s state_pandas: %s type(state_pandas): %s" % (data_pandas, type(data_pandas), state_pandas, type(state_pandas), ), file=sys.stderr) + + state_info_col = state_pandas.iloc[0] state_info_col_properties = state_info_col['properties'] state_info_col_key_row = state_info_col['keyRowAsUnsafe'] @@ -421,10 +453,10 @@ def load_stream(self, stream): state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, valueSchema=self.state_object_schema, **state_properties) - state_column_dropped_series = batch[0:-1] - first_row_dropped_series = [x.iloc[1:].reset_index(drop=True) for x in state_column_dropped_series] + print("=== data_pandas: %s state: %s" % (data_pandas, state, ), file=sys.stderr) + # state info - yield (first_row_dropped_series, state, ) + yield (data_pandas, state, ) def dump_stream(self, iterator, stream): """ @@ -435,36 +467,46 @@ def dump_stream(self, iterator, stream): def init_stream_yield_batches(): import pandas as pd + import pyarrow as pa should_write_start_length = True for data in iterator: packaged_result = data[0] - pdf = packaged_result[0][0].reset_index(drop=True) + pdf = packaged_result[0][0] state = packaged_result[0][-1] return_schema = packaged_result[1] - new_empty_row = pd.DataFrame(dict.fromkeys(pdf.columns), index=[0]) + # FIXME: arrow type to pandas type + # FIXME: probably also need to check columns to validate? + + print("==== pdf: %s len(pdf): %s" % (pdf, len(pdf), ), file=sys.stderr) - # Concatenate new_row with df - pdf_with_empty_row = pd.concat([new_empty_row, pdf[:]], axis=0).reset_index(drop=True) + empty_data = len(pdf) == 0 + if empty_data: + # if returned DataFrame is empty with no column information, just create a new + # DataFrame with empty row with column information + pdf = pd.DataFrame(dict.fromkeys(pa.schema(return_schema).names), index=[0]) + + print("==== pdf: %s state: %s return_schema: %s" % (pdf, state, return_schema, ), file=sys.stderr) state_properties = state.json().encode("utf-8") state_key_row_as_binary = state._keyAsUnsafe state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) - len_pdf = len(pdf) - none_array = [None, ] * len_pdf state_dict = { - 'properties': [state_properties, ] + none_array, - 'keyRowAsUnsafe': [state_key_row_as_binary, ] + none_array, - 'object': [state_object, ] + none_array, + 'properties': [state_properties, ], + 'keyRowAsUnsafe': [state_key_row_as_binary, ], + 'object': [state_object, ], + 'isEmptyData': [empty_data, ], } state_pdf = pd.DataFrame.from_dict(state_dict) + print("==== pdf: %s return_schema: %s state_pdf: %s result_state arrow_schema: %s" % (pdf, return_schema, state_pdf, self.result_state_pdf_arrow_type, ), file=sys.stderr) + batch = self._create_batch([ - (pdf_with_empty_row, return_schema), + (pdf, return_schema), (state_pdf, self.result_state_pdf_arrow_type)]) if should_write_start_length: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f52c3b56692d6..8269083c72a0c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -213,7 +213,11 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type): def wrapped(key_series, value_series, state): import pandas as pd - key = tuple(s.head(1).at[0] for s in key_series) + print("=== key_series: %s value_series: %s state: %s" % (key_series, value_series, state, ), file=sys.stderr) + + key = tuple(s[0] for s in key_series) + print("=== key: %s" % (key, ), file=sys.stderr) + if state.hasTimedOut: # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), state) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 5bb7708c0c6b7..736fd515d35d2 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -269,6 +269,45 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + 3.1.2 + + true + false + true + + ${basedir}/src/main/java + ${basedir}/src/main/scala + + + ${basedir}/src/test/java + + dev/checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + com.puppycrawl.tools + checkstyle + 8.43 + + + + + + check + + + + org.scalastyle scalastyle-maven-plugin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 7abca5f0e3320..34e128a4925f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -44,7 +44,7 @@ object ArrowWriter { new ArrowWriter(root, children.toArray) } - private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() (ArrowUtils.fromArrowField(field), vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 3992721b17cc1..c3f5683c223cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -27,15 +27,16 @@ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python._ import org.apache.spark.sql.Row import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -81,6 +82,8 @@ class ArrowPythonRunnerWithState( ) ) + logWarning(s"DEBUG: schemaWithState: ${schemaWithState}") + val stateRowSerializer = stateEncoder.createSerializer() val stateRowDeserializer = stateEncoder.createDeserializer() @@ -122,37 +125,55 @@ class ArrowPythonRunnerWithState( protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId) + + logWarning(s"DEBUG: arrowSchema: ${arrowSchema}") + val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { - val nullDataRow = new GenericInternalRow(Array.fill(inputSchema.length)(null: Any)) - val nullStateInfoRow = new GenericInternalRow(Array.fill(1)(null: Any)) + val arrowWriterForData = { + val children = root.getFieldVectors().asScala.dropRight(1).map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + + new ArrowWriter(root, children.toArray) + } + val arrowWriterForState = { + val children = root.getFieldVectors().asScala.takeRight(1).map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + new ArrowWriter(root, children.toArray) + } - val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() - val joinedRow = new JoinedRow while (inputIterator.hasNext) { val (keyRow, groupState, dataIter) = inputIterator.next() + assert(dataIter.hasNext, "should have at least one data row!") + // Provide state info row in the first row val stateInfoRow = buildStateInfoRow(keyRow, groupState) - joinedRow.withLeft(nullDataRow).withRight(stateInfoRow) - arrowWriter.write(joinedRow) + arrowWriterForState.write(stateInfoRow) // Continue providing remaining data rows while (dataIter.hasNext) { val dataRow = dataIter.next() - joinedRow.withLeft(dataRow).withRight(nullStateInfoRow) - arrowWriter.write(joinedRow) + arrowWriterForData.write(dataRow) } - arrowWriter.finish() + // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data + // side, as we know there is at least one data row. + arrowWriterForState.finish() + arrowWriterForData.finish() writer.writeBatch() - arrowWriter.reset() + arrowWriterForState.reset() + arrowWriterForData.reset() } // end writes footer to the output stream and doesn't clean any resources. // It could throw exception if the output stream is closed, so it should be @@ -267,18 +288,6 @@ class ArrowPythonRunnerWithState( val rowForStateInfo = flattenedBatchForState.getRow(0) - // UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = schema(0).dataType.asInstanceOf[StructType] - .indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - - val rowIterator = flattenedBatch.rowIterator.asScala - // drop first row as it's reserved for state - assert(rowIterator.hasNext) - rowIterator.next() - // FIXME: we rely on known schema for state info, but would we want to access this by // column name? // Received state information does not need schemas - this class already knows them. @@ -286,7 +295,8 @@ class ArrowPythonRunnerWithState( Array( StructField("properties", StringType), StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType) + StructField("object", BinaryType), + StructField('isEmptyData', BooleanType) ) */ implicit val formats = org.json4s.DefaultFormats @@ -301,10 +311,26 @@ class ArrowPythonRunnerWithState( val pickledRow = rowForStateInfo.getBinary(2) Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) } + val isEmptyData = rowForStateInfo.getBoolean(3) val newGroupState = GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson) - (keyRowAsUnsafe, newGroupState, rowIterator.map(unsafeProjForData)) + val rowIterator = if (isEmptyData) { + logWarning("DEBUG: no data is available") + Iterator.empty + } else { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = schema(0).dataType.asInstanceOf[StructType] + .indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + val rowIterator = flattenedBatch.rowIterator.asScala + rowIterator.map(unsafeProjForData) + } + + (keyRowAsUnsafe, newGroupState, rowIterator) } } } From 9e1122527db230d6dcd27c6933edc28c908832a3 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 19 Aug 2022 15:04:27 +0900 Subject: [PATCH 12/44] WIP remove debug log --- python/pyspark/sql/pandas/serializers.py | 29 ++----------------- python/pyspark/worker.py | 4 --- .../python/ArrowPythonRunnerWithState.scala | 6 ---- 3 files changed, 2 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 70aa04b1c97b3..52f8eae957e15 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -249,7 +249,6 @@ def create_array(s, t): arrs = [] for s, t in series: - print("==== <_create_batch> s: %s t: %s" % (s, t, ), file=sys.stderr) if t is not None and pa.types.is_struct(t): if not isinstance(s, pd.DataFrame): raise ValueError( @@ -387,7 +386,6 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema self.utf8_deserializer = UTF8Deserializer() self.state_object_schema = state_object_schema - # FIXME: result_state_df_type? self.result_state_df_type = StructType([ StructField('properties', StringType()), StructField('keyRowAsUnsafe', BinaryType()), @@ -405,38 +403,23 @@ def load_stream(self, stream): batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) for batch in batches: - print("=== batch: %s type(batch): %s" % (batch, type(batch), ), file=sys.stderr) - batch_schema = batch.schema - - print("=== batch_schema: %s type(batch_schema): %s" % (batch_schema, type(batch_schema), ), file=sys.stderr) + data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) + state_schema = pa.schema([batch_schema[-1], ]) batch_columns = batch.columns data_columns = batch_columns[0:-1] state_column = batch_columns[-1] - print("=== data_columns: %s state_column: %s" % (data_columns, state_column, ), file=sys.stderr) - - data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) - state_schema = pa.schema([batch_schema[-1], ]) - - print("=== data_schema: %s state_schema: %s" % (data_schema, state_schema, ), file=sys.stderr) - data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) - print("=== data_batch: %s state_batch: %s" % (data_batch, state_batch, ), file=sys.stderr) - data_arrow = pa.Table.from_batches([data_batch]).itercolumns() state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - print("=== data_arrow_columns: %s state_arrow_columns: %s" % (data_arrow, state_arrow, ), file=sys.stderr) - data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] - print("=== data_pandas: %s type(data_pandas): %s state_pandas: %s type(state_pandas): %s" % (data_pandas, type(data_pandas), state_pandas, type(state_pandas), ), file=sys.stderr) - state_info_col = state_pandas.iloc[0] state_info_col_properties = state_info_col['properties'] @@ -453,8 +436,6 @@ def load_stream(self, stream): state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, valueSchema=self.state_object_schema, **state_properties) - print("=== data_pandas: %s state: %s" % (data_pandas, state, ), file=sys.stderr) - # state info yield (data_pandas, state, ) @@ -480,16 +461,12 @@ def init_stream_yield_batches(): # FIXME: arrow type to pandas type # FIXME: probably also need to check columns to validate? - print("==== pdf: %s len(pdf): %s" % (pdf, len(pdf), ), file=sys.stderr) - empty_data = len(pdf) == 0 if empty_data: # if returned DataFrame is empty with no column information, just create a new # DataFrame with empty row with column information pdf = pd.DataFrame(dict.fromkeys(pa.schema(return_schema).names), index=[0]) - print("==== pdf: %s state: %s return_schema: %s" % (pdf, state, return_schema, ), file=sys.stderr) - state_properties = state.json().encode("utf-8") state_key_row_as_binary = state._keyAsUnsafe state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) @@ -503,8 +480,6 @@ def init_stream_yield_batches(): state_pdf = pd.DataFrame.from_dict(state_dict) - print("==== pdf: %s return_schema: %s state_pdf: %s result_state arrow_schema: %s" % (pdf, return_schema, state_pdf, self.result_state_pdf_arrow_type, ), file=sys.stderr) - batch = self._create_batch([ (pdf, return_schema), (state_pdf, self.result_state_pdf_arrow_type)]) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8269083c72a0c..093ca40e920cc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -213,11 +213,7 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type): def wrapped(key_series, value_series, state): import pandas as pd - print("=== key_series: %s value_series: %s state: %s" % (key_series, value_series, state, ), file=sys.stderr) - key = tuple(s[0] for s in key_series) - print("=== key: %s" % (key, ), file=sys.stderr) - if state.hasTimedOut: # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), state) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index c3f5683c223cc..71bfa5785a709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -82,8 +82,6 @@ class ArrowPythonRunnerWithState( ) ) - logWarning(s"DEBUG: schemaWithState: ${schemaWithState}") - val stateRowSerializer = stateEncoder.createSerializer() val stateRowDeserializer = stateEncoder.createDeserializer() @@ -125,9 +123,6 @@ class ArrowPythonRunnerWithState( protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId) - - logWarning(s"DEBUG: arrowSchema: ${arrowSchema}") - val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) @@ -316,7 +311,6 @@ class ArrowPythonRunnerWithState( val newGroupState = GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson) val rowIterator = if (isEmptyData) { - logWarning("DEBUG: no data is available") Iterator.empty } else { // UDF returns a StructType column in ColumnarBatch, select the children here From f33d97805ebe5ad17d5ffb6d4d618e5a438586b1 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Aug 2022 15:06:55 +0900 Subject: [PATCH 13/44] WIP hack around to see the possibility of perf gain on binpacking --- python/pyspark/sql/pandas/serializers.py | 23 +++++- .../python/ArrowPythonRunnerWithState.scala | 78 ++++++++++--------- .../FlatMapGroupsInPandasWithStateExec.scala | 38 +++++---- 3 files changed, 87 insertions(+), 52 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 52f8eae957e15..b406720b1c9b3 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -451,6 +451,15 @@ def init_stream_yield_batches(): import pyarrow as pa should_write_start_length = True + + # FIXME: we are now very specific to the test code which always produces 1 output + # to experiment bin-packing. We are also assuming that all states & outputs do not + # grow that much if we pack to one. In reality we may need to try bin-packing to + # specific number of rows or size of data. + + pdfs = [] + state_pdfs = [] + for data in iterator: packaged_result = data[0] @@ -480,13 +489,23 @@ def init_stream_yield_batches(): state_pdf = pd.DataFrame.from_dict(state_dict) + pdfs.append(pdf) + state_pdfs.append(state_pdf) + + assert len(pdfs) == len(state_pdfs) + + if len(pdfs) > 0: + merged_pdf = pd.concat(pdfs, ignore_index=True) + merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) + batch = self._create_batch([ - (pdf, return_schema), - (state_pdf, self.result_state_pdf_arrow_type)]) + (merged_pdf, return_schema), + (merged_state_pdf, self.result_state_pdf_arrow_type)]) if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) should_write_start_length = False + yield batch return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 71bfa5785a709..01decf0d7314b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -61,7 +61,7 @@ class ArrowPythonRunnerWithState( stateSchema: StructType) extends BasePythonRunner[ (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]), - (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])]( + (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])]( funcs, evalType, argOffsets) { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback @@ -200,7 +200,8 @@ class ArrowPythonRunnerWithState( worker: Socket, pid: Option[Int], releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[(UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])] = { + context: TaskContext) + : Iterator[(Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])] = { new ReaderIterator( stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { @@ -223,7 +224,8 @@ class ArrowPythonRunnerWithState( private var batchLoaded = true - protected override def read(): (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) = { + protected override def read() + : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } @@ -269,8 +271,12 @@ class ArrowPythonRunnerWithState( } catch handleException } - private def deserializeColumnarBatch( - batch: ColumnarBatch): (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) = { + // FIXME: we are now very specific to the test code which always produces 1 output + // to experiment bin-packing. We are also assuming that all states & outputs do not + // grow that much if we pack to one. In reality we may need to try bin-packing to + // specific number of rows or size of data. + private def deserializeColumnarBatch(batch: ColumnarBatch) + : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { // this should at least have one row for state assert(batch.numRows() > 0) assert(schema.length == 2) @@ -279,40 +285,42 @@ class ArrowPythonRunnerWithState( val outputVectorsForState = schema(1).dataType.asInstanceOf[StructType] .indices.map(structVectorForState.getChild) val flattenedBatchForState = new ColumnarBatch(outputVectorsForState.toArray) - flattenedBatchForState.setNumRows(1) + flattenedBatchForState.setNumRows(batch.numRows()) + + val rowIteratorForState = flattenedBatchForState.rowIterator().asScala.map { row => + implicit val formats = org.json4s.DefaultFormats + + // FIXME: we rely on known schema for state info, but would we want to access this by + // column name? + // Received state information does not need schemas - this class already knows them. + /* + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField('isEmptyData', BooleanType) + ) + */ + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { + None + } else { + val pickledRow = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) + } - val rowForStateInfo = flattenedBatchForState.getRow(0) + // FIXME: does not hold true for experiment + // val isEmptyData = rowForStateInfo.getBoolean(3) - // FIXME: we rely on known schema for state info, but would we want to access this by - // column name? - // Received state information does not need schemas - this class already knows them. - /* - Array( - StructField("properties", StringType), - StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType), - StructField('isEmptyData', BooleanType) - ) - */ - implicit val formats = org.json4s.DefaultFormats - - val propertiesAsJson = parse(rowForStateInfo.getUTF8String(0).toString) - val keyRowAsUnsafeAsBinary = rowForStateInfo.getBinary(1) - val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) - keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) - val maybeObjectRow = if (rowForStateInfo.isNullAt(2)) { - None - } else { - val pickledRow = rowForStateInfo.getBinary(2) - Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) + (keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)) } - val isEmptyData = rowForStateInfo.getBoolean(3) - val newGroupState = GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson) + val rowForStateInfo = flattenedBatchForState.getRow(0) - val rowIterator = if (isEmptyData) { - Iterator.empty - } else { + val rowIteratorForOutput = { // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] val outputVectors = schema(0).dataType.asInstanceOf[StructType] @@ -324,7 +332,7 @@ class ArrowPythonRunnerWithState( rowIterator.map(unsafeProjForData) } - (keyRowAsUnsafe, newGroupState, rowIterator) + (rowIteratorForState, rowIteratorForOutput) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 2ed5dcbaed58f..e5e71b580b37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -159,6 +159,7 @@ case class FlatMapGroupsInPandasWithStateExec( stateType) val context = TaskContext.get() + val processIter = iter.map { case (keyRow, stateData, valueIter) => val groupedState = GroupStateImpl.createForStreaming( Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, @@ -169,24 +170,31 @@ case class FlatMapGroupsInPandasWithStateExec( watermarkPresent).asInstanceOf[GroupStateImpl[Row]] (keyRow, groupedState, valueIter) } + // FIXME: we are now very specific to the test code which always produces 1 output + // to experiment bin-packing. We are also assuming that all states & outputs do not + // grow that much if we pack to one. In reality we may need to try bin-packing to + // specific number of rows or size of data. runner.compute(processIter, context.partitionId(), context).flatMap { - case (keyRow, newGroupState, outputIter) => + case (stateIter, outputIter) => // When the iterator is consumed, then write changes to state + // state does not affect each others, hence when to update does not affect to the result def onIteratorCompletion: Unit = { - if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { - stateManager.removeState(store, keyRow) - numRemovedStateRows += 1 - } else { - val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs - .orElse(NO_TIMESTAMP) - val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || - newGroupState.isTimeoutUpdated - - if (shouldWriteState) { - val updatedStateObj = if (newGroupState.exists) newGroupState.get else null - stateManager.putState(store, keyRow, updatedStateObj, - currentTimeoutTimestamp) - numUpdatedStateRows += 1 + stateIter.foreach { case (keyRow, newGroupState) => + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs + .orElse(NO_TIMESTAMP) + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + newGroupState.isTimeoutUpdated + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, keyRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } } } } From 8604fdfeaba95dd3ee51290ac25d7942b407c42d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Aug 2022 15:07:30 +0900 Subject: [PATCH 14/44] WIP proper work to apply binpacking on python worker -> executor --- python/pyspark/sql/pandas/serializers.py | 89 +++++++++++------ .../python/ArrowPythonRunnerWithState.scala | 96 ++++++++++--------- .../FlatMapGroupsInPandasWithStateExec.scala | 4 - 3 files changed, 111 insertions(+), 78 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index b406720b1c9b3..c214a379cb64d 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -390,7 +390,6 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema StructField('properties', StringType()), StructField('keyRowAsUnsafe', BinaryType()), StructField('object', BinaryType()), - StructField('isEmptyData', BooleanType()), ]) self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) @@ -446,35 +445,59 @@ def dump_stream(self, iterator, stream): be sent back to the JVM before the Arrow stream starts. """ + def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + import pandas as pd + import pyarrow as pa + + max_data_cnt = max(pdf_data_cnt, state_data_cnt) + + empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt + empty_row_cnt_in_state = max_data_cnt - state_data_cnt + + empty_rows_pdf = pd.DataFrame( + dict.fromkeys(pa.schema(pdf_schema).names), + index=[x for x in range(0, empty_row_cnt_in_data)]) + empty_rows_state = pd.DataFrame( + columns=['properties', 'keyRowAsUnsafe', 'object'], + index=[x for x in range(0, empty_row_cnt_in_state)]) + + pdfs.append(empty_rows_pdf) + state_pdfs.append(empty_rows_state) + + merged_pdf = pd.concat(pdfs, ignore_index=True) + merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) + + return self._create_batch([ + (merged_pdf, pdf_schema), + (merged_state_pdf, self.result_state_pdf_arrow_type)]) + def init_stream_yield_batches(): import pandas as pd import pyarrow as pa should_write_start_length = True - # FIXME: we are now very specific to the test code which always produces 1 output - # to experiment bin-packing. We are also assuming that all states & outputs do not - # grow that much if we pack to one. In reality we may need to try bin-packing to - # specific number of rows or size of data. - pdfs = [] state_pdfs = [] + return_schema = None + + pdf_data_cnt = 0 + state_data_cnt = 0 for data in iterator: packaged_result = data[0] pdf = packaged_result[0][0] state = packaged_result[0][-1] + # this won't change across batches return_schema = packaged_result[1] # FIXME: arrow type to pandas type # FIXME: probably also need to check columns to validate? - empty_data = len(pdf) == 0 - if empty_data: - # if returned DataFrame is empty with no column information, just create a new - # DataFrame with empty row with column information - pdf = pd.DataFrame(dict.fromkeys(pa.schema(return_schema).names), index=[0]) + if len(pdf) > 0: + pdf_data_cnt += len(pdf) + pdfs.append(pdf) state_properties = state.json().encode("utf-8") state_key_row_as_binary = state._keyAsUnsafe @@ -484,27 +507,39 @@ def init_stream_yield_batches(): 'properties': [state_properties, ], 'keyRowAsUnsafe': [state_key_row_as_binary, ], 'object': [state_object, ], - 'isEmptyData': [empty_data, ], } state_pdf = pd.DataFrame.from_dict(state_dict) - pdfs.append(pdf) state_pdfs.append(state_pdf) - - assert len(pdfs) == len(state_pdfs) - - if len(pdfs) > 0: - merged_pdf = pd.concat(pdfs, ignore_index=True) - merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) - - batch = self._create_batch([ - (merged_pdf, return_schema), - (merged_state_pdf, self.result_state_pdf_arrow_type)]) - - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False + state_data_cnt = 1 + + max_data_cnt = max(pdf_data_cnt, state_data_cnt) + # FIXME: what would be the best criteria for threshold? + if max_data_cnt > 10000: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + # end of loop, we may have remaining data + if pdf_data_cnt > 0 or state_data_cnt > 0: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 yield batch diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 01decf0d7314b..83ebf27f1c77c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -271,56 +271,13 @@ class ArrowPythonRunnerWithState( } catch handleException } - // FIXME: we are now very specific to the test code which always produces 1 output - // to experiment bin-packing. We are also assuming that all states & outputs do not - // grow that much if we pack to one. In reality we may need to try bin-packing to - // specific number of rows or size of data. private def deserializeColumnarBatch(batch: ColumnarBatch) : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { // this should at least have one row for state assert(batch.numRows() > 0) assert(schema.length == 2) - val structVectorForState = batch.column(1).asInstanceOf[ArrowColumnVector] - val outputVectorsForState = schema(1).dataType.asInstanceOf[StructType] - .indices.map(structVectorForState.getChild) - val flattenedBatchForState = new ColumnarBatch(outputVectorsForState.toArray) - flattenedBatchForState.setNumRows(batch.numRows()) - - val rowIteratorForState = flattenedBatchForState.rowIterator().asScala.map { row => - implicit val formats = org.json4s.DefaultFormats - - // FIXME: we rely on known schema for state info, but would we want to access this by - // column name? - // Received state information does not need schemas - this class already knows them. - /* - Array( - StructField("properties", StringType), - StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType), - StructField('isEmptyData', BooleanType) - ) - */ - val propertiesAsJson = parse(row.getUTF8String(0).toString) - val keyRowAsUnsafeAsBinary = row.getBinary(1) - val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) - keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) - val maybeObjectRow = if (row.isNullAt(2)) { - None - } else { - val pickledRow = row.getBinary(2) - Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) - } - - // FIXME: does not hold true for experiment - // val isEmptyData = rowForStateInfo.getBoolean(3) - - (keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)) - } - - val rowForStateInfo = flattenedBatchForState.getRow(0) - - val rowIteratorForOutput = { + def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] val outputVectors = schema(0).dataType.asInstanceOf[StructType] @@ -328,11 +285,56 @@ class ArrowPythonRunnerWithState( val flattenedBatch = new ColumnarBatch(outputVectors.toArray) flattenedBatch.setNumRows(batch.numRows()) - val rowIterator = flattenedBatch.rowIterator.asScala - rowIterator.map(unsafeProjForData) + flattenedBatch.rowIterator.asScala.flatMap { row => + if (row.isNullAt(0)) { + None + } else { + Some(unsafeProjForData(row)) + } + } + } + + def constructIterForState( + batch: ColumnarBatch): Iterator[(UnsafeRow, GroupStateImpl[Row])] = { + val structVectorForState = batch.column(1).asInstanceOf[ArrowColumnVector] + val outputVectorsForState = schema(1).dataType.asInstanceOf[StructType] + .indices.map(structVectorForState.getChild) + val flattenedBatchForState = new ColumnarBatch(outputVectorsForState.toArray) + flattenedBatchForState.setNumRows(batch.numRows()) + + flattenedBatchForState.rowIterator().asScala.flatMap { row => + implicit val formats = org.json4s.DefaultFormats + + // FIXME: we rely on known schema for state info, but would we want to access this by + // column name? + // Received state information does not need schemas - this class already knows them. + /* + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + ) + */ + if (row.isNullAt(0)) { + None + } else { + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { + None + } else { + val pickledRow = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) + } + + Some(keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)) + } + } } - (rowIteratorForState, rowIteratorForOutput) + (constructIterForState(batch), constructIterForData(batch)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index e5e71b580b37c..d4c5d769945e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -170,10 +170,6 @@ case class FlatMapGroupsInPandasWithStateExec( watermarkPresent).asInstanceOf[GroupStateImpl[Row]] (keyRow, groupedState, valueIter) } - // FIXME: we are now very specific to the test code which always produces 1 output - // to experiment bin-packing. We are also assuming that all states & outputs do not - // grow that much if we pack to one. In reality we may need to try bin-packing to - // specific number of rows or size of data. runner.compute(processIter, context.partitionId(), context).flatMap { case (stateIter, outputIter) => // When the iterator is consumed, then write changes to state From 0d024e0cbc45878b7db6255fbce2ea847b80c7d5 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Aug 2022 15:08:06 +0900 Subject: [PATCH 15/44] WIP fix silly bug --- python/pyspark/sql/pandas/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index c214a379cb64d..00b09d093f440 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -512,7 +512,7 @@ def init_stream_yield_batches(): state_pdf = pd.DataFrame.from_dict(state_dict) state_pdfs.append(state_pdf) - state_data_cnt = 1 + state_data_cnt += 1 max_data_cnt = max(pdf_data_cnt, state_data_cnt) # FIXME: what would be the best criteria for threshold? From 43c623bc8408d6fda9413885062edbc989305d56 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Aug 2022 15:08:32 +0900 Subject: [PATCH 16/44] WIP another silly bugfix on migration --- python/pyspark/sql/pandas/serializers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 00b09d093f440..c13a303672b6f 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -541,6 +541,10 @@ def init_stream_yield_batches(): pdf_data_cnt = 0 state_data_cnt = 0 + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) From af1725ab357a9169fcc7d85b6c58412110618f4d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Aug 2022 15:09:06 +0900 Subject: [PATCH 17/44] WIP apply binpacking for executor -> python worker as well --- python/pyspark/sql/pandas/serializers.py | 42 ++++++++------ .../python/ArrowPythonRunnerWithState.scala | 56 ++++++++++++++++--- 2 files changed, 74 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index c13a303672b6f..3a3c38ee17839 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -413,30 +413,37 @@ def load_stream(self, stream): data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) - data_arrow = pa.Table.from_batches([data_batch]).itercolumns() state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - - data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] - state_info_col = state_pandas.iloc[0] + # FIXME: data_batch: should be "zero-copy" split + for state_idx in range(0, len(state_pandas)): + state_info_col = state_pandas.iloc[0] - state_info_col_properties = state_info_col['properties'] - state_info_col_key_row = state_info_col['keyRowAsUnsafe'] - state_info_col_object = state_info_col['object'] + state_info_col_properties = state_info_col['properties'] + state_info_col_key_row = state_info_col['keyRowAsUnsafe'] + state_info_col_object = state_info_col['object'] - state_properties = json.loads(state_info_col_properties) - if state_info_col_object: - state_object = self.pickleSer.loads(state_info_col_object) - else: - state_object = None - state_properties["optionalValue"] = state_object + data_start_offset = state_info_col['startOffset'] + num_data_rows = state_info_col['numRows'] + + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, **state_properties) + + data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) + data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() - state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, - valueSchema=self.state_object_schema, **state_properties) + data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] - # state info - yield (data_pandas, state, ) + # state info + yield (data_pandas, state, ) def dump_stream(self, iterator, stream): """ @@ -516,6 +523,7 @@ def init_stream_yield_batches(): max_data_cnt = max(pdf_data_cnt, state_data_cnt) # FIXME: what would be the best criteria for threshold? + # currently we arbitrarily set this via number of rows if max_data_cnt > 10000: batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 83ebf27f1c77c..a236551e5dbc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -77,7 +77,9 @@ class ArrowPythonRunnerWithState( Array( StructField("properties", StringType), StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType) + StructField("object", BinaryType), + StructField("startOffset", IntegerType), + StructField("numRows", IntegerType) ) ) ) @@ -110,12 +112,22 @@ class ArrowPythonRunnerWithState( private def buildStateInfoRow( keyRow: UnsafeRow, - groupState: GroupStateImpl[Row]): InternalRow = { + groupState: GroupStateImpl[Row], + startOffset: Int, + numRows: Int): InternalRow = { + // FIXME: document the schema + // - properties + // - keyRowAsUnsafe + // - state object as Row + // - startOffset + // - numRows val stateUnderlyingRow = new GenericInternalRow( Array[Any]( UTF8String.fromString(groupState.json()), keyRow.getBytes, - groupState.getOption.map(PythonSQLUtils.toPyRow).orNull + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, + startOffset, + numRows ) ) new GenericInternalRow(Array[Any](stateUnderlyingRow)) @@ -147,21 +159,50 @@ class ArrowPythonRunnerWithState( val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() + var numRowsForCurGroup = 0 + var startOffsetForCurGroup = 0 + var totalNumRowsForBatch = 0 + while (inputIterator.hasNext) { val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") - // Provide state info row in the first row - val stateInfoRow = buildStateInfoRow(keyRow, groupState) - arrowWriterForState.write(stateInfoRow) + numRowsForCurGroup = 0 - // Continue providing remaining data rows + // Provide data rows while (dataIter.hasNext) { val dataRow = dataIter.next() arrowWriterForData.write(dataRow) + numRowsForCurGroup += 1 + totalNumRowsForBatch += 1 } + // Provide state info row in the first row + val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, + numRowsForCurGroup) + arrowWriterForState.write(stateInfoRow) + + // FIXME: threshold as number of rows + // if we want to go with size, + // arrowWriterForState.sizeInBytes() + arrowWriterForState.sizeInBytes() + if (totalNumRowsForBatch > 10000) { + // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data + // side, as we know there is at least one data row. + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurGroup = 0 + totalNumRowsForBatch = 0 + } + } + + if (numRowsForCurGroup > 0) { + // need to flush remaining batch + // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data // side, as we know there is at least one data row. arrowWriterForState.finish() @@ -170,6 +211,7 @@ class ArrowPythonRunnerWithState( arrowWriterForState.reset() arrowWriterForData.reset() } + // end writes footer to the output stream and doesn't clean any resources. // It could throw exception if the output stream is closed, so it should be // in the try block. From 31e9687a53ef2ab90daa1673b57c0b8f7aa0920b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Aug 2022 15:09:44 +0900 Subject: [PATCH 18/44] WIP fix silly bug --- python/pyspark/sql/pandas/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 3a3c38ee17839..01f023199b890 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -418,7 +418,7 @@ def load_stream(self, stream): # FIXME: data_batch: should be "zero-copy" split for state_idx in range(0, len(state_pandas)): - state_info_col = state_pandas.iloc[0] + state_info_col = state_pandas.iloc[state_idx] state_info_col_properties = state_info_col['properties'] state_info_col_key_row = state_info_col['keyRowAsUnsafe'] From cad77a2ed4b97c91f1f0701a29fbc7b9279a56fd Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Aug 2022 16:38:46 +0900 Subject: [PATCH 19/44] WIP fix another silly bug --- python/pyspark/sql/pandas/serializers.py | 6 +++++- .../sql/execution/python/ArrowPythonRunnerWithState.scala | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 01f023199b890..b9370d7bacafc 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -416,10 +416,13 @@ def load_stream(self, stream): state_arrow = pa.Table.from_batches([state_batch]).itercolumns() state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] - # FIXME: data_batch: should be "zero-copy" split for state_idx in range(0, len(state_pandas)): state_info_col = state_pandas.iloc[state_idx] + if not state_info_col: + # no more data with grouping key + state + break + state_info_col_properties = state_info_col['properties'] state_info_col_key_row = state_info_col['keyRowAsUnsafe'] state_info_col_object = state_info_col['object'] @@ -438,6 +441,7 @@ def load_stream(self, stream): valueSchema=self.state_object_schema, **state_properties) data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) + data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index a236551e5dbc4..3689fd72329c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -183,6 +183,9 @@ class ArrowPythonRunnerWithState( numRowsForCurGroup) arrowWriterForState.write(stateInfoRow) + // start offset for next group would be same as the total number of rows for batch + startOffsetForCurGroup = totalNumRowsForBatch + // FIXME: threshold as number of rows // if we want to go with size, // arrowWriterForState.sizeInBytes() + arrowWriterForState.sizeInBytes() From c3da9966fc472616a1d9e46db816ab6d637cecd2 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 29 Aug 2022 16:04:01 +0900 Subject: [PATCH 20/44] WIP batching per specified size, with sampling --- python/pyspark/sql/pandas/serializers.py | 28 +++++++++++++++---- python/pyspark/sql/pandas/serializers.py.rej | 11 ++++++++ python/pyspark/worker.py | 19 ++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 20 +++++++++++++ .../python/ArrowPythonRunnerWithState.scala | 24 ++++++++++++---- .../FlatMapGroupsInPandasWithStateExec.scala | 16 +++++++++-- 6 files changed, 105 insertions(+), 13 deletions(-) create mode 100644 python/pyspark/sql/pandas/serializers.py.rej diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index b9370d7bacafc..7ef4822c1ea3e 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -379,7 +379,8 @@ def load_stream(self, stream): class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): - def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema): + def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema, + softLimitBytesPerBatch, minDataCountForSample): super(ApplyInPandasWithStateSerializer, self).__init__( timezone, safecheck, assign_cols_by_name) self.pickleSer = CPickleSerializer() @@ -393,6 +394,8 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema ]) self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) + self.softLimitBytesPerBatch = softLimitBytesPerBatch + self.minDataCountForSample = minDataCountForSample def load_stream(self, stream): import pyarrow as pa @@ -495,6 +498,11 @@ def init_stream_yield_batches(): pdf_data_cnt = 0 state_data_cnt = 0 + sampled_data_size_per_row = 0 + sampled_state_size = 0 + # FIXME: sample with empty state size separately? + sampled_empty_state_size = 0 + for data in iterator: packaged_result = data[0] @@ -525,10 +533,20 @@ def init_stream_yield_batches(): state_pdfs.append(state_pdf) state_data_cnt += 1 - max_data_cnt = max(pdf_data_cnt, state_data_cnt) - # FIXME: what would be the best criteria for threshold? - # currently we arbitrarily set this via number of rows - if max_data_cnt > 10000: + # FIXME: threshold of sample data + if sampled_data_size_per_row == 0 and pdf_data_cnt > self.minDataCountForSample: + memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs] + sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt + + # FIXME: threshold of sample data + if sampled_state_size == 0 and state_data_cnt > self.minDataCountForSample: + memory_usages = [p.memory_usage(deep=True).sum() for p in state_pdfs] + sampled_state_size = sum(memory_usages) / state_data_cnt + + # This effectively works after the sampling has completed, size we multiply by 0 + # if the sampling is still in progress. + if (sampled_data_size_per_row * pdf_data_cnt) + \ + (sampled_state_size * state_data_cnt) >= self.softLimitBytesPerBatch: batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt) diff --git a/python/pyspark/sql/pandas/serializers.py.rej b/python/pyspark/sql/pandas/serializers.py.rej new file mode 100644 index 0000000000000..6da4d4ee13ece --- /dev/null +++ b/python/pyspark/sql/pandas/serializers.py.rej @@ -0,0 +1,11 @@ +diff a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py (rejected hunks) +@@ -443,7 +443,8 @@ class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): + + class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + +- def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema): ++ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema, ++ softLimitBytesPerBatch, minDataCountForSample): + # Since we have a contract that first row in a arrow batch is a state row, we should be + # super careful on splitting the batch. + super(ApplyInPandasWithStateSerializer, self).__init__( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 093ca40e920cc..54e7355449a80 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -383,6 +383,8 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf[k] = v state_object_schema = None + softLimitBytesPerBatchInApplyInPandasWithState = None + minDataCountForSampleInApplyInPandasWithState = None if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) @@ -400,12 +402,27 @@ def read_udfs(pickleSer, infile, eval_type): == "true" ) + softLimitBytesPerBatchInApplyInPandasWithState = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch", (64 * 1024 * 1024) + ) + softLimitBytesPerBatchInApplyInPandasWithState = \ + int(softLimitBytesPerBatchInApplyInPandasWithState) + + minDataCountForSampleInApplyInPandasWithState = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.minDataCountForSample", 100 + ) + minDataCountForSampleInApplyInPandasWithState = \ + int(minDataCountForSampleInApplyInPandasWithState) + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - ser = ApplyInPandasWithStateSerializer(timezone, safecheck, assign_cols_by_name, state_object_schema) + ser = ApplyInPandasWithStateSerializer(timezone, safecheck, assign_cols_by_name, + state_object_schema, + softLimitBytesPerBatchInApplyInPandasWithState, + minDataCountForSampleInApplyInPandasWithState) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index de25c19a26eb8..04758684c897e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2705,6 +2705,20 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH = + buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch") + // FIXME: doc + .version("3.4.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("64MB") + + val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE = + buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample") + // FIXME: doc + .version("3.4.0") + .intConf + .createWithDefault(100) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -4529,6 +4543,12 @@ class SQLConf extends Serializable with Logging { def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION) + def softLimitBytesPerBatchInApplyInPandasWithState: Long = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH) + + def minDataCountForSampleInApplyInPandasWithState: Int = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 3689fd72329c9..ad4a769b6f565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -58,7 +58,9 @@ class ArrowPythonRunnerWithState( stateEncoder: ExpressionEncoder[Row], keySchema: StructType, valueSchema: StructType, - stateSchema: StructType) + stateSchema: StructType, + softLimitBytesPerBatch: Long, + minDataCountForSample: Int) extends BasePythonRunner[ (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]), (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])]( @@ -66,6 +68,7 @@ class ArrowPythonRunnerWithState( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + // FIXME: should we use this instead? override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize require( bufferSize >= 4, @@ -160,9 +163,12 @@ class ArrowPythonRunnerWithState( writer.start() var numRowsForCurGroup = 0 + var numStatesForCurGroup = 0 var startOffsetForCurGroup = 0 var totalNumRowsForBatch = 0 + var sampledDataSizePerRow = 0 + while (inputIterator.hasNext) { val (keyRow, groupState, dataIter) = inputIterator.next() @@ -182,14 +188,22 @@ class ArrowPythonRunnerWithState( val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, numRowsForCurGroup) arrowWriterForState.write(stateInfoRow) + numStatesForCurGroup += 1 // start offset for next group would be same as the total number of rows for batch startOffsetForCurGroup = totalNumRowsForBatch - // FIXME: threshold as number of rows - // if we want to go with size, - // arrowWriterForState.sizeInBytes() + arrowWriterForState.sizeInBytes() - if (totalNumRowsForBatch > 10000) { + // FIXME: threshold of sample data + if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { + sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch + } + + // This effectively works after the sampling has completed, size we multiply by 0 + // if the sampling is still in progress. + // FIXME: ignore state size for now, as we expect more number of data rather than + // number of state. + // FIXME: sample with empty state size separately? + if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch) { // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data // side, as we know there is at least one data row. arrowWriterForState.finish() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index d4c5d769945e1..42271b774db07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.TaskContext + import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -30,6 +31,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -78,8 +80,16 @@ case class FlatMapGroupsInPandasWithStateExec( override def output: Seq[Attribute] = outAttributes + private val softLimitBytesPerBatch = conf.softLimitBytesPerBatchInApplyInPandasWithState + private val minDataCountForSample = conf.minDataCountForSampleInApplyInPandasWithState + private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> + softLimitBytesPerBatch.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> + minDataCountForSample.toString) + private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( @@ -156,7 +166,9 @@ case class FlatMapGroupsInPandasWithStateExec( stateEncoder.asInstanceOf[ExpressionEncoder[Row]], groupingAttributes.toStructType, child.output.toStructType, - stateType) + stateType, + softLimitBytesPerBatch, + minDataCountForSample) val context = TaskContext.get() From cfb27807a3580985d04e622e1be86b5e973b21f8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 29 Aug 2022 17:19:12 +0900 Subject: [PATCH 21/44] WIP introduce DBR-only change --- .../spark/sql/execution/arrow/ArrowWriter.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 34e128a4925f6..bd27ad59bf039 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { count += 1 } + def sizeInBytes(): Int = { + var i = 0 + var bytes = 0 + while (i < fields.size) { + bytes += fields(i).getSizeInBytes() + i += 1 + } + bytes + } + def finish(): Unit = { root.setRowCount(count) fields.foreach(_.finish()) @@ -136,6 +146,10 @@ private[arrow] abstract class ArrowFieldWriter { valueVector.setValueCount(count) } + def getSizeInBytes(): Int = { + valueVector.getBufferSizeFor(count) + } + def reset(): Unit = { valueVector.reset() count = 0 From 228b140ee1e0a69e5627b9bfe6a89343e5a577d6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 29 Aug 2022 18:22:37 +0900 Subject: [PATCH 22/44] WIP debugging now... --- python/pyspark/sql/pandas/serializers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 7ef4822c1ea3e..a3af0485b5b31 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -413,12 +413,16 @@ def load_stream(self, stream): data_columns = batch_columns[0:-1] state_column = batch_columns[-1] + print("== data_columns: %s state_column: %s" % (data_columns, state_column, ), file=sys.stderr) + data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) state_arrow = pa.Table.from_batches([state_batch]).itercolumns() state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] + print("== data_batch: %s state_batch: %s" % (data_batch, state_batch, ), file=sys.stderr) + for state_idx in range(0, len(state_pandas)): state_info_col = state_pandas.iloc[state_idx] @@ -450,6 +454,9 @@ def load_stream(self, stream): data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] # state info + + print("== data_pandas: %s state: %s" % (data_pandas, state, ), file=sys.stderr) + yield (data_pandas, state, ) def dump_stream(self, iterator, stream): From ee4ed57954264165ebdceb2a304798fd66cd436f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 30 Aug 2022 16:49:13 +0900 Subject: [PATCH 23/44] WIP still debugging... weirdness happened --- python/pyspark/sql/pandas/serializers.py | 21 +++++++++++++++---- python/pyspark/sql/pandas/serializers.py.rej | 11 ---------- python/pyspark/worker.py | 8 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 14 +++++++++++-- .../python/ArrowPythonRunnerWithState.scala | 12 +++++++++-- .../FlatMapGroupsInPandasWithStateExec.scala | 15 ++++++++++--- .../FlatMapGroupsInPandasWithStateSuite.scala | 2 ++ 7 files changed, 60 insertions(+), 23 deletions(-) delete mode 100644 python/pyspark/sql/pandas/serializers.py.rej diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index a3af0485b5b31..abc42285e595c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,6 +19,7 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ import sys +import time from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer from pyspark.sql.pandas.types import to_arrow_type @@ -380,7 +381,7 @@ def load_stream(self, stream): class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema, - softLimitBytesPerBatch, minDataCountForSample): + softLimitBytesPerBatch, minDataCountForSample, softTimeoutMillisPurgeBatch): super(ApplyInPandasWithStateSerializer, self).__init__( timezone, safecheck, assign_cols_by_name) self.pickleSer = CPickleSerializer() @@ -396,6 +397,7 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) self.softLimitBytesPerBatch = softLimitBytesPerBatch self.minDataCountForSample = minDataCountForSample + self.softTimeoutMillisPurgeBatch = softTimeoutMillisPurgeBatch def load_stream(self, stream): import pyarrow as pa @@ -405,15 +407,19 @@ def load_stream(self, stream): batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) for batch in batches: + print("== batch: %s" % (batch, ), file=sys.stderr) + batch_schema = batch.schema data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) state_schema = pa.schema([batch_schema[-1], ]) + print("== batch_schema: %s data_schema: %s state_schema: %s" % (batch_schema, data_schema, state_schema, ), file=sys.stderr) + batch_columns = batch.columns data_columns = batch_columns[0:-1] state_column = batch_columns[-1] - print("== data_columns: %s state_column: %s" % (data_columns, state_column, ), file=sys.stderr) + print("== batch_columns: %s data_columns: %s state_column: %s" % (batch_columns, data_columns, state_column, ), file=sys.stderr) data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) @@ -510,6 +516,8 @@ def init_stream_yield_batches(): # FIXME: sample with empty state size separately? sampled_empty_state_size = 0 + last_purged_time_ns = time.time_ns() + for data in iterator: packaged_result = data[0] @@ -552,8 +560,12 @@ def init_stream_yield_batches(): # This effectively works after the sampling has completed, size we multiply by 0 # if the sampling is still in progress. - if (sampled_data_size_per_row * pdf_data_cnt) + \ - (sampled_state_size * state_data_cnt) >= self.softLimitBytesPerBatch: + batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) + \ + (sampled_state_size * state_data_cnt) >= self.softLimitBytesPerBatch + cur_time_ns = time.time_ns() + is_timed_out_on_purge = ((cur_time_ns - last_purged_time_ns) // 1000000) >= \ + self.softTimeoutMillisPurgeBatch + if batch_over_limit_on_size or is_timed_out_on_purge: batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt) @@ -561,6 +573,7 @@ def init_stream_yield_batches(): state_pdfs = [] pdf_data_cnt = 0 state_data_cnt = 0 + last_purged_time_ns = cur_time_ns if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) diff --git a/python/pyspark/sql/pandas/serializers.py.rej b/python/pyspark/sql/pandas/serializers.py.rej deleted file mode 100644 index 6da4d4ee13ece..0000000000000 --- a/python/pyspark/sql/pandas/serializers.py.rej +++ /dev/null @@ -1,11 +0,0 @@ -diff a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py (rejected hunks) -@@ -443,7 +443,8 @@ class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): - - class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): - -- def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema): -+ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema, -+ softLimitBytesPerBatch, minDataCountForSample): - # Since we have a contract that first row in a arrow batch is a state row, we should be - # super careful on splitting the batch. - super(ApplyInPandasWithStateSerializer, self).__init__( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 54e7355449a80..4fad720ed6035 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -413,6 +413,11 @@ def read_udfs(pickleSer, infile, eval_type): ) minDataCountForSampleInApplyInPandasWithState = \ int(minDataCountForSampleInApplyInPandasWithState) + softTimeoutMillisPurgeBatchInApplyInPandasWithState = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch", 100 + ) + softTimeoutMillisPurgeBatchInApplyInPandasWithState = \ + int(softTimeoutMillisPurgeBatchInApplyInPandasWithState) if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) @@ -422,7 +427,8 @@ def read_udfs(pickleSer, infile, eval_type): ser = ApplyInPandasWithStateSerializer(timezone, safecheck, assign_cols_by_name, state_object_schema, softLimitBytesPerBatchInApplyInPandasWithState, - minDataCountForSampleInApplyInPandasWithState) + minDataCountForSampleInApplyInPandasWithState, + softTimeoutMillisPurgeBatchInApplyInPandasWithState) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04758684c897e..521b2b3a897dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2719,6 +2719,13 @@ object SQLConf { .intConf .createWithDefault(100) + val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH = + buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch") + // FIXME: doc + .version("3.4.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("100ms") + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -4544,10 +4551,13 @@ class SQLConf extends Serializable with Logging { def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION) def softLimitBytesPerBatchInApplyInPandasWithState: Long = - getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH) + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH) def minDataCountForSampleInApplyInPandasWithState: Int = - getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE) + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE) + + def softTimeoutMillisPurgeBatchInApplyInPandasWithState: Long = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH) def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index ad4a769b6f565..f2b8f1cf31c01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -60,7 +60,8 @@ class ArrowPythonRunnerWithState( valueSchema: StructType, stateSchema: StructType, softLimitBytesPerBatch: Long, - minDataCountForSample: Int) + minDataCountForSample: Int, + softTimeoutMillsPurgeBatch: Long) extends BasePythonRunner[ (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]), (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])]( @@ -169,7 +170,11 @@ class ArrowPythonRunnerWithState( var sampledDataSizePerRow = 0 + var lastBatchPurgedMillis = System.currentTimeMillis() + while (inputIterator.hasNext) { + logWarning(s" writer content so far: ") + val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") @@ -179,6 +184,7 @@ class ArrowPythonRunnerWithState( // Provide data rows while (dataIter.hasNext) { val dataRow = dataIter.next() + logWarning(s" dataRow: $dataRow") arrowWriterForData.write(dataRow) numRowsForCurGroup += 1 totalNumRowsForBatch += 1 @@ -203,7 +209,8 @@ class ArrowPythonRunnerWithState( // FIXME: ignore state size for now, as we expect more number of data rather than // number of state. // FIXME: sample with empty state size separately? - if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch) { + if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch || + System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data // side, as we know there is at least one data row. arrowWriterForState.finish() @@ -214,6 +221,7 @@ class ArrowPythonRunnerWithState( startOffsetForCurGroup = 0 totalNumRowsForBatch = 0 + lastBatchPurgedMillis = System.currentTimeMillis() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 42271b774db07..7078bf3acc2c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -82,19 +82,27 @@ case class FlatMapGroupsInPandasWithStateExec( private val softLimitBytesPerBatch = conf.softLimitBytesPerBatchInApplyInPandasWithState private val minDataCountForSample = conf.minDataCountForSampleInApplyInPandasWithState + private val softTimeoutMillsPurgeBatch = conf.softTimeoutMillisPurgeBatchInApplyInPandasWithState private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> softLimitBytesPerBatch.toString) + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> - minDataCountForSample.toString) + minDataCountForSample.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> + softTimeoutMillsPurgeBatch.toString) private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( groupingAttributes ++ child.output, groupingAttributes) + + logWarning(s" dedupAttributes: $dedupAttributes / child.output: ${child.output} / argOffsets: $argOffsets") + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) + private lazy val unsafeProjForTimeoutDummyRow = + UnsafeProjection.create(dedupAttributes, groupingAttributes ++ child.output) override def requiredChildDistribution: Seq[Distribution] = StatefulOperatorPartitioning.getCompatibleDistribution( @@ -141,7 +149,7 @@ case class FlatMapGroupsInPandasWithStateExec( } val processIter = timingOutPairs.map { stateData => - val joinedKeyRow = unsafeProj( + val joinedKeyRow = unsafeProjForTimeoutDummyRow( new JoinedRow( stateData.keyRow, new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) @@ -168,7 +176,8 @@ case class FlatMapGroupsInPandasWithStateExec( child.output.toStructType, stateType, softLimitBytesPerBatch, - minDataCountForSample) + minDataCountForSample, + softTimeoutMillsPurgeBatch) val context = TaskContext.get() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 03d3fd6dcff1e..372d36ba76f89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -107,6 +107,8 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { ) } + // FIXME: we haven't had any test to produce multiple outputs + test("applyInPandasWithState - streaming + aggregation") { assume(shouldTestPandasUDFs) From 4045ab36c3bac1a57d56227acde86edaeea315fd Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 30 Aug 2022 17:31:03 +0900 Subject: [PATCH 24/44] WIP small fix --- .../execution/python/ArrowPythonRunnerWithState.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index f2b8f1cf31c01..52ef3b7ffe20b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -164,8 +164,8 @@ class ArrowPythonRunnerWithState( writer.start() var numRowsForCurGroup = 0 - var numStatesForCurGroup = 0 var startOffsetForCurGroup = 0 + var numStatesForCurBatch = 0 var totalNumRowsForBatch = 0 var sampledDataSizePerRow = 0 @@ -194,7 +194,7 @@ class ArrowPythonRunnerWithState( val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, numRowsForCurGroup) arrowWriterForState.write(stateInfoRow) - numStatesForCurGroup += 1 + numStatesForCurBatch += 1 // start offset for next group would be same as the total number of rows for batch startOffsetForCurGroup = totalNumRowsForBatch @@ -213,6 +213,9 @@ class ArrowPythonRunnerWithState( System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data // side, as we know there is at least one data row. + + logWarning(" purging batch!") + arrowWriterForState.finish() arrowWriterForData.finish() writer.writeBatch() @@ -221,6 +224,7 @@ class ArrowPythonRunnerWithState( startOffsetForCurGroup = 0 totalNumRowsForBatch = 0 + numStatesForCurBatch = 0 lastBatchPurgedMillis = System.currentTimeMillis() } } @@ -230,6 +234,9 @@ class ArrowPythonRunnerWithState( // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data // side, as we know there is at least one data row. + + logWarning(" purging batch!") + arrowWriterForState.finish() arrowWriterForData.finish() writer.writeBatch() From 2d115abfd16f55e7671082d90d55459c50ec1ea2 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 30 Aug 2022 17:46:00 +0900 Subject: [PATCH 25/44] WIP fix a serious bug... make sure all columns in Arrow RecordBatch have same number of elements --- python/pyspark/sql/pandas/serializers.py | 11 ---- .../python/ArrowPythonRunnerWithState.scala | 30 ++++----- .../FlatMapGroupsInPandasWithStateExec.scala | 6 +- .../FlatMapGroupsInPandasWithStateSuite.scala | 63 ++++++++++++++++++- 4 files changed, 78 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index abc42285e595c..e5d592c875661 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -407,28 +407,20 @@ def load_stream(self, stream): batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) for batch in batches: - print("== batch: %s" % (batch, ), file=sys.stderr) - batch_schema = batch.schema data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) state_schema = pa.schema([batch_schema[-1], ]) - print("== batch_schema: %s data_schema: %s state_schema: %s" % (batch_schema, data_schema, state_schema, ), file=sys.stderr) - batch_columns = batch.columns data_columns = batch_columns[0:-1] state_column = batch_columns[-1] - print("== batch_columns: %s data_columns: %s state_column: %s" % (batch_columns, data_columns, state_column, ), file=sys.stderr) - data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) state_arrow = pa.Table.from_batches([state_batch]).itercolumns() state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] - print("== data_batch: %s state_batch: %s" % (data_batch, state_batch, ), file=sys.stderr) - for state_idx in range(0, len(state_pandas)): state_info_col = state_pandas.iloc[state_idx] @@ -460,9 +452,6 @@ def load_stream(self, stream): data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] # state info - - print("== data_pandas: %s state: %s" % (data_pandas, state, ), file=sys.stderr) - yield (data_pandas, state, ) def dump_stream(self, iterator, stream): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 52ef3b7ffe20b..8ceaf9dc7ef73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -109,6 +109,9 @@ class ArrowPythonRunnerWithState( context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { + private val EMPTY_STATE_INFO_ROW = + new GenericInternalRow(Array[Any](null, null, null, null, null)) + protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) @@ -165,16 +168,14 @@ class ArrowPythonRunnerWithState( var numRowsForCurGroup = 0 var startOffsetForCurGroup = 0 - var numStatesForCurBatch = 0 var totalNumRowsForBatch = 0 + var totalNumStatesForBatch = 0 var sampledDataSizePerRow = 0 var lastBatchPurgedMillis = System.currentTimeMillis() while (inputIterator.hasNext) { - logWarning(s" writer content so far: ") - val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") @@ -184,7 +185,6 @@ class ArrowPythonRunnerWithState( // Provide data rows while (dataIter.hasNext) { val dataRow = dataIter.next() - logWarning(s" dataRow: $dataRow") arrowWriterForData.write(dataRow) numRowsForCurGroup += 1 totalNumRowsForBatch += 1 @@ -194,7 +194,7 @@ class ArrowPythonRunnerWithState( val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, numRowsForCurGroup) arrowWriterForState.write(stateInfoRow) - numStatesForCurBatch += 1 + totalNumStatesForBatch += 1 // start offset for next group would be same as the total number of rows for batch startOffsetForCurGroup = totalNumRowsForBatch @@ -211,10 +211,10 @@ class ArrowPythonRunnerWithState( // FIXME: sample with empty state size separately? if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch || System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { - // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data - // side, as we know there is at least one data row. - - logWarning(" purging batch!") + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_INFO_ROW) + } arrowWriterForState.finish() arrowWriterForData.finish() @@ -224,7 +224,7 @@ class ArrowPythonRunnerWithState( startOffsetForCurGroup = 0 totalNumRowsForBatch = 0 - numStatesForCurBatch = 0 + totalNumStatesForBatch = 0 lastBatchPurgedMillis = System.currentTimeMillis() } } @@ -232,12 +232,12 @@ class ArrowPythonRunnerWithState( if (numRowsForCurGroup > 0) { // need to flush remaining batch - // DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data - // side, as we know there is at least one data row. - - logWarning(" purging batch!") + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_INFO_ROW) + } - arrowWriterForState.finish() + arrowWriterForState.finish() arrowWriterForData.finish() writer.writeBatch() arrowWriterForState.reset() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 7078bf3acc2c6..18d5f95815952 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -98,11 +98,7 @@ case class FlatMapGroupsInPandasWithStateExec( private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( groupingAttributes ++ child.output, groupingAttributes) - logWarning(s" dedupAttributes: $dedupAttributes / child.output: ${child.output} / argOffsets: $argOffsets") - private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) - private lazy val unsafeProjForTimeoutDummyRow = - UnsafeProjection.create(dedupAttributes, groupingAttributes ++ child.output) override def requiredChildDistribution: Seq[Distribution] = StatefulOperatorPartitioning.getCompatibleDistribution( @@ -149,7 +145,7 @@ case class FlatMapGroupsInPandasWithStateExec( } val processIter = timingOutPairs.map { stateData => - val joinedKeyRow = unsafeProjForTimeoutDummyRow( + val joinedKeyRow = unsafeProj( new JoinedRow( stateData.keyRow, new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 372d36ba76f89..3aae51bc7e51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -107,7 +107,68 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { ) } - // FIXME: we haven't had any test to produce multiple outputs + test("applyInPandasWithState - streaming, multiple groups in partition, " + + "multiple outputs per grouping key") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | state.update((count,)) + | return pdf.rename(columns={"value": "key"}).assign(countAsString=str(count)) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "a", "a", "b"), + CheckNewAnswer(("a", "4"), ("a", "4"), ("a", "4"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "b", "c", "d", "e", "f", "g"), + CheckNewAnswer(("b", "2"), ("c", "1"), ("d", "1"), ("e", "1"), + ("f", "1"), ("g", "1")), + assertNumStateRows(total = 7, updated = 6) + ) + } + } test("applyInPandasWithState - streaming + aggregation") { assume(shouldTestPandasUDFs) From 3e7d7853cf0d9fea950f9868497a099da74c7052 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 30 Aug 2022 20:24:15 +0900 Subject: [PATCH 26/44] WIP strengthen test --- .../FlatMapGroupsInPandasWithStateSuite.scala | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 3aae51bc7e51d..671621aa43b3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -111,15 +111,14 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { "multiple outputs per grouping key") { assume(shouldTestPandasUDFs) - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything val pythonScript = """ |import pandas as pd - |from pyspark.sql.types import StructType, StructField, StringType + |from pyspark.sql.types import IntegerType, StructType, StructField, StringType | |tpe = StructType([ | StructField("key", StringType()), + | StructField("value", IntegerType()), | StructField("countAsString", StringType())]) | |def func(key, pdf, state): @@ -128,44 +127,45 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = 0 | else: | count = count[0] - | count += len(pdf) + | count = count + len(pdf) | state.update((count,)) - | return pdf.rename(columns={"value": "key"}).assign(countAsString=str(count)) + | return pdf.assign(countAsString=str(count)) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - val inputData = MemoryStream[String] + val inputData = MemoryStream[(String, Int)] val outputStructType = StructType( Seq( StructField("key", StringType), + StructField("value", IntegerType), StructField("countAsString", StringType))) val stateStructType = StructType(Seq(StructField("count", LongType))) - val inputDataDS = inputData.toDS() + val inputDataDS = inputData.toDS().selectExpr("_1 AS key", "_2 AS value") val result = inputDataDS - .groupBy("value") + .groupBy("key") .applyInPandasWithState( - pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("key"), inputDataDS("value")).expr.asInstanceOf[PythonUDF], outputStructType, stateStructType, "Update", "NoTimeout") + .select("key", "value", "countAsString") testStream(result, Update)( - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), + AddData(inputData, ("a", 1)), + CheckNewAnswer(("a", 1, "1")), assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "a", "a", "b"), - CheckNewAnswer(("a", "4"), ("a", "4"), ("a", "4"), ("b", "1")), + AddData(inputData, ("a", 2), ("a", 3), ("b", 1)), + CheckNewAnswer(("a", 2, "3"), ("a", 3, "3"), ("b", 1, "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), - AddData(inputData, "b", "c", "d", "e", "f", "g"), - CheckNewAnswer(("b", "2"), ("c", "1"), ("d", "1"), ("e", "1"), - ("f", "1"), ("g", "1")), - assertNumStateRows(total = 7, updated = 6) + AddData(inputData, ("b", 2), ("c", 1), ("d", 1), ("e", 1)), + CheckNewAnswer(("b", 2, "2"), ("c", 1, "1"), ("d", 1, "1"), ("e", 1, "1")), + assertNumStateRows(total = 5, updated = 4) ) } } From 029dae76d7c137c9a91f9c3cf9e744850b31e4e2 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 2 Sep 2022 11:30:17 +0900 Subject: [PATCH 27/44] WIP documenting the changes for pipelining and bin-packing... not yet done --- .../python/ArrowPythonRunnerWithState.scala | 188 ++++++++++++------ .../FlatMapGroupsInPandasWithStateExec.scala | 27 +-- .../execution/streaming/GroupStateImpl.scala | 4 +- 3 files changed, 139 insertions(+), 80 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 8ceaf9dc7ef73..cf78b5fea3760 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -58,7 +58,7 @@ class ArrowPythonRunnerWithState( stateEncoder: ExpressionEncoder[Row], keySchema: StructType, valueSchema: StructType, - stateSchema: StructType, + stateValueSchema: StructType, softLimitBytesPerBatch: Long, minDataCountForSample: Int, softTimeoutMillsPurgeBatch: Long) @@ -67,38 +67,46 @@ class ArrowPythonRunnerWithState( (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])]( funcs, evalType, argOffsets) { + import ArrowPythonRunnerWithState._ + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - // FIXME: should we use this instead? override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize require( bufferSize >= 4, "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") - val schemaWithState = inputSchema.add("!__state__!", - StructType( - Array( - StructField("properties", StringType), - StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType), - StructField("startOffset", IntegerType), - StructField("numRows", IntegerType) - ) + private val stateMetadataSchema = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("startOffset", IntegerType), + StructField("numRows", IntegerType) ) ) - val stateRowSerializer = stateEncoder.createSerializer() - val stateRowDeserializer = stateEncoder.createDeserializer() + private val schemaWithState = inputSchema.add("!__state__!", stateMetadataSchema) + + private val stateRowDeserializer = stateEncoder.createDeserializer() + + private val workerConfWithRunnerConfs = workerConf + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> + softLimitBytesPerBatch.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> + minDataCountForSample.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> + softTimeoutMillsPurgeBatch.toString) protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { // Write config for the worker as a number of key -> value pairs of strings - stream.writeInt(workerConf.size) - for ((k, v) <- workerConf) { + stream.writeInt(workerConfWithRunnerConfs.size) + for ((k, v) <- workerConfWithRunnerConfs) { PythonRDD.writeUTF(k, stream) PythonRDD.writeUTF(v, stream) } - PythonRDD.writeUTF(stateSchema.json, stream) + PythonRDD.writeUTF(stateValueSchema.json, stream) } protected override def newWriterThread( @@ -109,9 +117,6 @@ class ArrowPythonRunnerWithState( context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { - private val EMPTY_STATE_INFO_ROW = - new GenericInternalRow(Array[Any](null, null, null, null, null)) - protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) @@ -122,12 +127,7 @@ class ArrowPythonRunnerWithState( groupState: GroupStateImpl[Row], startOffset: Int, numRows: Int): InternalRow = { - // FIXME: document the schema - // - properties - // - keyRowAsUnsafe - // - state object as Row - // - startOffset - // - numRows + // NOTE: see ArrowPythonRunnerWithState.STATE_METADATA_SCHEMA val stateUnderlyingRow = new GenericInternalRow( Array[Any]( UTF8String.fromString(groupState.json()), @@ -141,12 +141,21 @@ class ArrowPythonRunnerWithState( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + // We initialize all columns in data & state metadata for Arrow RecordBatch. val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { + // We logically group the columns by family and initialize writer separately, since it's + // lot more easier and probably performant to write the row directly rather than + // projecting the row to match up with the overall schema. + // The number of data rows and state metadata rows can be different which seems to matter + // for Arrow RecordBatch, so we append empty rows to cover it. + // We always produce at least one data row per grouping key whereas we only produce one + // state metadata row per grouping key, so we only need to fill up the empty rows in + // state metadata side. val arrowWriterForData = { val children = root.getFieldVectors().asScala.dropRight(1).map { vector => vector.allocateNew() @@ -166,13 +175,45 @@ class ArrowPythonRunnerWithState( val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() + // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to + // gain the performance. In many cases, the amount of data per grouping key is quite + // small, which does not seem to maximize the benefits of using Arrow. + // + // We have to split the record batch down to each group in Python worker to convert the + // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split + // the range of data and give a view, say, "zero-copy". To help splitting the range for + // data, we provide the "start offset" and the "number of data" in the state metadata. + // + // FIXME: probably need to change this as "hard limit" when addressing scalability. Worth + // noting that we may need to break down the data into chunks for a specific group + // having "small" number of data, because we also do bin-packing as well. Maybe we could + // concatenate these chunks in Python worker (serializer), with some hints e.g. + // We can get the information - the number of data in the chunk before reading. + // + // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft + // limit on the size - it's not a hard limit since we allow current group to write all + // data even it's going to exceed the limit. + // + // We perform some basic sampling for data to guess the size of the data very roughly, + // and simply multiply by the number of data to estimate the size. We extract the size of + // data from the record batch rather than UnsafeRow, as we don't hold the memory for + // UnsafeRow once we write to the record batch. If there is a memory bound here, it + // should come from record batch. + // + // In the meanwhile, we don't also want to let the current record batch collect the data + // indefinitely, since we are pipelining the process between executor and python worker. + // Python worker won't process any data if executor is not yet finalized a record + // batch, which defeats the purpose of pipelining. To address this, we also introduce + // timeout for constructing a record batch. This is a soft limit indeed as same as limit + // on the size - we allow current group to write all data even it's timed-out. + + // FIXME: Maybe better if we can extract out the batching logic into a separate class. var numRowsForCurGroup = 0 var startOffsetForCurGroup = 0 var totalNumRowsForBatch = 0 var totalNumStatesForBatch = 0 var sampledDataSizePerRow = 0 - var lastBatchPurgedMillis = System.currentTimeMillis() while (inputIterator.hasNext) { @@ -185,35 +226,50 @@ class ArrowPythonRunnerWithState( // Provide data rows while (dataIter.hasNext) { val dataRow = dataIter.next() + // TODO: if we think there will be non-small amount of data per grouping key, + // we could probably try out "dictionary encoding" for the optimization + // of storing same grouping keys multiple times. This may complicate the logic, as + // in IPC streaming format, DictionaryBatch will be provided separately along with + // RecordBatch, and I'm not sure whether the record batch can be directly converted + // to Pandas DataFrame / Series if the record batch refers to the dictionary batch. + // https://arrow.apache.org/docs/format/Columnar.html#dictionary-encoded-layout + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format arrowWriterForData.write(dataRow) numRowsForCurGroup += 1 totalNumRowsForBatch += 1 } - // Provide state info row in the first row + // Provide state metadata row val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, numRowsForCurGroup) arrowWriterForState.write(stateInfoRow) totalNumStatesForBatch += 1 - // start offset for next group would be same as the total number of rows for batch + // The start offset for next group would be same as the total number of rows for batch, + // unless the next group starts with new batch. startOffsetForCurGroup = totalNumRowsForBatch - // FIXME: threshold of sample data + // FIXME: Do we need to come up with sampling "across record batches"? + // FIXME: Do we need to also come up with the size of state metadata as well? + // FIXME: Do we need to separate the case of "state with value" vs + // "state without value" on sampling? + + // Currently, this only works when the number of rows are greater than the minimum + // data count for sampling. And we technically have no way to pick some rows from + // record batch and measure the size of data, hence we leverage all data in current + // record batch. We only sample once as it could be costly. if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch } - // This effectively works after the sampling has completed, size we multiply by 0 - // if the sampling is still in progress. - // FIXME: ignore state size for now, as we expect more number of data rather than - // number of state. - // FIXME: sample with empty state size separately? + // The soft-limit on size effectively works after the sampling has completed, since we + // multiply the number of rows by 0 if the sampling is still in progress. The + // soft-limit on timeout always applies. if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch || System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch (0 until remainingEmptyStateRows).foreach { _ => - arrowWriterForState.write(EMPTY_STATE_INFO_ROW) + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) } arrowWriterForState.finish() @@ -230,11 +286,10 @@ class ArrowPythonRunnerWithState( } if (numRowsForCurGroup > 0) { - // need to flush remaining batch - + // We still have some rows in the current record batch. Need to flush them as well. val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch (0 until remainingEmptyStateRows).foreach { _ => - arrowWriterForState.write(EMPTY_STATE_INFO_ROW) + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) } arrowWriterForState.finish() @@ -321,11 +376,9 @@ class ArrowPythonRunnerWithState( case SpecialLengths.START_ARROW_STREAM => reader = new ArrowStreamReader(stream, allocator) root = reader.getVectorSchemaRoot() - // FIXME: should we validate schema here with value schema and state schema? schema = ArrowUtils.fromArrowSchema(root.getSchema()) val dataAttributes = schema(0).dataType.asInstanceOf[StructType].toAttributes - unsafeProjForData = UnsafeProjection.create(dataAttributes, dataAttributes) vectors = root.getFieldVectors().asScala.map { vector => @@ -347,13 +400,16 @@ class ArrowPythonRunnerWithState( private def deserializeColumnarBatch(batch: ColumnarBatch) : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { - // this should at least have one row for state + // This should at least have one row for state. Also, we ensure that all columns across + // data and state metadata have same number of rows, which is required by Arrow record + // batch. assert(batch.numRows() > 0) assert(schema.length == 2) def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + // FIXME: should we validate schema here with value schema? val outputVectors = schema(0).dataType.asInstanceOf[StructType] .indices.map(structVector.getChild) val flattenedBatch = new ColumnarBatch(outputVectors.toArray) @@ -361,8 +417,10 @@ class ArrowPythonRunnerWithState( flattenedBatch.rowIterator.asScala.flatMap { row => if (row.isNullAt(0)) { + // The entire row in record batch seems to be for state metadata. None } else { + // FIXME: would it work without this projection? Some(unsafeProjForData(row)) } } @@ -370,28 +428,28 @@ class ArrowPythonRunnerWithState( def constructIterForState( batch: ColumnarBatch): Iterator[(UnsafeRow, GroupStateImpl[Row])] = { - val structVectorForState = batch.column(1).asInstanceOf[ArrowColumnVector] - val outputVectorsForState = schema(1).dataType.asInstanceOf[StructType] - .indices.map(structVectorForState.getChild) - val flattenedBatchForState = new ColumnarBatch(outputVectorsForState.toArray) + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(1).asInstanceOf[ArrowColumnVector] + // FIXME: should we validate schema here with state metadata schema? + val outputVectors = schema(1).dataType.asInstanceOf[StructType] + .indices.map(structVector.getChild) + val flattenedBatchForState = new ColumnarBatch(outputVectors.toArray) flattenedBatchForState.setNumRows(batch.numRows()) flattenedBatchForState.rowIterator().asScala.flatMap { row => implicit val formats = org.json4s.DefaultFormats - // FIXME: we rely on known schema for state info, but would we want to access this by - // column name? - // Received state information does not need schemas - this class already knows them. - /* - Array( - StructField("properties", StringType), - StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType), - ) - */ if (row.isNullAt(0)) { + // The entire row in record batch seems to be for data. None } else { + // Received state metadata does not need schema - this class already knows them. + // Array( + // StructField("properties", StringType), + // StructField("keyRowAsUnsafe", BinaryType), + // StructField("object", BinaryType), + // ) + // TODO: Do we want to rely on the column name rather than the ordinal for safety? val propertiesAsJson = parse(row.getUTF8String(0).toString) val keyRowAsUnsafeAsBinary = row.getBinary(1) val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) @@ -399,8 +457,9 @@ class ArrowPythonRunnerWithState( val maybeObjectRow = if (row.isNullAt(2)) { None } else { - val pickledRow = row.getBinary(2) - Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer)) + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) } Some(keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)) @@ -413,3 +472,18 @@ class ArrowPythonRunnerWithState( } } } + +object ArrowPythonRunnerWithState { + val STATE_METADATA_SCHEMA: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("startOffset", IntegerType), + StructField("numRows", IntegerType) + ) + ) + + // To avoid initializing a new row for empty state metadata row. + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow(Array[Any](null, null, null, null, null)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 18d5f95815952..6971aa8bbd155 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -80,18 +80,8 @@ case class FlatMapGroupsInPandasWithStateExec( override def output: Seq[Attribute] = outAttributes - private val softLimitBytesPerBatch = conf.softLimitBytesPerBatchInApplyInPandasWithState - private val minDataCountForSample = conf.minDataCountForSampleInApplyInPandasWithState - private val softTimeoutMillsPurgeBatch = conf.softTimeoutMillisPurgeBatchInApplyInPandasWithState - private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> - softLimitBytesPerBatch.toString) + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> - minDataCountForSample.toString) + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> - softTimeoutMillsPurgeBatch.toString) + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) @@ -115,10 +105,6 @@ case class FlatMapGroupsInPandasWithStateExec( override def createInputProcessor( store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { - /** - * For every group, get the key, values and corresponding state and call the function, - * and return an iterator of rows - */ override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) val processIter = groupedIter.map { case (keyRow, valueRowIter) => @@ -130,7 +116,6 @@ case class FlatMapGroupsInPandasWithStateExec( process(processIter, hasTimedOut = false) } - /** Find the groups that have timeout set and are timing out right now, and call the function */ override def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { @@ -171,9 +156,9 @@ case class FlatMapGroupsInPandasWithStateExec( groupingAttributes.toStructType, child.output.toStructType, stateType, - softLimitBytesPerBatch, - minDataCountForSample, - softTimeoutMillsPurgeBatch) + conf.softLimitBytesPerBatchInApplyInPandasWithState, + conf.minDataCountForSampleInApplyInPandasWithState, + conf.softTimeoutMillisPurgeBatchInApplyInPandasWithState) val context = TaskContext.get() @@ -189,8 +174,8 @@ case class FlatMapGroupsInPandasWithStateExec( } runner.compute(processIter, context.partitionId(), context).flatMap { case (stateIter, outputIter) => - // When the iterator is consumed, then write changes to state - // state does not affect each others, hence when to update does not affect to the result + // When the iterator is consumed, then write changes to state. + // state does not affect each others, hence when to update does not affect to the result. def onIteratorCompletion: Unit = { stateIter.foreach { case (keyRow, newGroupState) => if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 1b7e0bf3c4e18..8b220cb202957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -251,14 +251,14 @@ private[sql] object GroupStateImpl { case _ => throw new IllegalStateException("Invalid string for GroupStateTimeout: " + clazz) } - def fromJson[S](key: Option[S], json: JValue): GroupStateImpl[S] = { + def fromJson[S](value: Option[S], json: JValue): GroupStateImpl[S] = { implicit val formats = org.json4s.DefaultFormats val hmap = json.extract[Map[String, Any]] // Constructor val newGroupState = new GroupStateImpl[S]( - key, + value, hmap("batchProcessingTimeMs").asInstanceOf[Number].longValue(), hmap("eventTimeWatermarkMs").asInstanceOf[Number].longValue(), groupStateTimeoutFromString(hmap("timeoutConf").asInstanceOf[String]), From d7ecaf944774a2d418863e4526a3ee327d4b807a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 2 Sep 2022 11:45:22 +0900 Subject: [PATCH 28/44] WIP sync --- .../python/ArrowPythonRunnerWithState.scala | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index cf78b5fea3760..f70142fb22057 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -77,17 +77,7 @@ class ArrowPythonRunnerWithState( "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") - private val stateMetadataSchema = StructType( - Array( - StructField("properties", StringType), - StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType), - StructField("startOffset", IntegerType), - StructField("numRows", IntegerType) - ) - ) - - private val schemaWithState = inputSchema.add("!__state__!", stateMetadataSchema) + private val schemaWithState = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) private val stateRowDeserializer = stateEncoder.createDeserializer() From 6a6dd205b257864f019af617d33c055d0ac88027 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 2 Sep 2022 15:56:03 +0900 Subject: [PATCH 29/44] WIP start with is_last_chunk since it's easier to implement... several tests are failing yet (and the existing test code is hard to make it work with multiple calls) --- python/pyspark/sql/pandas/serializers.py | 50 +++++++++++------ python/pyspark/worker.py | 14 +++-- .../python/ArrowPythonRunnerWithState.scala | 55 +++++++++++++++++-- .../FlatMapGroupsInPandasWithStateExec.scala | 2 - .../FlatMapGroupsInPandasWithStateSuite.scala | 11 ++-- 5 files changed, 98 insertions(+), 34 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index e5d592c875661..24386b7538b01 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -406,6 +406,7 @@ def load_stream(self, stream): batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + state_for_current_group = None for batch in batches: batch_schema = batch.schema data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) @@ -434,6 +435,7 @@ def load_stream(self, stream): data_start_offset = state_info_col['startOffset'] num_data_rows = state_info_col['numRows'] + is_last_chunk = state_info_col['isLastChunk'] state_properties = json.loads(state_info_col_properties) if state_info_col_object: @@ -442,17 +444,30 @@ def load_stream(self, stream): state_object = None state_properties["optionalValue"] = state_object - state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, - valueSchema=self.state_object_schema, **state_properties) + if state_for_current_group: + # use the state, we already have state for same group and there should be some + # data in same group being processed earlier + state = state_for_current_group + else: + # there is no state being stored for same group, construct one + state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, **state_properties) + + if is_last_chunk: + # discard the state being cached for same group + state_for_current_group = None + elif not state_for_current_group: + # there's no cached state but expected to have additional data in same group + # cache the current state + state_for_current_group = state data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) - data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] # state info - yield (data_pandas, state, ) + yield (data_pandas, state, is_last_chunk, ) def dump_stream(self, iterator, stream): """ @@ -511,7 +526,8 @@ def init_stream_yield_batches(): packaged_result = data[0] pdf = packaged_result[0][0] - state = packaged_result[0][-1] + state = packaged_result[0][1] + is_last_chunk = packaged_result[0][2] # this won't change across batches return_schema = packaged_result[1] @@ -522,20 +538,22 @@ def init_stream_yield_batches(): pdf_data_cnt += len(pdf) pdfs.append(pdf) - state_properties = state.json().encode("utf-8") - state_key_row_as_binary = state._keyAsUnsafe - state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + if is_last_chunk: + # pick up state for only last chunk as state should have been updated so far + state_properties = state.json().encode("utf-8") + state_key_row_as_binary = state._keyAsUnsafe + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) - state_dict = { - 'properties': [state_properties, ], - 'keyRowAsUnsafe': [state_key_row_as_binary, ], - 'object': [state_object, ], - } + state_dict = { + 'properties': [state_properties, ], + 'keyRowAsUnsafe': [state_key_row_as_binary, ], + 'object': [state_object, ], + } - state_pdf = pd.DataFrame.from_dict(state_dict) + state_pdf = pd.DataFrame.from_dict(state_dict) - state_pdfs.append(state_pdf) - state_data_cnt += 1 + state_pdfs.append(state_pdf) + state_data_cnt += 1 # FIXME: threshold of sample data if sampled_data_size_per_row == 0 and pdf_data_cnt > self.minDataCountForSample: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4fad720ed6035..011da20487024 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -210,15 +210,16 @@ def wrapped(key_series, value_series): def wrap_grouped_map_pandas_udf_with_state(f, return_type): - def wrapped(key_series, value_series, state): + def wrapped(key_series, value_series, state, is_last_chunk): import pandas as pd key = tuple(s[0] for s in key_series) if state.hasTimedOut: # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. - result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), state) + result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), + state, is_last_chunk) else: - result = f(key, pd.concat(value_series, axis=1), state) + result = f(key, pd.concat(value_series, axis=1), state, is_last_chunk) if not isinstance(result, pd.DataFrame): raise TypeError( @@ -236,9 +237,9 @@ def wrapped(key_series, value_series, state): "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) ) - return (result, state, ) + return (result, state, is_last_chunk, ) - return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] + return lambda k, v, s, l: [(wrapped(k, v, s, l), to_arrow_type(return_type))] def wrap_grouped_agg_pandas_udf(f, return_type): @@ -572,7 +573,8 @@ def mapper(a): keys = [a[0][o] for o in parsed_offsets[0][0]] vals = [a[0][o] for o in parsed_offsets[0][1]] state = a[1] - return f(keys, vals, state) + is_last_chunk = a[2] + return f(keys, vals, state, is_last_chunk) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index f70142fb22057..21aeaf6a02341 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -116,7 +116,8 @@ class ArrowPythonRunnerWithState( keyRow: UnsafeRow, groupState: GroupStateImpl[Row], startOffset: Int, - numRows: Int): InternalRow = { + numRows: Int, + isLastChunk: Boolean): InternalRow = { // NOTE: see ArrowPythonRunnerWithState.STATE_METADATA_SCHEMA val stateUnderlyingRow = new GenericInternalRow( Array[Any]( @@ -124,7 +125,8 @@ class ArrowPythonRunnerWithState( keyRow.getBytes, groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, startOffset, - numRows + numRows, + isLastChunk ) ) new GenericInternalRow(Array[Any](stateUnderlyingRow)) @@ -227,11 +229,50 @@ class ArrowPythonRunnerWithState( arrowWriterForData.write(dataRow) numRowsForCurGroup += 1 totalNumRowsForBatch += 1 + + // Currently, this only works when the number of rows are greater than the minimum + // data count for sampling. And we technically have no way to pick some rows from + // record batch and measure the size of data, hence we leverage all data in current + // record batch. We only sample once as it could be costly. + if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { + sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch + } + + // If it exceeds the condition of batch (only size, not about timeout) and + // there is more data for the same group, flush and construct a new batch. + + // if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch && + // dataIter.hasNext) { + // FIXME: DEBUGGING now... split the data per 10 elements <- 1 element for testing + if (numRowsForCurGroup % 10 == 1 && dataIter.hasNext) { + // Provide state metadata row as intermediate + val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, + numRowsForCurGroup, isLastChunk = false) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) + } + + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurGroup = 0 + numRowsForCurGroup = 0 + totalNumRowsForBatch = 0 + totalNumStatesForBatch = 0 + lastBatchPurgedMillis = System.currentTimeMillis() + } } // Provide state metadata row val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, - numRowsForCurGroup) + numRowsForCurGroup, isLastChunk = true) arrowWriterForState.write(stateInfoRow) totalNumStatesForBatch += 1 @@ -282,7 +323,7 @@ class ArrowPythonRunnerWithState( arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) } - arrowWriterForState.finish() + arrowWriterForState.finish() arrowWriterForData.finish() writer.writeBatch() arrowWriterForState.reset() @@ -470,10 +511,12 @@ object ArrowPythonRunnerWithState { StructField("keyRowAsUnsafe", BinaryType), StructField("object", BinaryType), StructField("startOffset", IntegerType), - StructField("numRows", IntegerType) + StructField("numRows", IntegerType), + StructField("isLastChunk", BooleanType) ) ) // To avoid initializing a new row for empty state metadata row. - val EMPTY_STATE_METADATA_ROW = new GenericInternalRow(Array[Any](null, null, null, null, null)) + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( + Array[Any](null, null, null, null, null, null)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 6971aa8bbd155..808afffdef1e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.python import org.apache.spark.TaskContext - import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -31,7 +30,6 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 671621aa43b3e..378e4b0607d58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -45,7 +45,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state): + |def func(key, pdf, state, is_last_chunk): | assert state.getCurrentProcessingTimeMs() >= 0 | try: | state.getCurrentWatermarkMs() @@ -64,7 +64,10 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | return pd.DataFrame() | else: | state.update((count,)) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | if is_last_chunk: + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | else: + | return pd.DataFrame() |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -121,7 +124,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("value", IntegerType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state): + |def func(key, pdf, state, is_last_chunk): | count = state.getOption | if count is None: | count = 0 @@ -184,7 +187,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state): + |def func(key, pdf, state, is_last_chunk): | count = state.getOption | if count is None: | count = 0 From 5cfd59c5de9be0c1279f37efd4c5b02ea1efc5a4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 2 Sep 2022 18:44:55 +0900 Subject: [PATCH 30/44] WIP adjust the test code to make test pass with multiple calls --- .../FlatMapGroupsInPandasWithStateSuite.scala | 75 +++++++++++-------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 378e4b0607d58..e4ce785a3abe2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -59,15 +59,16 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | else: | count = count[0] | count += len(pdf) - | if count == 3: - | state.remove() - | return pd.DataFrame() - | else: - | state.update((count,)) - | if is_last_chunk: - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | state.update((count,)) + | + | ret = pd.DataFrame() + | if is_last_chunk: + | if count >= 3: + | state.remove() | else: - | return pd.DataFrame() + | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | + | return ret |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -122,7 +123,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { |tpe = StructType([ | StructField("key", StringType()), | StructField("value", IntegerType()), - | StructField("countAsString", StringType())]) + | StructField("valueAsString", StringType())]) | |def func(key, pdf, state, is_last_chunk): | count = state.getOption @@ -132,7 +133,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = count[0] | count = count + len(pdf) | state.update((count,)) - | return pdf.assign(countAsString=str(count)) + | return pdf.assign(valueAsString=lambda x: x.value.apply(str)) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -143,7 +144,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { Seq( StructField("key", StringType), StructField("value", IntegerType), - StructField("countAsString", StringType))) + StructField("valueAsString", StringType))) val stateStructType = StructType(Seq(StructField("count", LongType))) val inputDataDS = inputData.toDS().selectExpr("_1 AS key", "_2 AS value") val result = @@ -155,14 +156,14 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { stateStructType, "Update", "NoTimeout") - .select("key", "value", "countAsString") + .select("key", "value", "valueAsString") testStream(result, Update)( AddData(inputData, ("a", 1)), CheckNewAnswer(("a", 1, "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, ("a", 2), ("a", 3), ("b", 1)), - CheckNewAnswer(("a", 2, "3"), ("a", 3, "3"), ("b", 1, "1")), + CheckNewAnswer(("a", 2, "2"), ("a", 3, "3"), ("b", 1, "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), @@ -193,13 +194,17 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = 0 | else: | count = count[0] - | count += len(pdf) - | if count == 3: - | state.remove() - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) - | else: - | state.update((count,)) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | state.update((count,)) + | + | ret = pd.DataFrame() + | if is_last_chunk: + | if count >= 3: + | state.remove() + | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | + | return ret |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -258,7 +263,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state): + |def func(key, pdf, state, is_last_chunk): | assert state.getCurrentProcessingTimeMs() >= 0 | try: | state.getCurrentWatermarkMs() @@ -277,8 +282,11 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = count[0] | count += len(pdf) | state.update((count,)) - | state.setTimeoutDuration(10000) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | if is_last_chunk: + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | else: + | return pd.DataFrame() |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -361,7 +369,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("maxEventTimeSec", IntegerType())]) | - |def func(key, pdf, state): + |def func(key, pdf, state, is_last_chunk): | assert state.getCurrentProcessingTimeMs() >= 0 | assert state.getCurrentWatermarkMs() >= -1 | @@ -379,10 +387,14 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | pser = pdf.eventTime.apply( | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) | max_event_time_sec = int(max(pser.max(), m)) - | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec | state.update((max_event_time_sec,)) - | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) - | return pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [max_event_time_sec]}) + | if is_last_chunk: + | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec + | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + | return pd.DataFrame({'key': [key[0]], + | 'maxEventTimeSec': [max_event_time_sec]}) + | else: + | return pd.DataFrame() |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -447,7 +459,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state): + |def func(key, pdf, state, is_last_chunk): | if state.hasTimedOut: | state.remove() | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) @@ -459,8 +471,11 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = count[0] | count += len(pdf) | state.update((count,)) - | state.setTimeoutDuration(10000) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | if is_last_chunk: + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | else: + | return pd.DataFrame() |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) From 63f8f87e4c58ea4c249c094bb4f0c74978208859 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 5 Sep 2022 11:20:19 +0900 Subject: [PATCH 31/44] WIP refactor a bit... just extract the abstract classes to explicit ones --- .../python/ArrowPythonRunnerWithState.scala | 718 +++++++++--------- 1 file changed, 357 insertions(+), 361 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 21aeaf6a02341..a456ddf16f1dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter import org.apache.spark.sql.execution.streaming.GroupStateImpl @@ -67,8 +67,6 @@ class ArrowPythonRunnerWithState( (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])]( funcs, evalType, argOffsets) { - import ArrowPythonRunnerWithState._ - override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize @@ -77,10 +75,6 @@ class ArrowPythonRunnerWithState( "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") - private val schemaWithState = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) - - private val stateRowDeserializer = stateEncoder.createDeserializer() - private val workerConfWithRunnerConfs = workerConf + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> softLimitBytesPerBatch.toString) + @@ -105,185 +99,176 @@ class ArrowPythonRunnerWithState( inputIterator: Iterator[(UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])], partitionIndex: Int, context: TaskContext): WriterThread = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { + new StateWriterThread(env, worker, inputIterator, partitionIndex, context) + } - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - handleMetadataBeforeExec(dataOut) - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext) + : Iterator[(Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])] = { + new StateReaderIterator(stream, writerThread, startTime, env, worker, pid, + releasedOrClosed, context) + } - private def buildStateInfoRow( - keyRow: UnsafeRow, - groupState: GroupStateImpl[Row], - startOffset: Int, - numRows: Int, - isLastChunk: Boolean): InternalRow = { - // NOTE: see ArrowPythonRunnerWithState.STATE_METADATA_SCHEMA - val stateUnderlyingRow = new GenericInternalRow( - Array[Any]( - UTF8String.fromString(groupState.json()), - keyRow.getBytes, - groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, - startOffset, - numRows, - isLastChunk - ) + private class StateWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[(UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])], + partitionIndex: Int, + context: TaskContext) + extends WriterThread(env, worker, inputIterator, partitionIndex, context) { + + import StateWriterThread._ + + private val schemaWithState = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + handleMetadataBeforeExec(dataOut) + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + private def buildStateInfoRow( + keyRow: UnsafeRow, + groupState: GroupStateImpl[Row], + startOffset: Int, + numRows: Int, + isLastChunk: Boolean): InternalRow = { + // NOTE: see ArrowPythonRunnerWithState.STATE_METADATA_SCHEMA + val stateUnderlyingRow = new GenericInternalRow( + Array[Any]( + UTF8String.fromString(groupState.json()), + keyRow.getBytes, + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, + startOffset, + numRows, + isLastChunk ) - new GenericInternalRow(Array[Any](stateUnderlyingRow)) - } + ) + new GenericInternalRow(Array[Any](stateUnderlyingRow)) + } - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - // We initialize all columns in data & state metadata for Arrow RecordBatch. - val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - - Utils.tryWithSafeFinally { - // We logically group the columns by family and initialize writer separately, since it's - // lot more easier and probably performant to write the row directly rather than - // projecting the row to match up with the overall schema. - // The number of data rows and state metadata rows can be different which seems to matter - // for Arrow RecordBatch, so we append empty rows to cover it. - // We always produce at least one data row per grouping key whereas we only produce one - // state metadata row per grouping key, so we only need to fill up the empty rows in - // state metadata side. - val arrowWriterForData = { - val children = root.getFieldVectors().asScala.dropRight(1).map { vector => - vector.allocateNew() - createFieldWriter(vector) - } + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + // We initialize all columns in data & state metadata for Arrow RecordBatch. + val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + Utils.tryWithSafeFinally { + // We logically group the columns by family and initialize writer separately, since it's + // lot more easier and probably performant to write the row directly rather than + // projecting the row to match up with the overall schema. + // The number of data rows and state metadata rows can be different which seems to matter + // for Arrow RecordBatch, so we append empty rows to cover it. + // We always produce at least one data row per grouping key whereas we only produce one + // state metadata row per grouping key, so we only need to fill up the empty rows in + // state metadata side. + val arrowWriterForData = { + val children = root.getFieldVectors().asScala.dropRight(1).map { vector => + vector.allocateNew() + createFieldWriter(vector) + } - new ArrowWriter(root, children.toArray) + new ArrowWriter(root, children.toArray) + } + val arrowWriterForState = { + val children = root.getFieldVectors().asScala.takeRight(1).map { vector => + vector.allocateNew() + createFieldWriter(vector) } - val arrowWriterForState = { - val children = root.getFieldVectors().asScala.takeRight(1).map { vector => - vector.allocateNew() - createFieldWriter(vector) - } - new ArrowWriter(root, children.toArray) + new ArrowWriter(root, children.toArray) + } + + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to + // gain the performance. In many cases, the amount of data per grouping key is quite + // small, which does not seem to maximize the benefits of using Arrow. + // + // We have to split the record batch down to each group in Python worker to convert the + // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split + // the range of data and give a view, say, "zero-copy". To help splitting the range for + // data, we provide the "start offset" and the "number of data" in the state metadata. + // + // FIXME: probably need to change this as "hard limit" when addressing scalability. Worth + // noting that we may need to break down the data into chunks for a specific group + // having "small" number of data, because we also do bin-packing as well. Maybe we could + // concatenate these chunks in Python worker (serializer), with some hints e.g. + // We can get the information - the number of data in the chunk before reading. + // + // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft + // limit on the size - it's not a hard limit since we allow current group to write all + // data even it's going to exceed the limit. + // + // We perform some basic sampling for data to guess the size of the data very roughly, + // and simply multiply by the number of data to estimate the size. We extract the size of + // data from the record batch rather than UnsafeRow, as we don't hold the memory for + // UnsafeRow once we write to the record batch. If there is a memory bound here, it + // should come from record batch. + // + // In the meanwhile, we don't also want to let the current record batch collect the data + // indefinitely, since we are pipelining the process between executor and python worker. + // Python worker won't process any data if executor is not yet finalized a record + // batch, which defeats the purpose of pipelining. To address this, we also introduce + // timeout for constructing a record batch. This is a soft limit indeed as same as limit + // on the size - we allow current group to write all data even it's timed-out. + + // FIXME: Maybe better if we can extract out the batching logic into a separate class. + var numRowsForCurGroup = 0 + var startOffsetForCurGroup = 0 + var totalNumRowsForBatch = 0 + var totalNumStatesForBatch = 0 + + var sampledDataSizePerRow = 0 + var lastBatchPurgedMillis = System.currentTimeMillis() + + def finalizeCurrentArrowBatch(): Unit = { + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) } - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() - - // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to - // gain the performance. In many cases, the amount of data per grouping key is quite - // small, which does not seem to maximize the benefits of using Arrow. - // - // We have to split the record batch down to each group in Python worker to convert the - // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split - // the range of data and give a view, say, "zero-copy". To help splitting the range for - // data, we provide the "start offset" and the "number of data" in the state metadata. - // - // FIXME: probably need to change this as "hard limit" when addressing scalability. Worth - // noting that we may need to break down the data into chunks for a specific group - // having "small" number of data, because we also do bin-packing as well. Maybe we could - // concatenate these chunks in Python worker (serializer), with some hints e.g. - // We can get the information - the number of data in the chunk before reading. - // - // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft - // limit on the size - it's not a hard limit since we allow current group to write all - // data even it's going to exceed the limit. - // - // We perform some basic sampling for data to guess the size of the data very roughly, - // and simply multiply by the number of data to estimate the size. We extract the size of - // data from the record batch rather than UnsafeRow, as we don't hold the memory for - // UnsafeRow once we write to the record batch. If there is a memory bound here, it - // should come from record batch. - // - // In the meanwhile, we don't also want to let the current record batch collect the data - // indefinitely, since we are pipelining the process between executor and python worker. - // Python worker won't process any data if executor is not yet finalized a record - // batch, which defeats the purpose of pipelining. To address this, we also introduce - // timeout for constructing a record batch. This is a soft limit indeed as same as limit - // on the size - we allow current group to write all data even it's timed-out. - - // FIXME: Maybe better if we can extract out the batching logic into a separate class. - var numRowsForCurGroup = 0 - var startOffsetForCurGroup = 0 - var totalNumRowsForBatch = 0 - var totalNumStatesForBatch = 0 - - var sampledDataSizePerRow = 0 - var lastBatchPurgedMillis = System.currentTimeMillis() - - while (inputIterator.hasNext) { - val (keyRow, groupState, dataIter) = inputIterator.next() - - assert(dataIter.hasNext, "should have at least one data row!") - - numRowsForCurGroup = 0 - - // Provide data rows - while (dataIter.hasNext) { - val dataRow = dataIter.next() - // TODO: if we think there will be non-small amount of data per grouping key, - // we could probably try out "dictionary encoding" for the optimization - // of storing same grouping keys multiple times. This may complicate the logic, as - // in IPC streaming format, DictionaryBatch will be provided separately along with - // RecordBatch, and I'm not sure whether the record batch can be directly converted - // to Pandas DataFrame / Series if the record batch refers to the dictionary batch. - // https://arrow.apache.org/docs/format/Columnar.html#dictionary-encoded-layout - // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format - arrowWriterForData.write(dataRow) - numRowsForCurGroup += 1 - totalNumRowsForBatch += 1 - - // Currently, this only works when the number of rows are greater than the minimum - // data count for sampling. And we technically have no way to pick some rows from - // record batch and measure the size of data, hence we leverage all data in current - // record batch. We only sample once as it could be costly. - if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { - sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch - } - - // If it exceeds the condition of batch (only size, not about timeout) and - // there is more data for the same group, flush and construct a new batch. - - // if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch && - // dataIter.hasNext) { - // FIXME: DEBUGGING now... split the data per 10 elements <- 1 element for testing - if (numRowsForCurGroup % 10 == 1 && dataIter.hasNext) { - // Provide state metadata row as intermediate - val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, - numRowsForCurGroup, isLastChunk = false) - arrowWriterForState.write(stateInfoRow) - totalNumStatesForBatch += 1 - - val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch - (0 until remainingEmptyStateRows).foreach { _ => - arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) - } - - arrowWriterForState.finish() - arrowWriterForData.finish() - writer.writeBatch() - arrowWriterForState.reset() - arrowWriterForData.reset() - - startOffsetForCurGroup = 0 - numRowsForCurGroup = 0 - totalNumRowsForBatch = 0 - totalNumStatesForBatch = 0 - lastBatchPurgedMillis = System.currentTimeMillis() - } - } + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurGroup = 0 + numRowsForCurGroup = 0 + totalNumRowsForBatch = 0 + totalNumStatesForBatch = 0 + lastBatchPurgedMillis = System.currentTimeMillis() + } + + while (inputIterator.hasNext) { + val (keyRow, groupState, dataIter) = inputIterator.next() - // Provide state metadata row - val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, - numRowsForCurGroup, isLastChunk = true) - arrowWriterForState.write(stateInfoRow) - totalNumStatesForBatch += 1 + assert(dataIter.hasNext, "should have at least one data row!") - // The start offset for next group would be same as the total number of rows for batch, - // unless the next group starts with new batch. - startOffsetForCurGroup = totalNumRowsForBatch + numRowsForCurGroup = 0 - // FIXME: Do we need to come up with sampling "across record batches"? - // FIXME: Do we need to also come up with the size of state metadata as well? - // FIXME: Do we need to separate the case of "state with value" vs - // "state without value" on sampling? + // Provide data rows + while (dataIter.hasNext) { + val dataRow = dataIter.next() + // TODO: if we think there will be non-small amount of data per grouping key, + // we could probably try out "dictionary encoding" for the optimization + // of storing same grouping keys multiple times. This may complicate the logic, as + // in IPC streaming format, DictionaryBatch will be provided separately along with + // RecordBatch, and I'm not sure whether the record batch can be directly converted + // to Pandas DataFrame / Series if the record batch refers to the dictionary batch. + // https://arrow.apache.org/docs/format/Columnar.html#dictionary-encoded-layout + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + arrowWriterForData.write(dataRow) + numRowsForCurGroup += 1 + totalNumRowsForBatch += 1 // Currently, this only works when the number of rows are greater than the minimum // data count for sampling. And we technically have no way to pick some rows from @@ -293,66 +278,99 @@ class ArrowPythonRunnerWithState( sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch } - // The soft-limit on size effectively works after the sampling has completed, since we - // multiply the number of rows by 0 if the sampling is still in progress. The - // soft-limit on timeout always applies. - if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch || - System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { - val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch - (0 until remainingEmptyStateRows).foreach { _ => - arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) - } - - arrowWriterForState.finish() - arrowWriterForData.finish() - writer.writeBatch() - arrowWriterForState.reset() - arrowWriterForData.reset() - - startOffsetForCurGroup = 0 - totalNumRowsForBatch = 0 - totalNumStatesForBatch = 0 - lastBatchPurgedMillis = System.currentTimeMillis() + // If it exceeds the condition of batch (only size, not about timeout) and + // there is more data for the same group, flush and construct a new batch. + + // if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch && + // dataIter.hasNext) { + // FIXME: DEBUGGING now... split the data per 10 elements <- 1 element for testing + if (numRowsForCurGroup % 10 == 1 && dataIter.hasNext) { + // Provide state metadata row as intermediate + val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, + numRowsForCurGroup, isLastChunk = false) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + finalizeCurrentArrowBatch() } } - if (numRowsForCurGroup > 0) { - // We still have some rows in the current record batch. Need to flush them as well. - val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch - (0 until remainingEmptyStateRows).foreach { _ => - arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) - } + // Provide state metadata row + val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, + numRowsForCurGroup, isLastChunk = true) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + // The start offset for next group would be same as the total number of rows for batch, + // unless the next group starts with new batch. + startOffsetForCurGroup = totalNumRowsForBatch + + // FIXME: Do we need to come up with sampling "across record batches"? + // FIXME: Do we need to also come up with the size of state metadata as well? + // FIXME: Do we need to separate the case of "state with value" vs + // "state without value" on sampling? + + // Currently, this only works when the number of rows are greater than the minimum + // data count for sampling. And we technically have no way to pick some rows from + // record batch and measure the size of data, hence we leverage all data in current + // record batch. We only sample once as it could be costly. + if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { + sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch + } - arrowWriterForState.finish() - arrowWriterForData.finish() - writer.writeBatch() - arrowWriterForState.reset() - arrowWriterForData.reset() + // The soft-limit on size effectively works after the sampling has completed, since we + // multiply the number of rows by 0 if the sampling is still in progress. The + // soft-limit on timeout always applies. + if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch || + System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { + finalizeCurrentArrowBatch() } + } - // end writes footer to the output stream and doesn't clean any resources. - // It could throw exception if the output stream is closed, so it should be - // in the try block. - writer.end() - } { - // If we close root and allocator in TaskCompletionListener, there could be a race - // condition where the writer thread keeps writing to the VectorSchemaRoot while - // it's being closed by the TaskCompletion listener. - // Closing root and allocator here is cleaner because root and allocator is owned - // by the writer thread and is only visible to the writer thread. - // - // If the writer thread is interrupted by TaskCompletionListener, it should either - // (1) in the try block, in which case it will get an InterruptedException when - // performing io, and goes into the finally block or (2) in the finally block, - // in which case it will ignore the interruption and close the resources. - root.close() - allocator.close() + if (numRowsForCurGroup > 0) { + // We still have some rows in the current record batch. Need to flush them as well. + finalizeCurrentArrowBatch() } + + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. + root.close() + allocator.close() } } } - protected def newReaderIterator( + object StateWriterThread { + val STATE_METADATA_SCHEMA: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("startOffset", IntegerType), + StructField("numRows", IntegerType), + StructField("isLastChunk", BooleanType) + ) + ) + + // To avoid initializing a new row for empty state metadata row. + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( + Array[Any](null, null, null, null, null, null)) + } + + class StateReaderIterator( stream: DataInputStream, writerThread: WriterThread, startTime: Long, @@ -361,162 +379,140 @@ class ArrowPythonRunnerWithState( pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext) - : Iterator[(Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])] = { + extends ReaderIterator(stream, writerThread, startTime, env, worker, pid, + releasedOrClosed, context) { - new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + private val stateRowDeserializer = stateEncoder.createDeserializer() - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - private var unsafeProjForData: UnsafeProjection = _ + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) } + allocator.close() + } - private var batchLoaded = true + private var batchLoaded = true - protected override def read() - : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - deserializeColumnarBatch(batch) - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } + protected override def read() + : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + deserializeColumnarBatch(batch) } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - - val dataAttributes = schema(0).dataType.asInstanceOf[StructType].toAttributes - unsafeProjForData = UnsafeProjection.create(dataAttributes, dataAttributes) - - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() } - } catch handleException - } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + + // FIXME: should we validate schema here with value schema? + // FIXME: should we validate schema here with state metadata schema? + + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } - private def deserializeColumnarBatch(batch: ColumnarBatch) - : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { - // This should at least have one row for state. Also, we ensure that all columns across - // data and state metadata have same number of rows, which is required by Arrow record - // batch. - assert(batch.numRows() > 0) - assert(schema.length == 2) - - def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { - // UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - // FIXME: should we validate schema here with value schema? - val outputVectors = schema(0).dataType.asInstanceOf[StructType] - .indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - - flattenedBatch.rowIterator.asScala.flatMap { row => - if (row.isNullAt(0)) { - // The entire row in record batch seems to be for state metadata. - None - } else { - // FIXME: would it work without this projection? - Some(unsafeProjForData(row)) - } + private def deserializeColumnarBatch(batch: ColumnarBatch) + : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { + // This should at least have one row for state. Also, we ensure that all columns across + // data and state metadata have same number of rows, which is required by Arrow record + // batch. + assert(batch.numRows() > 0) + assert(schema.length == 2) + + def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = schema(0).dataType.asInstanceOf[StructType] + .indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + flattenedBatch.rowIterator.asScala.flatMap { row => + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for state metadata. + None + } else { + Some(row) } } + } - def constructIterForState( - batch: ColumnarBatch): Iterator[(UnsafeRow, GroupStateImpl[Row])] = { - // UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(1).asInstanceOf[ArrowColumnVector] - // FIXME: should we validate schema here with state metadata schema? - val outputVectors = schema(1).dataType.asInstanceOf[StructType] - .indices.map(structVector.getChild) - val flattenedBatchForState = new ColumnarBatch(outputVectors.toArray) - flattenedBatchForState.setNumRows(batch.numRows()) - - flattenedBatchForState.rowIterator().asScala.flatMap { row => - implicit val formats = org.json4s.DefaultFormats - - if (row.isNullAt(0)) { - // The entire row in record batch seems to be for data. + def constructIterForState( + batch: ColumnarBatch): Iterator[(UnsafeRow, GroupStateImpl[Row])] = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(1).asInstanceOf[ArrowColumnVector] + + val outputVectors = schema(1).dataType.asInstanceOf[StructType] + .indices.map(structVector.getChild) + val flattenedBatchForState = new ColumnarBatch(outputVectors.toArray) + flattenedBatchForState.setNumRows(batch.numRows()) + flattenedBatchForState.rowIterator().asScala.flatMap { row => + implicit val formats = org.json4s.DefaultFormats + + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for data. + None + } else { + // Received state metadata does not need schema - this class already knows them. + // Array( + // StructField("properties", StringType), + // StructField("keyRowAsUnsafe", BinaryType), + // StructField("object", BinaryType), + // ) + // TODO: Do we want to rely on the column name rather than the ordinal for safety? + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { None } else { - // Received state metadata does not need schema - this class already knows them. - // Array( - // StructField("properties", StringType), - // StructField("keyRowAsUnsafe", BinaryType), - // StructField("object", BinaryType), - // ) - // TODO: Do we want to rely on the column name rather than the ordinal for safety? - val propertiesAsJson = parse(row.getUTF8String(0).toString) - val keyRowAsUnsafeAsBinary = row.getBinary(1) - val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) - keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) - val maybeObjectRow = if (row.isNullAt(2)) { - None - } else { - val pickledStateValue = row.getBinary(2) - Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, - stateRowDeserializer)) - } - - Some(keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)) + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) } + + Some(keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)) } } - - (constructIterForState(batch), constructIterForData(batch)) } + + (constructIterForState(batch), constructIterForData(batch)) } } } - -object ArrowPythonRunnerWithState { - val STATE_METADATA_SCHEMA: StructType = StructType( - Array( - StructField("properties", StringType), - StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType), - StructField("startOffset", IntegerType), - StructField("numRows", IntegerType), - StructField("isLastChunk", BooleanType) - ) - ) - - // To avoid initializing a new row for empty state metadata row. - val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( - Array[Any](null, null, null, null, null, null)) -} From 6e772cd736cbaa7af00b1d2a6177c80f76c922c8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 5 Sep 2022 16:21:08 +0900 Subject: [PATCH 32/44] WIP iterator of DatFrame done! updated tests and they all passed --- python/pyspark/sql/pandas/serializers.py | 135 +++++++++--------- python/pyspark/worker.py | 32 +++-- .../FlatMapGroupsInPandasWithStateSuite.scala | 105 +++++++------- 3 files changed, 150 insertions(+), 122 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 24386b7538b01..86c144220b443 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -402,72 +402,81 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema def load_stream(self, stream): import pyarrow as pa import json + from itertools import groupby from pyspark.sql.streaming.state import GroupStateImpl - batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + def gen_data_and_state(batches): + state_for_current_group = None - state_for_current_group = None - for batch in batches: - batch_schema = batch.schema - data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) - state_schema = pa.schema([batch_schema[-1], ]) + for batch in batches: + batch_schema = batch.schema + data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) + state_schema = pa.schema([batch_schema[-1], ]) - batch_columns = batch.columns - data_columns = batch_columns[0:-1] - state_column = batch_columns[-1] + batch_columns = batch.columns + data_columns = batch_columns[0:-1] + state_column = batch_columns[-1] - data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) - state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) + data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) + state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) - state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] + state_arrow = pa.Table.from_batches([state_batch]).itercolumns() + state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] - for state_idx in range(0, len(state_pandas)): - state_info_col = state_pandas.iloc[state_idx] + for state_idx in range(0, len(state_pandas)): + state_info_col = state_pandas.iloc[state_idx] - if not state_info_col: - # no more data with grouping key + state - break + if not state_info_col: + # no more data with grouping key + state + break - state_info_col_properties = state_info_col['properties'] - state_info_col_key_row = state_info_col['keyRowAsUnsafe'] - state_info_col_object = state_info_col['object'] + state_info_col_properties = state_info_col['properties'] + state_info_col_key_row = state_info_col['keyRowAsUnsafe'] + state_info_col_object = state_info_col['object'] - data_start_offset = state_info_col['startOffset'] - num_data_rows = state_info_col['numRows'] - is_last_chunk = state_info_col['isLastChunk'] + data_start_offset = state_info_col['startOffset'] + num_data_rows = state_info_col['numRows'] + is_last_chunk = state_info_col['isLastChunk'] - state_properties = json.loads(state_info_col_properties) - if state_info_col_object: - state_object = self.pickleSer.loads(state_info_col_object) - else: - state_object = None - state_properties["optionalValue"] = state_object + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object - if state_for_current_group: - # use the state, we already have state for same group and there should be some - # data in same group being processed earlier - state = state_for_current_group - else: - # there is no state being stored for same group, construct one - state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, - valueSchema=self.state_object_schema, **state_properties) + if state_for_current_group: + # use the state, we already have state for same group and there should be some + # data in same group being processed earlier + state = state_for_current_group + else: + # there is no state being stored for same group, construct one + state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties) - if is_last_chunk: - # discard the state being cached for same group - state_for_current_group = None - elif not state_for_current_group: - # there's no cached state but expected to have additional data in same group - # cache the current state - state_for_current_group = state + if is_last_chunk: + # discard the state being cached for same group + state_for_current_group = None + elif not state_for_current_group: + # there's no cached state but expected to have additional data in same group + # cache the current state + state_for_current_group = state - data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) - data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() + data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) + data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() + + data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] + + # state info + yield (data_pandas, state, ) + + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) - data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] + data_state_generator = gen_data_and_state(batches) - # state info - yield (data_pandas, state, is_last_chunk, ) + for state, data in groupby(data_state_generator, key=lambda x: x[1]): + yield (data, state,) def dump_stream(self, iterator, stream): """ @@ -527,7 +536,6 @@ def init_stream_yield_batches(): pdf = packaged_result[0][0] state = packaged_result[0][1] - is_last_chunk = packaged_result[0][2] # this won't change across batches return_schema = packaged_result[1] @@ -538,22 +546,21 @@ def init_stream_yield_batches(): pdf_data_cnt += len(pdf) pdfs.append(pdf) - if is_last_chunk: - # pick up state for only last chunk as state should have been updated so far - state_properties = state.json().encode("utf-8") - state_key_row_as_binary = state._keyAsUnsafe - state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + # pick up state for only last chunk as state should have been updated so far + state_properties = state.json().encode("utf-8") + state_key_row_as_binary = state._keyAsUnsafe + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) - state_dict = { - 'properties': [state_properties, ], - 'keyRowAsUnsafe': [state_key_row_as_binary, ], - 'object': [state_object, ], - } + state_dict = { + 'properties': [state_properties, ], + 'keyRowAsUnsafe': [state_key_row_as_binary, ], + 'object': [state_object, ], + } - state_pdf = pd.DataFrame.from_dict(state_dict) + state_pdf = pd.DataFrame.from_dict(state_dict) - state_pdfs.append(state_pdf) - state_data_cnt += 1 + state_pdfs.append(state_pdf) + state_data_cnt += 1 # FIXME: threshold of sample data if sampled_data_size_per_row == 0 and pdf_data_cnt > self.minDataCountForSample: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 011da20487024..7f51fbe88f68e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -210,16 +210,18 @@ def wrapped(key_series, value_series): def wrap_grouped_map_pandas_udf_with_state(f, return_type): - def wrapped(key_series, value_series, state, is_last_chunk): + def wrapped(key_series, value_series_gen, state): import pandas as pd key = tuple(s[0] for s in key_series) + if state.hasTimedOut: # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. - result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), - state, is_last_chunk) + values = [pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), ] else: - result = f(key, pd.concat(value_series, axis=1), state, is_last_chunk) + values = (pd.concat(x, axis=1) for x in value_series_gen) + + result = f(key, values, state) if not isinstance(result, pd.DataFrame): raise TypeError( @@ -237,9 +239,9 @@ def wrapped(key_series, value_series, state, is_last_chunk): "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) ) - return (result, state, is_last_chunk, ) + return (result, state, ) - return lambda k, v, s, l: [(wrapped(k, v, s, l), to_arrow_type(return_type))] + return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] def wrap_grouped_agg_pandas_udf(f, return_type): @@ -570,11 +572,21 @@ def mapper(a): parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): - keys = [a[0][o] for o in parsed_offsets[0][0]] - vals = [a[0][o] for o in parsed_offsets[0][1]] + from itertools import tee + state = a[1] - is_last_chunk = a[2] - return f(keys, vals, state, is_last_chunk) + data_gen = (x[0] for x in a[0]) + + # We know there should be at least one item in the iterator/generator. + # We want to peek the first element to construct the key, hence applying + # tee to construct the key while we retain another iterator/generator + # for values. + keys_gen, values_gen = tee(data_gen) + keys_elem = next(keys_gen) + keys = [keys_elem[o] for o in parsed_offsets[0][0]] + vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen) + + return f(keys, vals, state) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index e4ce785a3abe2..5b8c0365a3cf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -45,7 +45,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state, is_last_chunk): + |def func(key, pdf_iter, state): | assert state.getCurrentProcessingTimeMs() >= 0 | try: | state.getCurrentWatermarkMs() @@ -58,15 +58,16 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = 0 | else: | count = count[0] - | count += len(pdf) - | state.update((count,)) + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) | | ret = pd.DataFrame() - | if is_last_chunk: - | if count >= 3: - | state.remove() - | else: - | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | if count >= 3: + | state.remove() + | else: + | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) | | return ret |""".stripMargin @@ -123,17 +124,23 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { |tpe = StructType([ | StructField("key", StringType()), | StructField("value", IntegerType()), - | StructField("valueAsString", StringType())]) + | StructField("countAsString", StringType())]) | - |def func(key, pdf, state, is_last_chunk): + |def func(key, pdf_iter, state): | count = state.getOption | if count is None: | count = 0 | else: | count = count[0] - | count = count + len(pdf) + | + | pdf_list = [] + | for pdf in pdf_iter: + | count += len(pdf) + | pdf_list.append(pdf) + | | state.update((count,)) - | return pdf.assign(valueAsString=lambda x: x.value.apply(str)) + | pdf_concat = pd.concat(pdf_list) + | return pdf_concat.assign(countAsString=str(count)) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -144,7 +151,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { Seq( StructField("key", StringType), StructField("value", IntegerType), - StructField("valueAsString", StringType))) + StructField("countAsString", StringType))) val stateStructType = StructType(Seq(StructField("count", LongType))) val inputDataDS = inputData.toDS().selectExpr("_1 AS key", "_2 AS value") val result = @@ -156,14 +163,14 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { stateStructType, "Update", "NoTimeout") - .select("key", "value", "valueAsString") + .select("key", "value", "countAsString") testStream(result, Update)( AddData(inputData, ("a", 1)), CheckNewAnswer(("a", 1, "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, ("a", 2), ("a", 3), ("b", 1)), - CheckNewAnswer(("a", 2, "2"), ("a", 3, "3"), ("b", 1, "1")), + CheckNewAnswer(("a", 2, "3"), ("a", 3, "3"), ("b", 1, "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), @@ -188,21 +195,24 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state, is_last_chunk): + |def func(key, pdf_iter, state): | count = state.getOption | if count is None: | count = 0 | else: | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | | state.update((count,)) | | ret = pd.DataFrame() - | if is_last_chunk: - | if count >= 3: - | state.remove() - | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) - | else: - | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | if count >= 3: + | state.remove() + | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) | | return ret |""".stripMargin @@ -263,7 +273,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state, is_last_chunk): + |def func(key, pdf_iter, state): | assert state.getCurrentProcessingTimeMs() >= 0 | try: | state.getCurrentWatermarkMs() @@ -280,13 +290,13 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = 0 | else: | count = count[0] - | count += len(pdf) + | + | for pdf in pdf_iter: + | count += len(pdf) + | | state.update((count,)) - | if is_last_chunk: - | state.setTimeoutDuration(10000) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) - | else: - | return pd.DataFrame() + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -369,7 +379,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("maxEventTimeSec", IntegerType())]) | - |def func(key, pdf, state, is_last_chunk): + |def func(key, pdf_iter, state): | assert state.getCurrentProcessingTimeMs() >= 0 | assert state.getCurrentWatermarkMs() >= -1 | @@ -380,21 +390,20 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | else: | m = state.getOption | if m is None: - | m = 0 + | max_event_time_sec = 0 | else: - | m = m[0] + | max_event_time_sec = m[0] | - | pser = pdf.eventTime.apply( + | for pdf in pdf_iter: + | pser = pdf.eventTime.apply( | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) - | max_event_time_sec = int(max(pser.max(), m)) + | max_event_time_sec = int(max(pser.max(), max_event_time_sec)) + | | state.update((max_event_time_sec,)) - | if is_last_chunk: - | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec - | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) - | return pd.DataFrame({'key': [key[0]], - | 'maxEventTimeSec': [max_event_time_sec]}) - | else: - | return pd.DataFrame() + | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec + | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + | return pd.DataFrame({'key': [key[0]], + | 'maxEventTimeSec': [max_event_time_sec]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -459,7 +468,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | StructField("key", StringType()), | StructField("countAsString", StringType())]) | - |def func(key, pdf, state, is_last_chunk): + |def func(key, pdf_iter, state): | if state.hasTimedOut: | state.remove() | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) @@ -469,13 +478,13 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count = 0 | else: | count = count[0] - | count += len(pdf) + | + | for pdf in pdf_iter: + | count += len(pdf) + | | state.update((count,)) - | if is_last_chunk: - | state.setTimeoutDuration(10000) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) - | else: - | return pd.DataFrame() + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) From 00836b53f60632d0a7b7ef9515ed5a8d259d8846 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 6 Sep 2022 13:43:05 +0900 Subject: [PATCH 33/44] WIP FIX pyspark side test failure --- .../sql/tests/test_pandas_grouped_map_with_state.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py index a9a56c557fabd..414915c574e35 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py @@ -58,11 +58,16 @@ def test_apply_in_pandas_with_state_basic(self): ) state_type = StructType([StructField("c", LongType())]) - def func(key, pdf, state): + def func(key, pdf_iter, state): assert isinstance(state, GroupStateImpl) - state.update((len(pdf),)) + + total_len = 0 + for pdf in pdf_iter: + total_len += len(pdf) + + state.update((total_len,)) assert state.get[0] == 1 - return pd.DataFrame({"key": [key[0]], "countAsString": [str(len(pdf))]}) + return pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]}) def check_results(batch_df, _): self.assertEqual( From 5fdde94a8049599a2eff7b993614c73dda4ea966 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 14 Sep 2022 14:51:22 +0900 Subject: [PATCH 34/44] WIP sort out codebase a bit --- .../apache/spark/api/python/PythonRunner.scala | 1 - python/pyspark/serializers.py | 1 - python/pyspark/sql/pandas/_typing/__init__.pyi | 2 +- python/pyspark/sql/pandas/serializers.py | 1 - python/pyspark/worker.py | 16 ++++++---------- .../plans/logical/pythonLogicalOperators.scala | 5 ----- .../spark/sql/RelationalGroupedDataset.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 2 +- .../sql/execution/python/PythonArrowOutput.scala | 3 --- 9 files changed, 8 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7616b7616587e..6c9377a436eb7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -784,7 +784,6 @@ private[spark] object SpecialLengths { val END_OF_STREAM = -4 val NULL = -5 val START_ARROW_STREAM = -6 - val START_STATE_UPDATE = -7 } private[spark] object BarrierTaskContextMessageProtocol { diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index c6eef13c5bf58..8c5a941f376d2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -85,7 +85,6 @@ class SpecialLengths: END_OF_STREAM = -4 NULL = -5 START_ARROW_STREAM = -6 - START_STATE_UPDATE = -7 class Serializer: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 82b861c51cf5c..d6b46523c34b6 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -257,7 +257,7 @@ PandasGroupedMapFunction = Union[ Callable[[Tuple, DataFrameLike], DataFrameLike], ] -PandasGroupedMapFunctionWithState = Callable[[Tuple, DataFrameLike, GroupStateImpl], DataFrameLike] +PandasGroupedMapFunctionWithState = Callable[[Tuple, Iterable[DataFrameLike], GroupStateImpl], DataFrameLike] class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 86c144220b443..7250b1e1509ab 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -33,7 +33,6 @@ class SpecialLengths: END_OF_STREAM = -4 NULL = -5 START_ARROW_STREAM = -6 - START_STATE_UPDATE = -7 class ArrowCollectSerializer(Serializer): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7f51fbe88f68e..68fc6209b226b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,7 +57,8 @@ from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, CogroupUDFSerializer, - ArrowStreamUDFSerializer, ApplyInPandasWithStateSerializer, + ArrowStreamUDFSerializer, + ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType @@ -538,9 +539,7 @@ def extract_key_value_indexes(grouped_arg_offsets): idx += offsets_len return parsed - if eval_type in ( - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - ): + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 @@ -584,6 +583,8 @@ def mapper(a): keys_gen, values_gen = tee(data_gen) keys_elem = next(keys_gen) keys = [keys_elem[o] for o in parsed_offsets[0][0]] + + # This must be generator comprehension - do not materialize. vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen) return f(keys, vals, state) @@ -687,7 +688,6 @@ def main(infile, outfile): ) # initialize global state - state_schema = None taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() @@ -773,9 +773,7 @@ def main(infile, outfile): if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: - func, profiler, deserializer, serializer = read_udfs( - pickleSer, infile, eval_type - ) + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) func_init_time = time.time() @@ -823,7 +821,6 @@ def process(): faulthandler.disable() faulthandler_log_file.close() os.remove(faulthandler_log_path) - finish_time = time.time() report_times(outfile, boot_time, init_time, func_init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) @@ -831,7 +828,6 @@ def process(): # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 67d072bc36824..e97ff7808f172 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -112,7 +112,6 @@ case class FlatMapCoGroupsInPandas( * @param outputAttrs used to define the output rows * @param stateType used to serialize/deserialize state before calling `functionExpr` * @param outputMode the output mode of `func` - * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method * @param timeout used to timeout groups that have not received data in a while * @param child logical plan of the underlying data */ @@ -122,12 +121,8 @@ case class FlatMapGroupsInPandasWithState( outputAttrs: Seq[Attribute], stateType: StructType, outputMode: OutputMode, - isMapGroupsWithState: Boolean = false, timeout: GroupStateTimeout, child: LogicalPlan) extends UnaryNode { - if (isMapGroupsWithState) { - assert(outputMode == OutputMode.Update) - } override def output: Seq[Attribute] = outputAttrs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6c7b14b2334cf..69eb8101abf73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -646,7 +646,6 @@ class RelationalGroupedDataset protected[sql]( outputAttrs, stateStructType, outputMode, - isMapGroupsWithState = false, timeoutConf, child = df.logicalPlan) Dataset.ofRows(df.sparkSession, plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7ec47f469adde..0f25f53da4228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -691,7 +691,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object FlatMapGroupsInPandasWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FlatMapGroupsInPandasWithState( - func, groupAttr, outputAttr, stateType, outputMode, _, timeout, child) => + func, groupAttr, outputAttr, stateType, outputMode, timeout, child) => val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = python.FlatMapGroupsInPandasWithStateExec( func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index c8398f2316b7a..d86827d556e03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -107,9 +107,6 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ new ArrowColumnVector(vector) }.toArray[ColumnVector] read() - case SpecialLengths.START_STATE_UPDATE => - handleStateUpdate(stream) - read() case SpecialLengths.TIMING_DATA => handleTimingData() read() From e7ad043814b638693a21434af4e0278b2ad3e30e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 6 Sep 2022 15:37:17 +0900 Subject: [PATCH 35/44] WIP no batch query support in applyInPandasWithState --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0f25f53da4228..8feb68909b2f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -814,7 +814,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ) :: Nil case _: FlatMapGroupsInPandasWithState => // TODO(SPARK-XXXXX): Implement batch support for applyInPandasWithState - throw new UnsupportedOperationException("applyInPandasWithState is unsupported.") + throw new UnsupportedOperationException( + "applyInPandasWithState is unsupported in batch query. Use applyInPandas instead.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, From 5070b81fc6a8c7acf0b383c4050ce31dd613430e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 6 Sep 2022 16:43:03 +0900 Subject: [PATCH 36/44] WIP address some missed things --- .../UnsupportedOperationChecker.scala | 61 ++++++++++ ...psInPandasWithStateDistributionSuite.scala | 115 ++++++++++++++++++ .../FlatMapGroupsInPandasWithStateSuite.scala | 76 ++++++++++++ 3 files changed, 252 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c11ce7d3b90f1..aa3205b6f0347 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -64,6 +64,7 @@ object UnsupportedOperationChecker extends Logging { case s: Aggregate if s.isStreaming => true case _ @ Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => true case f: FlatMapGroupsWithState if f.isStreaming => true + case f: FlatMapGroupsInPandasWithState if f.isStreaming => true case d: Deduplicate if d.isStreaming => true case _ => false } @@ -142,6 +143,16 @@ object UnsupportedOperationChecker extends Logging { " or the output mode is not append on a streaming DataFrames/Datasets")(plan) } + val applyInPandasWithStates = plan.collect { + case f: FlatMapGroupsInPandasWithState if f.isStreaming => f + } + + // Disallow multiple `applyInPandasWithState`s. + if (applyInPandasWithStates.size >= 2) { + throwError( + "Multiple applyInPandasWithStates are not supported on a streaming DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -311,6 +322,56 @@ object UnsupportedOperationChecker extends Logging { } } + // applyInPandasWithState + case m: FlatMapGroupsInPandasWithState if m.isStreaming => + // Check compatibility with output modes and aggregations in query + val aggsInQuery = collectStreamingAggregates(plan) + + if (aggsInQuery.isEmpty) { + // applyInPandasWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "applyInPandasWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "applyInPandasWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => + } + } else { + // applyInPandasWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "applyInPandasWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "applyInPandasWithState in append mode is not supported after " + + "aggregation on a streaming DataFrame/Dataset") + } + } + + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "applyInPandasWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala new file mode 100644 index 0000000000000..75ee188445beb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.IntegratedUDFTestUtils.{shouldTestPandasUDFs, TestGroupedMapPandasUDFWithState} +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.util.{StatefulOpClusteredDistributionTestHelper, StreamManualClock} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} + +class FlatMapGroupsInPandasWithStateDistributionSuite extends StreamTest + with StatefulOpClusteredDistributionTestHelper { + + import testImplicits._ + + test("applyInPandasWithState should require StatefulOpClusteredDistribution " + + "from children - without initial state") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType, IntegerType + | + |tpe = StructType([ + | StructField("key1", StringType()), + | StructField("key2", StringType()), + | StructField("count", IntegerType())]) + | + |def func(key, pdf_iter, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | ret = pd.DataFrame() + | if count >= 3: + | state.remove() + | else: + | ret = pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'count': [count]}) + | + | return ret + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[(String, String, Long)] + val outputStructType = StructType( + Seq( + StructField("key1", StringType), + StructField("key2", StringType), + StructField("count", IntegerType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + val result = + inputDataDS + .withWatermark("timestamp", "10 second") + .repartition($"key1") + .groupBy($"key1", $"key2") + .applyInPandasWithState( + pythonFunc(inputDataDS("key1"), inputDataDS("key2"), inputDataDS("timestamp")) + .expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + .select("key1", "key2", "count") + + val clock = new StreamManualClock + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a is processed here for the first time. + CheckNewAnswer(("a", "a", 1)), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsInPandasWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsInPandasWithStateExec => f + } + + assert(flatMapGroupsInPandasWithStateExecs.length === 1) + assert(requireStatefulOpClusteredDistribution( + flatMapGroupsInPandasWithStateExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsInPandasWithStateExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 5b8c0365a3cf3..6a31eb7d86d50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.IntegratedUDFTestUtils._ import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update} +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf @@ -521,4 +522,79 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { } testWithTimeout(NoTimeout) testWithTimeout(ProcessingTimeTimeout) + + test("applyInPandasWithState - uses state format version 2 by default") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | ret = pd.DataFrame() + | if count >= 3: + | state.remove() + | else: + | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | + | return ret + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsInPandasWithStateExec => f + } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } } From 1b919b8dc75b0cf165c6cb4158694f7a3eb62dc9 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 7 Sep 2022 13:29:25 +0900 Subject: [PATCH 37/44] WIP remove comments which are obsolete or won't be addressed --- .../python/ArrowPythonRunnerWithState.scala | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index a456ddf16f1dd..3e6fb229ac5cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -197,12 +197,6 @@ class ArrowPythonRunnerWithState( // the range of data and give a view, say, "zero-copy". To help splitting the range for // data, we provide the "start offset" and the "number of data" in the state metadata. // - // FIXME: probably need to change this as "hard limit" when addressing scalability. Worth - // noting that we may need to break down the data into chunks for a specific group - // having "small" number of data, because we also do bin-packing as well. Maybe we could - // concatenate these chunks in Python worker (serializer), with some hints e.g. - // We can get the information - the number of data in the chunk before reading. - // // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft // limit on the size - it's not a hard limit since we allow current group to write all // data even it's going to exceed the limit. @@ -258,14 +252,6 @@ class ArrowPythonRunnerWithState( // Provide data rows while (dataIter.hasNext) { val dataRow = dataIter.next() - // TODO: if we think there will be non-small amount of data per grouping key, - // we could probably try out "dictionary encoding" for the optimization - // of storing same grouping keys multiple times. This may complicate the logic, as - // in IPC streaming format, DictionaryBatch will be provided separately along with - // RecordBatch, and I'm not sure whether the record batch can be directly converted - // to Pandas DataFrame / Series if the record batch refers to the dictionary batch. - // https://arrow.apache.org/docs/format/Columnar.html#dictionary-encoded-layout - // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format arrowWriterForData.write(dataRow) numRowsForCurGroup += 1 totalNumRowsForBatch += 1 From 198fc17f53708ddaf2c15dbed638d7ee66127798 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 7 Sep 2022 15:41:06 +0900 Subject: [PATCH 38/44] WIP change the return type of user function to Iterator[DataFrame] --- .../pyspark/sql/pandas/_typing/__init__.pyi | 2 +- python/pyspark/sql/pandas/serializers.py | 10 +-- .../test_pandas_grouped_map_with_state.py | 2 +- python/pyspark/worker.py | 45 +++++++++---- .../UnsupportedOperationChecker.scala | 3 +- ...psInPandasWithStateDistributionSuite.scala | 6 +- .../FlatMapGroupsInPandasWithStateSuite.scala | 67 +++++++++---------- 7 files changed, 77 insertions(+), 58 deletions(-) diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index d6b46523c34b6..7b972edc88dd2 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -257,7 +257,7 @@ PandasGroupedMapFunction = Union[ Callable[[Tuple, DataFrameLike], DataFrameLike], ] -PandasGroupedMapFunctionWithState = Callable[[Tuple, Iterable[DataFrameLike], GroupStateImpl], DataFrameLike] +PandasGroupedMapFunctionWithState = Callable[[Tuple, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]] class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 7250b1e1509ab..3a29e763cbac6 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -533,7 +533,7 @@ def init_stream_yield_batches(): for data in iterator: packaged_result = data[0] - pdf = packaged_result[0][0] + pdf_iter = packaged_result[0][0] state = packaged_result[0][1] # this won't change across batches return_schema = packaged_result[1] @@ -541,9 +541,11 @@ def init_stream_yield_batches(): # FIXME: arrow type to pandas type # FIXME: probably also need to check columns to validate? - if len(pdf) > 0: - pdf_data_cnt += len(pdf) - pdfs.append(pdf) + for pdf in pdf_iter: + # FIXME: probably need to reduce down the scope of record batch to this? + if len(pdf) > 0: + pdf_data_cnt += len(pdf) + pdfs.append(pdf) # pick up state for only last chunk as state should have been updated so far state_properties = state.json().encode("utf-8") diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py index 414915c574e35..9271853ab625a 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py @@ -67,7 +67,7 @@ def func(key, pdf_iter, state): state.update((total_len,)) assert state.get[0] == 1 - return pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]}) + yield pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]}) def check_results(batch_df, _): self.assertEqual( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 68fc6209b226b..3ba3178221967 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -222,25 +222,44 @@ def wrapped(key_series, value_series_gen, state): else: values = (pd.concat(x, axis=1) for x in value_series_gen) - result = f(key, values, state) + result_iter = f(key, values, state) - if not isinstance(result, pd.DataFrame): + def verify_element(result): + if not isinstance(result, pd.DataFrame): + raise TypeError( + "The type of element in return iterator of the user-defined function " + "should be pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty + ): + raise RuntimeError( + "Number of columns of the element (pandas.DataFrame) in return iterator " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + + return result + + if isinstance(result_iter, pd.DataFrame): raise TypeError( "Return type of the user-defined function should be " - "pandas.DataFrame, but is {}".format(type(result)) + "iterable of pandas.DataFrame, but is {}".format(type(result_iter)) ) - # the number of columns of result have to match the return type - # but it is fine for result to have no columns at all if it is empty - if not ( - len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty - ): - raise RuntimeError( - "Number of columns of the returned pandas.DataFrame " - "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + + try: + iter(result_iter) + except TypeError: + raise TypeError( + "Return type of the user-defined function should be " + "iterable, but is {}".format(type(result_iter)) ) - return (result, state, ) + result_iter_with_validation = (verify_element(x) for x in result_iter) + + return (result_iter_with_validation, state, ) return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index aa3205b6f0347..99ba3802097b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -150,7 +150,8 @@ object UnsupportedOperationChecker extends Logging { // Disallow multiple `applyInPandasWithState`s. if (applyInPandasWithStates.size >= 2) { throwError( - "Multiple applyInPandasWithStates are not supported on a streaming DataFrames/Datasets")(plan) + "Multiple applyInPandasWithStates are not supported on a streaming " + + "DataFrames/Datasets")(plan) } // Disallow multiple streaming aggregations diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala index 75ee188445beb..9c6573fd782ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala @@ -57,13 +57,11 @@ class FlatMapGroupsInPandasWithStateDistributionSuite extends StreamTest | count += len(pdf) | state.update((count,)) | - | ret = pd.DataFrame() | if count >= 3: | state.remove() + | yield pd.DataFrame() | else: - | ret = pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'count': [count]}) - | - | return ret + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'count': [count]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 6a31eb7d86d50..aa2d7169ce885 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -64,13 +64,11 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count += len(pdf) | state.update((count,)) | - | ret = pd.DataFrame() | if count >= 3: | state.remove() + | yield pd.DataFrame() | else: - | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) - | - | return ret + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -125,23 +123,23 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { |tpe = StructType([ | StructField("key", StringType()), | StructField("value", IntegerType()), - | StructField("countAsString", StringType())]) + | StructField("valueAsString", StringType()), + | StructField("prevCountAsString", StringType())]) | |def func(key, pdf_iter, state): - | count = state.getOption - | if count is None: - | count = 0 + | prev_count = state.getOption + | if prev_count is None: + | prev_count = 0 | else: - | count = count[0] + | prev_count = prev_count[0] | - | pdf_list = [] + | count = prev_count | for pdf in pdf_iter: | count += len(pdf) - | pdf_list.append(pdf) + | yield pdf.assign(valueAsString=lambda x: x.value.apply(str), + | prevCountAsString=str(prev_count)) | | state.update((count,)) - | pdf_concat = pd.concat(pdf_list) - | return pdf_concat.assign(countAsString=str(count)) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -152,7 +150,8 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { Seq( StructField("key", StringType), StructField("value", IntegerType), - StructField("countAsString", StringType))) + StructField("valueAsString", StringType), + StructField("prevCountAsString", StringType))) val stateStructType = StructType(Seq(StructField("count", LongType))) val inputDataDS = inputData.toDS().selectExpr("_1 AS key", "_2 AS value") val result = @@ -164,20 +163,23 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { stateStructType, "Update", "NoTimeout") - .select("key", "value", "countAsString") + .select("key", "value", "valueAsString", "prevCountAsString") testStream(result, Update)( AddData(inputData, ("a", 1)), - CheckNewAnswer(("a", 1, "1")), + CheckNewAnswer(("a", 1, "1", "0")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, ("a", 2), ("a", 3), ("b", 1)), - CheckNewAnswer(("a", 2, "3"), ("a", 3, "3"), ("b", 1, "1")), + CheckNewAnswer(("a", 2, "2", "1"), ("a", 3, "3", "1"), ("b", 1, "1", "0")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, ("b", 2), ("c", 1), ("d", 1), ("e", 1)), - CheckNewAnswer(("b", 2, "2"), ("c", 1, "1"), ("d", 1, "1"), ("e", 1, "1")), - assertNumStateRows(total = 5, updated = 4) + CheckNewAnswer(("b", 2, "2", "1"), ("c", 1, "1", "0"), ("d", 1, "1", "0"), + ("e", 1, "1", "0")), + assertNumStateRows(total = 5, updated = 4), + AddData(inputData, ("a", 4)), + CheckNewAnswer(("a", 4, "4", "3")) ) } } @@ -211,11 +213,9 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | ret = pd.DataFrame() | if count >= 3: | state.remove() - | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) | else: - | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) - | - | return ret + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -282,9 +282,10 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | except RuntimeError as e: | assert "watermark" in str(e) | + | ret = None | if state.hasTimedOut: | state.remove() - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) | else: | count = state.getOption | if count is None: @@ -297,7 +298,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | | state.update((count,)) | state.setTimeoutDuration(10000) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -387,7 +388,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | timeout_delay_sec = 5 | if state.hasTimedOut: | state.remove() - | return pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) + | yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) | else: | m = state.getOption | if m is None: @@ -403,8 +404,8 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | state.update((max_event_time_sec,)) | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) - | return pd.DataFrame({'key': [key[0]], - | 'maxEventTimeSec': [max_event_time_sec]}) + | yield pd.DataFrame({'key': [key[0]], + | 'maxEventTimeSec': [max_event_time_sec]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -472,7 +473,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { |def func(key, pdf_iter, state): | if state.hasTimedOut: | state.remove() - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) | else: | count = state.getOption | if count is None: @@ -485,7 +486,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | | state.update((count,)) | state.setTimeoutDuration(10000) - | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) @@ -555,13 +556,11 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { | count += len(pdf) | state.update((count,)) | - | ret = pd.DataFrame() | if count >= 3: | state.remove() + | yield pd.DataFrame() | else: - | ret = pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) - | - | return ret + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) |""".stripMargin val pythonFunc = TestGroupedMapPandasUDFWithState( name = "pandas_grouped_map_with_state", pythonScript = pythonScript) From f2a75f19586333efd82ea0e68c26619ef5a43625 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 13 Sep 2022 11:04:27 +0900 Subject: [PATCH 39/44] WIP remove unnecessary interface/implementation changes on GroupState as it bothers MiMa --- python/pyspark/sql/pandas/serializers.py | 7 +++++-- python/pyspark/sql/streaming/state.py | 13 ++++++++----- .../python/ArrowPythonRunnerWithState.scala | 15 +++++++++------ .../FlatMapGroupsInPandasWithStateExec.scala | 5 +++-- .../sql/execution/streaming/GroupStateImpl.scala | 9 +-------- .../spark/sql/streaming/TestGroupState.scala | 3 --- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 3a29e763cbac6..37e77e21eaf43 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -23,7 +23,7 @@ from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.types import StringType, StructType, BinaryType, StructField, BooleanType +from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType class SpecialLengths: @@ -391,6 +391,7 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema StructField('properties', StringType()), StructField('keyRowAsUnsafe', BinaryType()), StructField('object', BinaryType()), + StructField('oldTimeoutTimestamp', LongType()), ]) self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) @@ -497,7 +498,7 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat dict.fromkeys(pa.schema(pdf_schema).names), index=[x for x in range(0, empty_row_cnt_in_data)]) empty_rows_state = pd.DataFrame( - columns=['properties', 'keyRowAsUnsafe', 'object'], + columns=['properties', 'keyRowAsUnsafe', 'object', 'oldTimeoutTimestamp'], index=[x for x in range(0, empty_row_cnt_in_state)]) pdfs.append(empty_rows_pdf) @@ -551,11 +552,13 @@ def init_stream_yield_batches(): state_properties = state.json().encode("utf-8") state_key_row_as_binary = state._keyAsUnsafe state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + state_old_timeout_timestamp = state.oldTimeoutTimestamp state_dict = { 'properties': [state_properties, ], 'keyRowAsUnsafe': [state_key_row_as_binary, ], 'object': [state_object, ], + 'oldTimeoutTimestamp': [state_old_timeout_timestamp, ], } state_pdf = pd.DataFrame.from_dict(state_dict) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 8b883c4b5b809..c036f9704557a 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -45,7 +45,6 @@ def __init__( defined: bool, updated: bool, removed: bool, - timeoutUpdated: bool, timeoutTimestamp: int, # Python internal state. keyAsUnsafe: bytes, @@ -70,7 +69,8 @@ def __init__( self._updated = updated self._removed = removed self._timeout_timestamp = timeoutTimestamp - self._timeout_updated = timeoutUpdated + # Python internal state. + self._old_timeout_timestamp = timeoutTimestamp self._value_schema = valueSchema @@ -96,6 +96,12 @@ def getOption(self) -> Optional[Tuple]: def hasTimedOut(self) -> bool: return self._has_timed_out + # NOTE: this function is only available to PySpark implementation due to underlying + # implementation, do not port to Scala implementation! + @property + def oldTimeoutTimestamp(self) -> int: + return self._old_timeout_timestamp + def update(self, newValue: Tuple) -> None: if newValue is None: raise ValueError("'None' is not a valid state value") @@ -124,7 +130,6 @@ def setTimeoutDuration(self, durationMs: int) -> None: if durationMs <= 0: raise ValueError("Timeout duration must be positive") self._timeout_timestamp = durationMs + self._batch_processing_time_ms - self._timeout_updated = True # TODO(SPARK-XXXXX): Implement additionalDuration parameter. def setTimeoutTimestamp(self, timestampMs: int) -> None: @@ -150,7 +155,6 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: ) self._timeout_timestamp = timestampMs - self._timeout_updated = True def getCurrentWatermarkMs(self) -> int: if not self._watermark_present: @@ -184,6 +188,5 @@ def json(self) -> str: "updated": self._updated, "removed": self._removed, "timeoutTimestamp": self._timeout_timestamp, - "timeoutUpdated": self._timeout_updated, } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala index 3e6fb229ac5cf..e205480fd61fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -64,7 +64,7 @@ class ArrowPythonRunnerWithState( softTimeoutMillsPurgeBatch: Long) extends BasePythonRunner[ (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]), - (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])]( + (Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow])]( funcs, evalType, argOffsets) { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback @@ -111,7 +111,7 @@ class ArrowPythonRunnerWithState( pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext) - : Iterator[(Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow])] = { + : Iterator[(Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow])] = { new StateReaderIterator(stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) } @@ -388,7 +388,7 @@ class ArrowPythonRunnerWithState( private var batchLoaded = true protected override def read() - : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { + : (Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow]) = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } @@ -433,7 +433,7 @@ class ArrowPythonRunnerWithState( } private def deserializeColumnarBatch(batch: ColumnarBatch) - : (Iterator[(UnsafeRow, GroupStateImpl[Row])], Iterator[InternalRow]) = { + : (Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow]) = { // This should at least have one row for state. Also, we ensure that all columns across // data and state metadata have same number of rows, which is required by Arrow record // batch. @@ -459,7 +459,7 @@ class ArrowPythonRunnerWithState( } def constructIterForState( - batch: ColumnarBatch): Iterator[(UnsafeRow, GroupStateImpl[Row])] = { + batch: ColumnarBatch): Iterator[(UnsafeRow, GroupStateImpl[Row], Long)] = { // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(1).asInstanceOf[ArrowColumnVector] @@ -479,6 +479,7 @@ class ArrowPythonRunnerWithState( // StructField("properties", StringType), // StructField("keyRowAsUnsafe", BinaryType), // StructField("object", BinaryType), + // StructField("oldTimeoutTimestamp", LongType), // ) // TODO: Do we want to rely on the column name rather than the ordinal for safety? val propertiesAsJson = parse(row.getUTF8String(0).toString) @@ -492,8 +493,10 @@ class ArrowPythonRunnerWithState( Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, stateRowDeserializer)) } + val oldTimeoutTimestamp = row.getLong(3) - Some(keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)) + Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), + oldTimeoutTimestamp)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 808afffdef1e5..cccdac7c9c54c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -175,15 +175,16 @@ case class FlatMapGroupsInPandasWithStateExec( // When the iterator is consumed, then write changes to state. // state does not affect each others, hence when to update does not affect to the result. def onIteratorCompletion: Unit = { - stateIter.foreach { case (keyRow, newGroupState) => + stateIter.foreach { case (keyRow, newGroupState, oldTimeoutTimestamp) => if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { stateManager.removeState(store, keyRow) numRemovedStateRows += 1 } else { val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs .orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != oldTimeoutTimestamp val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || - newGroupState.isTimeoutUpdated + hasTimeoutChanged if (shouldWriteState) { val updatedStateObj = if (newGroupState.exists) newGroupState.get else null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 8b220cb202957..bcd3cfc4508dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -59,7 +59,6 @@ private[sql] class GroupStateImpl[S] private[sql]( private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed private var timeoutTimestamp: Long = NO_TIMESTAMP - private var timeoutUpdated: Boolean = false // ========= Public API ========= override def exists: Boolean = defined @@ -104,7 +103,6 @@ private[sql] class GroupStateImpl[S] private[sql]( throw new IllegalArgumentException("Timeout duration must be positive") } timeoutTimestamp = durationMs + batchProcessingTimeMs - timeoutUpdated = true } override def setTimeoutDuration(duration: String): Unit = { @@ -122,7 +120,6 @@ private[sql] class GroupStateImpl[S] private[sql]( s"current watermark ($eventTimeWatermarkMs)") } timeoutTimestamp = timestampMs - timeoutUpdated = true } override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { @@ -161,8 +158,6 @@ private[sql] class GroupStateImpl[S] private[sql]( override def isUpdated: Boolean = updated - override def isTimeoutUpdated: Boolean = timeoutUpdated - override def getTimeoutTimestampMs: Optional[Long] = { if (timeoutTimestamp != NO_TIMESTAMP) { Optional.of(timeoutTimestamp) @@ -199,8 +194,7 @@ private[sql] class GroupStateImpl[S] private[sql]( "defined" -> JBool(defined) :: "updated" -> JBool(updated) :: "removed" -> JBool(removed) :: - "timeoutTimestamp" -> JLong(timeoutTimestamp) :: - "timeoutUpdated" -> JBool(timeoutUpdated) :: Nil + "timeoutTimestamp" -> JLong(timeoutTimestamp) :: Nil ))) } @@ -271,7 +265,6 @@ private[sql] object GroupStateImpl { newGroupState.removed = hmap("removed").asInstanceOf[Boolean] newGroupState.timeoutTimestamp = hmap("timeoutTimestamp").asInstanceOf[Number].longValue() - newGroupState.timeoutUpdated = hmap("timeoutUpdated").asInstanceOf[Boolean] newGroupState } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala index 346bde5df3e24..d53d6087d677c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala @@ -123,9 +123,6 @@ trait TestGroupState[S] extends GroupState[S] { /** Whether the state has been updated but not removed */ def isUpdated: Boolean - /** FIXME: ... */ - def isTimeoutUpdated: Boolean - /** * Returns the timestamp if `setTimeoutTimestamp()` is called. * Or, returns batch processing time + the duration when From 3e5f5d4e878e5a9a1e71490191cf98751c273739 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 13 Sep 2022 20:21:22 +0900 Subject: [PATCH 40/44] WIP refine out some code --- python/pyspark/sql/pandas/serializers.py | 61 ++--- .../apache/spark/sql/internal/SQLConf.scala | 23 +- .../ApplyInPandasWithStatePythonRunner.scala | 197 ++++++++++++++++ .../python/ApplyInPandasWithStateWriter.scala | 220 ++++++++++++++++++ .../FlatMapGroupsInPandasWithStateExec.scala | 10 +- .../execution/python/PythonArrowInput.scala | 4 +- 6 files changed, 479 insertions(+), 36 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 37e77e21eaf43..2f8bac1092ccc 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -475,6 +475,7 @@ def gen_data_and_state(batches): data_state_generator = gen_data_and_state(batches) + # state will be same object for same grouping key for state, data in groupby(data_state_generator, key=lambda x: x[1]): yield (data, state,) @@ -486,6 +487,11 @@ def dump_stream(self, iterator, stream): """ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + """ + Arrow RecordBatch requires all columns to have all same number of rows. + Insert empty data for state/data with less elements to compensate. + """ + import pandas as pd import pyarrow as pa @@ -525,9 +531,6 @@ def init_stream_yield_batches(): state_data_cnt = 0 sampled_data_size_per_row = 0 - sampled_state_size = 0 - # FIXME: sample with empty state size separately? - sampled_empty_state_size = 0 last_purged_time_ns = time.time_ns() @@ -539,15 +542,37 @@ def init_stream_yield_batches(): # this won't change across batches return_schema = packaged_result[1] - # FIXME: arrow type to pandas type - # FIXME: probably also need to check columns to validate? - for pdf in pdf_iter: - # FIXME: probably need to reduce down the scope of record batch to this? if len(pdf) > 0: pdf_data_cnt += len(pdf) pdfs.append(pdf) + if sampled_data_size_per_row == 0 and \ + pdf_data_cnt > self.minDataCountForSample: + memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs] + sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt + + # This effectively works after the sampling has completed, size we multiply by 0 + # if the sampling is still in progress. + batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) >= \ + self.softLimitBytesPerBatch + + if batch_over_limit_on_size: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + last_purged_time_ns = time.time_ns() + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + # pick up state for only last chunk as state should have been updated so far state_properties = state.json().encode("utf-8") state_key_row_as_binary = state._keyAsUnsafe @@ -566,24 +591,10 @@ def init_stream_yield_batches(): state_pdfs.append(state_pdf) state_data_cnt += 1 - # FIXME: threshold of sample data - if sampled_data_size_per_row == 0 and pdf_data_cnt > self.minDataCountForSample: - memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs] - sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt - - # FIXME: threshold of sample data - if sampled_state_size == 0 and state_data_cnt > self.minDataCountForSample: - memory_usages = [p.memory_usage(deep=True).sum() for p in state_pdfs] - sampled_state_size = sum(memory_usages) / state_data_cnt - - # This effectively works after the sampling has completed, size we multiply by 0 - # if the sampling is still in progress. - batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) + \ - (sampled_state_size * state_data_cnt) >= self.softLimitBytesPerBatch cur_time_ns = time.time_ns() is_timed_out_on_purge = ((cur_time_ns - last_purged_time_ns) // 1000000) >= \ self.softTimeoutMillisPurgeBatch - if batch_over_limit_on_size or is_timed_out_on_purge: + if is_timed_out_on_purge: batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt) @@ -604,14 +615,8 @@ def init_stream_yield_batches(): batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt) - pdfs = [] - state_pdfs = [] - pdf_data_cnt = 0 - state_data_cnt = 0 - if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False yield batch diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 521b2b3a897dd..c8acb8ac09cd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2707,21 +2707,38 @@ object SQLConf { val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH = buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch") - // FIXME: doc + .internal() + .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " + + "records that can be written to a single ArrowRecordBatch in memory. This is used to " + + "restrict the amount of memory being used to materialize the data in both executor and " + + "Python worker. The accumulated size of records are calculated via sampling a set of " + + "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " + + "is quite huge, the size of constructed ArrowRecordBatch will be around the " + + "configured value.") .version("3.4.0") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("64MB") val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE = buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample") - // FIXME: doc + .internal() + .doc("When using applyInPandasWithState, specify the minimum number of records to sample " + + "the size of record. The size being retrieved from sampling will be used to estimate " + + "the accumulated size of records. Note that limiting by size does not work if the " + + "number of records are less than the configured value. For such case, ArrowRecordBatch " + + "will only be split for soft timeout.") .version("3.4.0") .intConf .createWithDefault(100) val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH = buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch") - // FIXME: doc + .internal() + .doc("When using applyInPandasWithState, specify the soft timeout for purging the " + + "ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " + + "the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " + + "of data (both executor and Python worker) to not wait indefinitely for sender to " + + "complete the ArrowRecordBatch, which may hurt both throughput and latency.") .version("3.4.0") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("100ms") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala new file mode 100644 index 0000000000000..1bcf9610ba2ee --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} +import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + + +/** + * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. + */ +class ApplyInPandasWithStatePythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + inputSchema: StructType, + override protected val timeZoneId: String, + initialWorkerConf: Map[String, String], + stateEncoder: ExpressionEncoder[Row], + keySchema: StructType, + valueSchema: StructType, + stateValueSchema: StructType, + softLimitBytesPerBatch: Long, + minDataCountForSample: Int, + softTimeoutMillsPurgeBatch: Long) + extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) + with PythonArrowInput[InType] + with PythonArrowOutput[OutType] { + + override protected val schema: StructType = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + + override protected val workerConf: Map[String, String] = initialWorkerConf + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> + softLimitBytesPerBatch.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> + minDataCountForSample.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> + softTimeoutMillsPurgeBatch.toString) + + private val stateRowDeserializer = stateEncoder.createDeserializer() + + override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + // Also write the schema for state value + PythonRDD.writeUTF(stateValueSchema.json, stream) + } + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[InType]): Unit = { + val w = new ApplyInPandasWithStateWriter(root, writer, softLimitBytesPerBatch, + minDataCountForSample, softTimeoutMillsPurgeBatch) + + while (inputIterator.hasNext) { + val (keyRow, groupState, dataIter) = inputIterator.next() + assert(dataIter.hasNext, "should have at least one data row!") + w.startNewGroup(keyRow, groupState) + + while (dataIter.hasNext) { + val dataRow = dataIter.next() + w.writeRow(dataRow) + } + + w.finalizeGroup() + } + + w.finalizeData() + } + + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = { + // This should at least have one row for state. Also, we ensure that all columns across + // data and state metadata have same number of rows, which is required by Arrow record + // batch. + assert(batch.numRows() > 0) + assert(schema.length == 2) + + def getColumnarBatchForStructTypeColumn( + batch: ColumnarBatch, + ordinal: Int, + expectedType: StructType): ColumnarBatch = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector] + val dataType = schema(ordinal).dataType.asInstanceOf[StructType] + assert(dataType.sameType(expectedType)) + + val outputVectors = dataType.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + flattenedBatch + } + + def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { + val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, valueSchema) + dataBatch.rowIterator.asScala.flatMap { row => + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for state metadata. + None + } else { + Some(row) + } + } + } + + def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] = { + val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1, + STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER) + + stateMetadataBatch.rowIterator().asScala.flatMap { row => + implicit val formats = org.json4s.DefaultFormats + + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for data. + None + } else { + // NOTE: See StateReaderIterator.STATE_METADATA_SCHEMA for the schema. + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { + None + } else { + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) + } + val oldTimeoutTimestamp = row.getLong(3) + + Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), + oldTimeoutTimestamp)) + } + } + } + + (constructIterForState(batch), constructIterForData(batch)) + } +} + +object ApplyInPandasWithStatePythonRunner { + type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) + type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long) + type OutType = (Iterator[OutTypeForState], Iterator[InternalRow]) + + val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("oldTimeoutTimestamp", LongType), + ) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala new file mode 100644 index 0000000000000..781335f821eba --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + + +class ApplyInPandasWithStateWriter( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + softLimitBytesPerBatch: Long, + minDataCountForSample: Int, + softTimeoutMillsPurgeBatch: Long) { + + import ApplyInPandasWithStateWriter._ + + // We logically group the columns by family and initialize writer separately, since it's + // lot more easier and probably performant to write the row directly rather than + // projecting the row to match up with the overall schema. + // + // The number of data rows and state metadata rows can be different which seems to matter + // for Arrow RecordBatch, so we append empty rows to cover it. + // + // We always produce at least one data row per grouping key whereas we only produce one + // state metadata row per grouping key, so we only need to fill up the empty rows in + // state metadata side. + private val arrowWriterForData = createArrowWriter(root.getFieldVectors.asScala.dropRight(1)) + private val arrowWriterForState = createArrowWriter(root.getFieldVectors.asScala.takeRight(1)) + + // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to + // gain the performance. In many cases, the amount of data per grouping key is quite + // small, which does not seem to maximize the benefits of using Arrow. + // + // We have to split the record batch down to each group in Python worker to convert the + // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split + // the range of data and give a view, say, "zero-copy". To help splitting the range for + // data, we provide the "start offset" and the "number of data" in the state metadata. + // + // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft + // limit on the size - it's not a hard limit since we allow current group to write all + // data even it's going to exceed the limit. + // + // We perform some basic sampling for data to guess the size of the data very roughly, + // and simply multiply by the number of data to estimate the size. We extract the size of + // data from the record batch rather than UnsafeRow, as we don't hold the memory for + // UnsafeRow once we write to the record batch. If there is a memory bound here, it + // should come from record batch. + // + // In the meanwhile, we don't also want to let the current record batch collect the data + // indefinitely, since we are pipelining the process between executor and python worker. + // Python worker won't process any data if executor is not yet finalized a record + // batch, which defeats the purpose of pipelining. To address this, we also introduce + // timeout for constructing a record batch. This is a soft limit indeed as same as limit + // on the size - we allow current group to write all data even it's timed-out. + + private var numRowsForCurGroup = 0 + private var startOffsetForCurGroup = 0 + private var totalNumRowsForBatch = 0 + private var totalNumStatesForBatch = 0 + + private var sampledDataSizePerRow = 0 + private var lastBatchPurgedMillis = System.currentTimeMillis() + + private var currentGroupKeyRow: UnsafeRow = _ + private var currentGroupState: GroupStateImpl[Row] = _ + + def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = { + currentGroupKeyRow = keyRow + currentGroupState = groupState + } + + def writeRow(dataRow: InternalRow): Unit = { + // Currently, this only works when the number of rows are greater than the minimum + // data count for sampling. And we technically have no way to pick some rows from + // record batch and measure the size of data, hence we leverage all data in current + // record batch. We only sample once as it could be costly. + if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { + sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch + } + + // If it exceeds the condition of batch (only size, not about timeout) and + // there is more data for the same group, flush and construct a new batch. + + // The soft-limit on size effectively works after the sampling has completed, since we + // multiply the number of rows by 0 if the sampling is still in progress. + + // if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch) { + // FIXME: debug + if (totalNumRowsForBatch % 10 == 1) { + // Provide state metadata row as intermediate + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + finalizeCurrentArrowBatch() + } + + arrowWriterForData.write(dataRow) + numRowsForCurGroup += 1 + totalNumRowsForBatch += 1 + } + + def finalizeGroup(): Unit = { + // Provide state metadata row + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = true) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + // The start offset for next group would be same as the total number of rows for batch, + // unless the next group starts with new batch. + startOffsetForCurGroup = totalNumRowsForBatch + + // The soft-limit on timeout applies on finalization of each group. + if (System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { + finalizeCurrentArrowBatch() + } + } + + def finalizeData(): Unit = { + if (numRowsForCurGroup > 0) { + // We still have some rows in the current record batch. Need to flush them as well. + finalizeCurrentArrowBatch() + } + } + + private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = { + val children = fieldVectors.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + + new ArrowWriter(root, children.toArray) + } + + private def buildStateInfoRow( + keyRow: UnsafeRow, + groupState: GroupStateImpl[Row], + startOffset: Int, + numRows: Int, + isLastChunk: Boolean): InternalRow = { + // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA + val stateUnderlyingRow = new GenericInternalRow( + Array[Any]( + UTF8String.fromString(groupState.json()), + keyRow.getBytes, + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, + startOffset, + numRows, + isLastChunk + ) + ) + new GenericInternalRow(Array[Any](stateUnderlyingRow)) + } + + private def finalizeCurrentArrowBatch(): Unit = { + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) + } + + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurGroup = 0 + numRowsForCurGroup = 0 + totalNumRowsForBatch = 0 + totalNumStatesForBatch = 0 + lastBatchPurgedMillis = System.currentTimeMillis() + } +} + +object ApplyInPandasWithStateWriter { + val STATE_METADATA_SCHEMA: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("startOffset", IntegerType), + StructField("numRows", IntegerType), + StructField("isLastChunk", BooleanType) + ) + ) + + // To avoid initializing a new row for empty state metadata row. + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( + Array[Any](null, null, null, null, null, null)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index cccdac7c9c54c..b833809561f75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -64,8 +64,6 @@ case class FlatMapGroupsInPandasWithStateExec( eventTimeWatermark: Option[Long], child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { - private val keySchema: StructType = groupingAttributes.toStructType - // TODO(SPARK-XXXXX): Add the support of initial state. override protected val initialStateDeserializer: Expression = null override protected val initialStateGroupAttrs: Seq[Attribute] = null @@ -114,6 +112,12 @@ case class FlatMapGroupsInPandasWithStateExec( process(processIter, hasTimedOut = false) } + override def processNewDataWithInitialState( + childDataIter: Iterator[InternalRow], + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + override def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { @@ -143,7 +147,7 @@ case class FlatMapGroupsInPandasWithStateExec( private def process( iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])], hasTimedOut: Boolean): Iterator[InternalRow] = { - val runner = new ArrowPythonRunnerWithState( + val runner = new ApplyInPandasWithStatePythonRunner( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, Array(argOffsets), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 6168d0f867adb..c0a2db8518a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -21,11 +21,12 @@ import java.net.Socket import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter - import org.apache.spark.{SparkEnv, TaskContext} + import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -76,7 +77,6 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { - val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() From 4e34d297545b65a694c950ebef940d62e5874bdf Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 13 Sep 2022 20:36:12 +0900 Subject: [PATCH 41/44] WIP fix scalastyle --- .../execution/python/ApplyInPandasWithStatePythonRunner.scala | 2 +- .../apache/spark/sql/execution/python/PythonArrowInput.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 1bcf9610ba2ee..213c9f4e712bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -191,7 +191,7 @@ object ApplyInPandasWithStatePythonRunner { StructField("properties", StringType), StructField("keyRowAsUnsafe", BinaryType), StructField("object", BinaryType), - StructField("oldTimeoutTimestamp", LongType), + StructField("oldTimeoutTimestamp", LongType) ) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index c0a2db8518a22..37718fcfcb897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -21,8 +21,8 @@ import java.net.Socket import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter From 50e743ea41dd9b193708361751480be5aad2d394 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 13 Sep 2022 21:35:58 +0900 Subject: [PATCH 42/44] WIP remove obsolete class --- .../python/ArrowPythonRunnerWithState.scala | 507 ------------------ 1 file changed, 507 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala deleted file mode 100644 index e205480fd61fc..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala +++ /dev/null @@ -1,507 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.python - -import java.io._ -import java.net.Socket -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ - -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} -import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python._ -import org.apache.spark.sql.Row -import org.apache.spark.sql.api.python.PythonSQLUtils -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} -import org.apache.spark.sql.execution.arrow.ArrowWriter -import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter -import org.apache.spark.sql.execution.streaming.GroupStateImpl -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils - -/** - * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. - */ -class ArrowPythonRunnerWithState( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - inputSchema: StructType, - timeZoneId: String, - workerConf: Map[String, String], - stateEncoder: ExpressionEncoder[Row], - keySchema: StructType, - valueSchema: StructType, - stateValueSchema: StructType, - softLimitBytesPerBatch: Long, - minDataCountForSample: Int, - softTimeoutMillsPurgeBatch: Long) - extends BasePythonRunner[ - (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]), - (Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow])]( - funcs, evalType, argOffsets) { - - override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - - override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize - require( - bufferSize >= 4, - "Pandas execution requires more than 4 bytes. Please set higher buffer. " + - s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") - - private val workerConfWithRunnerConfs = workerConf + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> - softLimitBytesPerBatch.toString) + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> - minDataCountForSample.toString) + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> - softTimeoutMillsPurgeBatch.toString) - - protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { - // Write config for the worker as a number of key -> value pairs of strings - stream.writeInt(workerConfWithRunnerConfs.size) - for ((k, v) <- workerConfWithRunnerConfs) { - PythonRDD.writeUTF(k, stream) - PythonRDD.writeUTF(v, stream) - } - PythonRDD.writeUTF(stateValueSchema.json, stream) - } - - protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[(UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])], - partitionIndex: Int, - context: TaskContext): WriterThread = { - new StateWriterThread(env, worker, inputIterator, partitionIndex, context) - } - - protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext) - : Iterator[(Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow])] = { - new StateReaderIterator(stream, writerThread, startTime, env, worker, pid, - releasedOrClosed, context) - } - - private class StateWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[(UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])], - partitionIndex: Int, - context: TaskContext) - extends WriterThread(env, worker, inputIterator, partitionIndex, context) { - - import StateWriterThread._ - - private val schemaWithState = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - handleMetadataBeforeExec(dataOut) - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } - - private def buildStateInfoRow( - keyRow: UnsafeRow, - groupState: GroupStateImpl[Row], - startOffset: Int, - numRows: Int, - isLastChunk: Boolean): InternalRow = { - // NOTE: see ArrowPythonRunnerWithState.STATE_METADATA_SCHEMA - val stateUnderlyingRow = new GenericInternalRow( - Array[Any]( - UTF8String.fromString(groupState.json()), - keyRow.getBytes, - groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, - startOffset, - numRows, - isLastChunk - ) - ) - new GenericInternalRow(Array[Any](stateUnderlyingRow)) - } - - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - // We initialize all columns in data & state metadata for Arrow RecordBatch. - val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - - Utils.tryWithSafeFinally { - // We logically group the columns by family and initialize writer separately, since it's - // lot more easier and probably performant to write the row directly rather than - // projecting the row to match up with the overall schema. - // The number of data rows and state metadata rows can be different which seems to matter - // for Arrow RecordBatch, so we append empty rows to cover it. - // We always produce at least one data row per grouping key whereas we only produce one - // state metadata row per grouping key, so we only need to fill up the empty rows in - // state metadata side. - val arrowWriterForData = { - val children = root.getFieldVectors().asScala.dropRight(1).map { vector => - vector.allocateNew() - createFieldWriter(vector) - } - - new ArrowWriter(root, children.toArray) - } - val arrowWriterForState = { - val children = root.getFieldVectors().asScala.takeRight(1).map { vector => - vector.allocateNew() - createFieldWriter(vector) - } - new ArrowWriter(root, children.toArray) - } - - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() - - // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to - // gain the performance. In many cases, the amount of data per grouping key is quite - // small, which does not seem to maximize the benefits of using Arrow. - // - // We have to split the record batch down to each group in Python worker to convert the - // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split - // the range of data and give a view, say, "zero-copy". To help splitting the range for - // data, we provide the "start offset" and the "number of data" in the state metadata. - // - // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft - // limit on the size - it's not a hard limit since we allow current group to write all - // data even it's going to exceed the limit. - // - // We perform some basic sampling for data to guess the size of the data very roughly, - // and simply multiply by the number of data to estimate the size. We extract the size of - // data from the record batch rather than UnsafeRow, as we don't hold the memory for - // UnsafeRow once we write to the record batch. If there is a memory bound here, it - // should come from record batch. - // - // In the meanwhile, we don't also want to let the current record batch collect the data - // indefinitely, since we are pipelining the process between executor and python worker. - // Python worker won't process any data if executor is not yet finalized a record - // batch, which defeats the purpose of pipelining. To address this, we also introduce - // timeout for constructing a record batch. This is a soft limit indeed as same as limit - // on the size - we allow current group to write all data even it's timed-out. - - // FIXME: Maybe better if we can extract out the batching logic into a separate class. - var numRowsForCurGroup = 0 - var startOffsetForCurGroup = 0 - var totalNumRowsForBatch = 0 - var totalNumStatesForBatch = 0 - - var sampledDataSizePerRow = 0 - var lastBatchPurgedMillis = System.currentTimeMillis() - - def finalizeCurrentArrowBatch(): Unit = { - val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch - (0 until remainingEmptyStateRows).foreach { _ => - arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) - } - - arrowWriterForState.finish() - arrowWriterForData.finish() - writer.writeBatch() - arrowWriterForState.reset() - arrowWriterForData.reset() - - startOffsetForCurGroup = 0 - numRowsForCurGroup = 0 - totalNumRowsForBatch = 0 - totalNumStatesForBatch = 0 - lastBatchPurgedMillis = System.currentTimeMillis() - } - - while (inputIterator.hasNext) { - val (keyRow, groupState, dataIter) = inputIterator.next() - - assert(dataIter.hasNext, "should have at least one data row!") - - numRowsForCurGroup = 0 - - // Provide data rows - while (dataIter.hasNext) { - val dataRow = dataIter.next() - arrowWriterForData.write(dataRow) - numRowsForCurGroup += 1 - totalNumRowsForBatch += 1 - - // Currently, this only works when the number of rows are greater than the minimum - // data count for sampling. And we technically have no way to pick some rows from - // record batch and measure the size of data, hence we leverage all data in current - // record batch. We only sample once as it could be costly. - if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { - sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch - } - - // If it exceeds the condition of batch (only size, not about timeout) and - // there is more data for the same group, flush and construct a new batch. - - // if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch && - // dataIter.hasNext) { - // FIXME: DEBUGGING now... split the data per 10 elements <- 1 element for testing - if (numRowsForCurGroup % 10 == 1 && dataIter.hasNext) { - // Provide state metadata row as intermediate - val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, - numRowsForCurGroup, isLastChunk = false) - arrowWriterForState.write(stateInfoRow) - totalNumStatesForBatch += 1 - - finalizeCurrentArrowBatch() - } - } - - // Provide state metadata row - val stateInfoRow = buildStateInfoRow(keyRow, groupState, startOffsetForCurGroup, - numRowsForCurGroup, isLastChunk = true) - arrowWriterForState.write(stateInfoRow) - totalNumStatesForBatch += 1 - - // The start offset for next group would be same as the total number of rows for batch, - // unless the next group starts with new batch. - startOffsetForCurGroup = totalNumRowsForBatch - - // FIXME: Do we need to come up with sampling "across record batches"? - // FIXME: Do we need to also come up with the size of state metadata as well? - // FIXME: Do we need to separate the case of "state with value" vs - // "state without value" on sampling? - - // Currently, this only works when the number of rows are greater than the minimum - // data count for sampling. And we technically have no way to pick some rows from - // record batch and measure the size of data, hence we leverage all data in current - // record batch. We only sample once as it could be costly. - if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { - sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch - } - - // The soft-limit on size effectively works after the sampling has completed, since we - // multiply the number of rows by 0 if the sampling is still in progress. The - // soft-limit on timeout always applies. - if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch || - System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { - finalizeCurrentArrowBatch() - } - } - - if (numRowsForCurGroup > 0) { - // We still have some rows in the current record batch. Need to flush them as well. - finalizeCurrentArrowBatch() - } - - // end writes footer to the output stream and doesn't clean any resources. - // It could throw exception if the output stream is closed, so it should be - // in the try block. - writer.end() - } { - // If we close root and allocator in TaskCompletionListener, there could be a race - // condition where the writer thread keeps writing to the VectorSchemaRoot while - // it's being closed by the TaskCompletion listener. - // Closing root and allocator here is cleaner because root and allocator is owned - // by the writer thread and is only visible to the writer thread. - // - // If the writer thread is interrupted by TaskCompletionListener, it should either - // (1) in the try block, in which case it will get an InterruptedException when - // performing io, and goes into the finally block or (2) in the finally block, - // in which case it will ignore the interruption and close the resources. - root.close() - allocator.close() - } - } - } - - object StateWriterThread { - val STATE_METADATA_SCHEMA: StructType = StructType( - Array( - StructField("properties", StringType), - StructField("keyRowAsUnsafe", BinaryType), - StructField("object", BinaryType), - StructField("startOffset", IntegerType), - StructField("numRows", IntegerType), - StructField("isLastChunk", BooleanType) - ) - ) - - // To avoid initializing a new row for empty state metadata row. - val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( - Array[Any](null, null, null, null, null, null)) - } - - class StateReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext) - extends ReaderIterator(stream, writerThread, startTime, env, worker, pid, - releasedOrClosed, context) { - - private val stateRowDeserializer = stateEncoder.createDeserializer() - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def read() - : (Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow]) = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - deserializeColumnarBatch(batch) - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - - // FIXME: should we validate schema here with value schema? - // FIXME: should we validate schema here with state metadata schema? - - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - - private def deserializeColumnarBatch(batch: ColumnarBatch) - : (Iterator[(UnsafeRow, GroupStateImpl[Row], Long)], Iterator[InternalRow]) = { - // This should at least have one row for state. Also, we ensure that all columns across - // data and state metadata have same number of rows, which is required by Arrow record - // batch. - assert(batch.numRows() > 0) - assert(schema.length == 2) - - def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { - // UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = schema(0).dataType.asInstanceOf[StructType] - .indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - - flattenedBatch.rowIterator.asScala.flatMap { row => - if (row.isNullAt(0)) { - // The entire row in record batch seems to be for state metadata. - None - } else { - Some(row) - } - } - } - - def constructIterForState( - batch: ColumnarBatch): Iterator[(UnsafeRow, GroupStateImpl[Row], Long)] = { - // UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(1).asInstanceOf[ArrowColumnVector] - - val outputVectors = schema(1).dataType.asInstanceOf[StructType] - .indices.map(structVector.getChild) - val flattenedBatchForState = new ColumnarBatch(outputVectors.toArray) - flattenedBatchForState.setNumRows(batch.numRows()) - flattenedBatchForState.rowIterator().asScala.flatMap { row => - implicit val formats = org.json4s.DefaultFormats - - if (row.isNullAt(0)) { - // The entire row in record batch seems to be for data. - None - } else { - // Received state metadata does not need schema - this class already knows them. - // Array( - // StructField("properties", StringType), - // StructField("keyRowAsUnsafe", BinaryType), - // StructField("object", BinaryType), - // StructField("oldTimeoutTimestamp", LongType), - // ) - // TODO: Do we want to rely on the column name rather than the ordinal for safety? - val propertiesAsJson = parse(row.getUTF8String(0).toString) - val keyRowAsUnsafeAsBinary = row.getBinary(1) - val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) - keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) - val maybeObjectRow = if (row.isNullAt(2)) { - None - } else { - val pickledStateValue = row.getBinary(2) - Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, - stateRowDeserializer)) - } - val oldTimeoutTimestamp = row.getLong(3) - - Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), - oldTimeoutTimestamp)) - } - } - } - - (constructIterForState(batch), constructIterForData(batch)) - } - } -} From d22d7db0a8165ebec98a5914113819d6619ee00c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 13 Sep 2022 21:45:34 +0900 Subject: [PATCH 43/44] WIP remove the temp fix --- sql/core/pom.xml | 66 ------------------------------------------------ 1 file changed, 66 deletions(-) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 736fd515d35d2..7203fc591081a 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -267,72 +267,6 @@ - - - - org.apache.maven.plugins - maven-checkstyle-plugin - 3.1.2 - - true - false - true - - ${basedir}/src/main/java - ${basedir}/src/main/scala - - - ${basedir}/src/test/java - - dev/checkstyle.xml - ${basedir}/target/checkstyle-output.xml - ${project.build.sourceEncoding} - ${project.reporting.outputEncoding} - - - - - com.puppycrawl.tools - checkstyle - 8.43 - - - - - - check - - - - - - org.scalastyle - scalastyle-maven-plugin - 1.0.0 - - true - false - false - false - false - ${basedir}/src/main/scala - ${basedir}/src/test/scala - ../../scalastyle-config.xml - ${basedir}/target/scalastyle-output.xml - ${project.build.sourceEncoding} - ${project.reporting.outputEncoding} - - - - - check - - - - From e60408f37e0ced881fa12bdb28907357026fb7e8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 14 Sep 2022 14:54:35 +0900 Subject: [PATCH 44/44] remove unused code --- .../apache/spark/sql/execution/python/PythonArrowInput.scala | 1 - .../apache/spark/sql/execution/python/PythonArrowOutput.scala | 4 ---- 2 files changed, 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 37718fcfcb897..bf66791183ece 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -26,7 +26,6 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index d86827d556e03..339f114539c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -41,10 +41,6 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT - protected def handleStateUpdate(stream: DataInputStream): Unit = { - new IllegalStateException("Should not reach here!") - } - protected def newReaderIterator( stream: DataInputStream, writerThread: WriterThread,