Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,10 @@ def crossJoin(self, other: ParentDataFrame) -> ParentDataFrame:
)

def zip(self, other: ParentDataFrame) -> ParentDataFrame:
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={"feature": "zip"},
other = self._check_same_session(other)
return DataFrame(
plan.Zip(self._plan, other._plan),
session=self._session,
)

def _check_same_session(self, other: ParentDataFrame) -> "DataFrame":
Expand Down Expand Up @@ -2515,10 +2516,6 @@ def _test() -> None:

globs = pyspark.sql.dataframe.__dict__.copy()

# `zip` is not yet supported on Spark Connect; the parent docstring's
# example would call into the connect impl and fail with NOT_IMPLEMENTED.
del pyspark.sql.dataframe.DataFrame.zip.__doc__

if not is_remote_only():
del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
del pyspark.sql.dataframe.DataFrame.rdd.__doc__
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,40 @@ def _repr_html_(self) -> str:
"""


class Zip(LogicalPlan):
def __init__(self, left: Optional[LogicalPlan], right: LogicalPlan) -> None:
super().__init__(left)
self.left = cast(LogicalPlan, left)
self.right = right

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.zip.left.CopyFrom(self.left.plan(session))
plan.zip.right.CopyFrom(self.right.plan(session))
return self._with_relations(plan, session)

@property
def observations(self) -> Dict[str, "Observation"]:
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
i = " " * indent
o = " " * (indent + LogicalPlan.INDENT)
n = indent + LogicalPlan.INDENT * 2
return f"{i}<Zip>\n{o}left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>Zip</b><br />
Left: {self.left._repr_html_()}
Right: {self.right._repr_html_()}
</li>
</ul>
"""


class SetOperation(LogicalPlan):
def __init__(
self,
Expand Down
354 changes: 178 additions & 176 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class Relation(google.protobuf.message.Message):
CHUNKED_CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
RELATION_CHANGES_FIELD_NUMBER: builtins.int
NEAREST_BY_JOIN_FIELD_NUMBER: builtins.int
ZIP_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -226,6 +227,8 @@ class Relation(google.protobuf.message.Message):
@property
def nearest_by_join(self) -> global___NearestByJoin: ...
@property
def zip(self) -> global___Zip: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -314,6 +317,7 @@ class Relation(google.protobuf.message.Message):
chunked_cached_local_relation: global___ChunkedCachedLocalRelation | None = ...,
relation_changes: global___RelationChanges | None = ...,
nearest_by_join: global___NearestByJoin | None = ...,
zip: global___Zip | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -459,6 +463,8 @@ class Relation(google.protobuf.message.Message):
b"with_relations",
"with_watermark",
b"with_watermark",
"zip",
b"zip",
],
) -> builtins.bool: ...
def ClearField(
Expand Down Expand Up @@ -590,6 +596,8 @@ class Relation(google.protobuf.message.Message):
b"with_relations",
"with_watermark",
b"with_watermark",
"zip",
b"zip",
],
) -> None: ...
def WhichOneof(
Expand Down Expand Up @@ -642,6 +650,7 @@ class Relation(google.protobuf.message.Message):
"chunked_cached_local_relation",
"relation_changes",
"nearest_by_join",
"zip",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -4742,3 +4751,35 @@ class NearestByJoin(google.protobuf.message.Message):
) -> None: ...

global___NearestByJoin = NearestByJoin

