-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-41005][CONNECT][PYTHON][FOLLOW-UP] Fetch/send partitions in parallel for Arrow based collect #38613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-41005][CONNECT][PYTHON][FOLLOW-UP] Fetch/send partitions in parallel for Arrow based collect #38613
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -145,36 +145,10 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte | |
| newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) } | ||
| } | ||
|
|
||
| val signal = new Object | ||
| val partitions = collection.mutable.Map.empty[Int, Array[Batch]] | ||
|
|
||
| val processPartition = (iter: Iterator[Batch]) => iter.toArray | ||
|
|
||
| // This callback is executed by the DAGScheduler thread. | ||
| // After fetching a partition, it inserts the partition into the Map, and then | ||
| // wakes up the main thread. | ||
| val resultHandler = (partitionId: Int, partition: Array[Batch]) => { | ||
| signal.synchronized { | ||
| partitions(partitionId) = partition | ||
| signal.notify() | ||
| } | ||
| () | ||
| } | ||
|
|
||
| spark.sparkContext.runJob(batches, processPartition, resultHandler) | ||
|
|
||
| // The man thread will wait until 0-th partition is available, | ||
| // then send it to client and wait for next partition. | ||
| var currentPartitionId = 0 | ||
| while (currentPartitionId < numPartitions) { | ||
| val partition = signal.synchronized { | ||
| while (!partitions.contains(currentPartitionId)) { | ||
| signal.wait() | ||
| } | ||
| partitions.remove(currentPartitionId).get | ||
| } | ||
|
|
||
| partition.foreach { case (bytes, count) => | ||
| def writeBatches(arrowBatches: Array[Batch]): Unit = { | ||
| for (arrowBatch <- arrowBatches) { | ||
| val (bytes, count) = arrowBatch | ||
| val response = proto.Response.newBuilder().setClientId(clientId) | ||
| val batch = proto.Response.ArrowBatch | ||
| .newBuilder() | ||
|
|
@@ -185,9 +159,30 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte | |
| responseObserver.onNext(response.build()) | ||
| numSent += 1 | ||
| } | ||
| } | ||
|
|
||
| // Store collection results for worst case of 1 to N-1 partitions | ||
| val results = new Array[Array[Batch]](numPartitions - 1) | ||
| var lastIndex = -1 // index of last partition written | ||
|
|
||
| currentPartitionId += 1 | ||
| // Handler to eagerly write partitions in order | ||
| val resultHandler = (partitionId: Int, partition: Array[Batch]) => { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do it need to be synchronized?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, it doesn't (because it's guided by the index). This approach is actually from the initial ordered implementation of collect with Arrow (that were in production for very long time), 82c18c2#diff-459628811d7786c705fbb2b7a381ecd2b88f183f44ab607d43b3d33ea48d390fR3282-R3318. |
||
| // If result is from next partition in order | ||
| if (partitionId - 1 == lastIndex) { | ||
| writeBatches(partition) | ||
| lastIndex += 1 | ||
| // Write stored partitions that come next in order | ||
| while (lastIndex < results.length && results(lastIndex) != null) { | ||
| writeBatches(results(lastIndex)) | ||
| results(lastIndex) = null | ||
| lastIndex += 1 | ||
| } | ||
| } else { | ||
| // Store partitions received out of order | ||
| results(partitionId - 1) = partition | ||
| } | ||
| } | ||
| spark.sparkContext.runJob(batches, (iter: Iterator[Batch]) => iter.toArray, resultHandler) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can create a threadpool? (shared across
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, I just noticed the review comment. I believe this is matched with our current implementation in PySpark. If we should improve, let's improve both paths together. I would prefer to match them and deduplicate the logic first.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for match the implementations
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use a thread pool if you have thread sitting around? |
||
| } | ||
|
|
||
| // Make sure at least 1 batch will be sent. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why I suggested to use locks and the main thread to write the results is exactly what this comment is trying to convey. You don't want these operations to happen inside the DAGScheduler thread. If you keep that blocked for something none scheduling related, you will stop all other scheduling. This is particularly bad in an environment where you might have multiple users running code at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have seen in higher concurrency scenarios that this does become a problem. Throughput will plateau because the DAGScheduler is doing the wrong things. I would like to avoid that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! We should write it down as code comments. @zhengruifeng can you help with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. let me add a comment for this