diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 656b44d2d..cacc25c8e 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -1,11 +1,16 @@ import functools from dask.dataframe.dispatch import make_meta, meta_nonempty -from dask.utils import M, apply +from dask.utils import M, apply, get_default_shuffle_method +from distributed.shuffle._core import ShuffleId, barrier_key +from distributed.shuffle._merge import merge_transfer, merge_unpack +from distributed.shuffle._shuffle import shuffle_barrier -from dask_expr._expr import Blockwise, Expr, Index, Projection +from dask_expr._expr import Blockwise, Expr, Index, PartitionsFiltered, Projection from dask_expr._repartition import Repartition -from dask_expr._shuffle import Shuffle, _contains_index_name +from dask_expr._shuffle import AssignPartitioningIndex, Shuffle, _contains_index_name + +_HASH_COLUMN_NAME = "__hash_partition" class Merge(Expr): @@ -140,6 +145,28 @@ def _lower(self): if shuffle_right_on is None: raise NotImplementedError("Cannot shuffle unnamed index") + if (shuffle_left_on or shuffle_right_on) and ( + shuffle_backend == "p2p" + or shuffle_backend is None + and get_default_shuffle_method() == "p2p" + ): + left = AssignPartitioningIndex( + left, shuffle_left_on, _HASH_COLUMN_NAME, self.npartitions + ) + right = AssignPartitioningIndex( + right, shuffle_right_on, _HASH_COLUMN_NAME, self.npartitions + ) + return HashJoinP2P( + left, + right, + left_on=left_on, + right_on=right_on, + suffixes=self.suffixes, + indicator=self.indicator, + left_index=left_index, + right_index=right_index, + ) + if shuffle_left_on: # Shuffle left left = Shuffle( @@ -213,6 +240,104 @@ def _simplify_up(self, parent): return result[parent_columns] +class HashJoinP2P(Merge, PartitionsFiltered): + _parameters = [ + "left", + "right", + "how", + "left_on", + "right_on", + "left_index", + "right_index", + "suffixes", + "indicator", + "_partitions", + ] + _defaults = { + "how": "inner", + "left_on": None, + "right_on": None, + "left_index": None, + "right_index": None, + "suffixes": ("_x", "_y"), + "indicator": False, + "_partitions": None, + } + + def _lower(self): + return None + + @functools.cached_property + def _meta(self): + left = self.left._meta.drop(columns=_HASH_COLUMN_NAME) + right = self.right._meta.drop(columns=_HASH_COLUMN_NAME) + return left.merge( + right, + left_on=self.left_on, + right_on=self.right_on, + indicator=self.indicator, + suffixes=self.suffixes, + left_index=self.left_index, + right_index=self.right_index, + ) + + def _layer(self) -> dict: + dsk = {} + name_left = "hash-join-transfer-" + self.left._name + name_right = "hash-join-transfer-" + self.right._name + transfer_keys_left = list() + transfer_keys_right = list() + for i in range(self.left.npartitions): + transfer_keys_left.append((name_left, i)) + dsk[(name_left, i)] = ( + merge_transfer, + (self.left._name, i), + self.left._name, + i, + self.npartitions, + self._partitions, + ) + for i in range(self.right.npartitions): + transfer_keys_right.append((name_right, i)) + dsk[(name_right, i)] = ( + merge_transfer, + (self.right._name, i), + self.right._name, + i, + self.npartitions, + self._partitions, + ) + + _barrier_key_left = barrier_key(ShuffleId(self.left._name)) + _barrier_key_right = barrier_key(ShuffleId(self.right._name)) + dsk[_barrier_key_left] = (shuffle_barrier, self.left._name, transfer_keys_left) + dsk[_barrier_key_right] = ( + shuffle_barrier, + self.right._name, + transfer_keys_right, + ) + + for part_out in self._partitions: + dsk[(self._name, part_out)] = ( + merge_unpack, + self.left._name, + self.right._name, + part_out, + _barrier_key_left, + _barrier_key_right, + self.how, + self.left_on, + self.right_on, + self.left._meta.drop(columns=_HASH_COLUMN_NAME), + self.right._meta.drop(columns=_HASH_COLUMN_NAME), + self._meta, + self.suffixes, + self.left_index, + self.right_index, + ) + return dsk + + class BlockwiseMerge(Merge, Blockwise): """Merge two dataframes with aligned partitions diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index 6591fc1c5..f69b5bced 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -2,6 +2,9 @@ import pytest +from dask_expr import from_pandas +from dask_expr.tests._util import _backend_library + distributed = pytest.importorskip("distributed") from distributed.utils_test import client as c # noqa F401 @@ -9,6 +12,9 @@ import dask_expr as dx +# Set DataFrame backend for this module +lib = _backend_library() + @pytest.mark.parametrize("npartitions", [None, 1, 20]) @gen_cluster(client=True) @@ -31,3 +37,36 @@ async def test_p2p_shuffle(c, s, a, b, npartitions): assert x == y if npartitions != 1: assert x > z + + +@pytest.mark.parametrize("npartitions_left", [5, 6]) +@gen_cluster(client=True) +async def test_merge_p2p_shuffle(c, s, a, b, npartitions_left): + df_left = lib.DataFrame({"a": [1, 2, 3] * 100, "b": 2}) + df_right = lib.DataFrame({"a": [4, 2, 3] * 100, "c": 2}) + left = from_pandas(df_left, npartitions=npartitions_left) + right = from_pandas(df_right, npartitions=5) + + out = left.merge(right, shuffle_backend="p2p") + assert out.npartitions == npartitions_left + x = c.compute(out) + x = await x + lib.testing.assert_frame_equal(x.reset_index(drop=True), df_left.merge(df_right)) + + +@pytest.mark.parametrize("npartitions_left", [5, 6]) +@gen_cluster(client=True) +async def test_index_merge_p2p_shuffle(c, s, a, b, npartitions_left): + df_left = lib.DataFrame({"a": [1, 2, 3] * 100, "b": 2}).set_index("a") + df_right = lib.DataFrame({"a": [4, 2, 3] * 100, "c": 2}) + left = from_pandas(df_left, npartitions=npartitions_left, sort=False) + right = from_pandas(df_right, npartitions=5) + + out = left.merge(right, left_index=True, right_on="a", shuffle_backend="p2p") + assert out.npartitions == npartitions_left + x = c.compute(out) + x = await x + lib.testing.assert_frame_equal( + x.sort_index(), + df_left.merge(df_right, left_index=True, right_on="a").sort_index(), + )