class Zip(google.protobuf.message.Message):
"""Relation of type [[Zip]].

Combines the columns of two DataFrames side-by-side. Both DataFrames must produce the same
canonicalized plan after stripping outer Project chains.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

LEFT_FIELD_NUMBER: builtins.int
RIGHT_FIELD_NUMBER: builtins.int
@property
def left(self) -> global___Relation:
"""(Required) Left input relation."""
@property
def right(self) -> global___Relation:
"""(Required) Right input relation."""
def __init__(
self,
*,
left: global___Relation | None = ...,
right: global___Relation | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["left", b"left", "right", b"right"]
) -> builtins.bool: ...
def ClearField(
self, field_name: typing_extensions.Literal["left", b"left", "right", b"right"]
) -> None: ...

global___Zip = Zip
14 changes: 14 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def test_crossjoin(self):
join_plan.root.join.join_type,
)

def test_zip(self):
left_input = self.connect.readTable(table_name=self.tbl_name)
right_input = self.connect.readTable(table_name=self.tbl_name)
plan = left_input.zip(right_input)._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root.zip)
self.assertEqual(
plan.root.zip.left.read.named_table.unparsed_identifier,
self.tbl_name,
)
self.assertEqual(
plan.root.zip.right.read.named_table.unparsed_identifier,
self.tbl_name,
)

def test_filter(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3)._plan.to_proto(self.connect)
Expand Down
14 changes: 3 additions & 11 deletions python/pyspark/sql/tests/connect/test_parity_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,12 @@
# limitations under the License.
#

from pyspark.errors import PySparkNotImplementedError
from pyspark.sql.tests.test_zip import DataFrameZipTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class ZipParityTests(ReusedConnectTestCase):
"""`DataFrame.zip` is classic-only for now; assert the Connect stub raises a clean
NOT_IMPLEMENTED instead of falling through to a generic error or appearing to work."""

def test_zip_raises_not_implemented(self):
df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
with self.assertRaises(PySparkNotImplementedError) as ctx:
df.select("a").zip(df.select("b"))
self.assertEqual(ctx.exception.getCondition(), "NOT_IMPLEMENTED")
self.assertIn("zip", str(ctx.exception))
class ZipParityTests(DataFrameZipTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,10 @@ class PlanGenerationTestSuite extends ConnectFunSuite with Logging {
simple.withMetadata("id", builder.build())
}

test("zip") {
left.select("id").zip(left.select("a"))
}

test("zipWithIndex") {
simple.zipWithIndex()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,19 @@ class DataFrameSuite extends QueryTest with RemoteSparkSession {
spark.conf.unset("spark.sql.analyzer.strictDataFrameColumnResolution")
}
}

test("zip") {
val sparkSession = spark
import sparkSession.implicits._

val df = Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c")
val left = df.select("a")
val right = df.select("b")

val zipped = left.zip(right)
assert(zipped.columns === Array("a", "b"))
val rows = zipped.collect().sortBy(_.getInt(0))
assert(rows(0).getInt(0) === 1 && rows(0).getInt(1) === 2)
assert(rows(1).getInt(0) === 4 && rows(1).getInt(1) === 5)
}
}
13 changes: 13 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/relations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ message Relation {
ChunkedCachedLocalRelation chunked_cached_local_relation = 45;
RelationChanges relation_changes = 46;
NearestByJoin nearest_by_join = 47;
Zip zip = 48;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -1307,3 +1308,15 @@ message NearestByJoin {
// (Required) Ranking direction. Must be one of: "distance", "similarity".
string direction = 7;
}

// Relation of type [[Zip]].
//
// Combines the columns of two DataFrames side-by-side. Both DataFrames must produce the same
// canonicalized plan after stripping outer Project chains.
message Zip {
// (Required) Left input relation.
Relation left = 1;

// (Required) Right input relation.
Relation right = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,12 @@ class Dataset[T] private[sql] (

/** @inheritdoc */
def zip(other: sql.Dataset[_]): DataFrame = {
throw new UnsupportedOperationException("zip is not supported in Spark Connect")
checkSameSparkSession(other)
sparkSession.newDataFrame { builder =>
builder.getZipBuilder
.setLeft(plan.getRoot)
.setRight(other.asInstanceOf[Dataset[_]].plan.getRoot)
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [id#0L, a#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
81 changes: 81 additions & 0 deletions sql/connect/common/src/test/resources/query-tests/queries/zip.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
{
"common": {
"planId": "4"
},
"zip": {
"left": {
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"expressions": [{
"unresolvedAttribute": {
"unparsedIdentifier": "id"
},
"common": {
"origin": {
"jvmOrigin": {
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.connect.Dataset",
"methodName": "select",
"fileName": "Dataset.scala"
}, {
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
"methodName": "~~trimmed~anonfun~~",
"fileName": "PlanGenerationTestSuite.scala"
}]
}
}
}
}]
}
},
"right": {
"common": {
"planId": "3"
},
"project": {
"input": {
"common": {
"planId": "2"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"expressions": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
},
"common": {
"origin": {
"jvmOrigin": {
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.connect.Dataset",
"methodName": "select",
"fileName": "Dataset.scala"
}, {
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
"methodName": "~~trimmed~anonfun~~",
"fileName": "PlanGenerationTestSuite.scala"
}]
}
}
}
}]
}
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class SparkConnectPlanner(
case proto.Relation.RelTypeCase.LATERAL_JOIN => transformLateralJoin(rel.getLateralJoin)
case proto.Relation.RelTypeCase.NEAREST_BY_JOIN =>
transformNearestByJoin(rel.getNearestByJoin)
case proto.Relation.RelTypeCase.ZIP => transformZip(rel.getZip)
case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
Expand Down Expand Up @@ -2591,6 +2592,11 @@ class SparkConnectPlanner(
.logicalPlan
}

private def transformZip(rel: proto.Zip): LogicalPlan = {
assertPlan(rel.hasLeft && rel.hasRight, "Both zip sides must be present")
logical.Zip(transformRelation(rel.getLeft), transformRelation(rel.getRight))
}

private def transformSort(sort: proto.Sort): LogicalPlan = {
assertPlan(sort.getOrderCount > 0, "'order' must be present and contain elements.")
logical.Sort(
Expand Down