From 3670a9be505a87aa0ceba2f630f5f343af0eff1c Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 12 Feb 2024 18:12:37 +0100 Subject: [PATCH 1/2] Adjust partitions for merge --- dask_expr/_merge.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 4091bbc53..d55b5f23e 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -139,7 +139,19 @@ def _meta(self): def _npartitions(self): if self.operand("_npartitions") is not None: return self.operand("_npartitions") - return max(self.left.npartitions, self.right.npartitions) + if min(self.left.npartitions, self.right.npartitions) == 1: + return max(self.left.npartitions, self.right.npartitions) + if self.left.npartitions <= self.right.npartitions: + df_lower = self.left + df_higher = self.right + else: + df_lower = self.right + df_higher = self.left + npartitions = df_higher.npartitions + factor = (len(df_lower.columns) + len(df_higher.columns)) / len( + df_higher.columns + ) + return int(npartitions * factor) @property def _bcast_left(self): From 9664accf9a5ec86510aa6e1ea47b49599188a1df Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 12 Feb 2024 22:45:07 +0100 Subject: [PATCH 2/2] Implement more intelligent partitions count for merge --- dask_expr/_merge.py | 53 ++++++++++++++++++++++++++--- dask_expr/tests/test_distributed.py | 8 ++--- dask_expr/tests/test_merge.py | 44 ++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 10 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index d55b5f23e..eec03df73 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -23,12 +23,14 @@ Filter, Index, Isin, + MaybeAlignPartitions, PartitionsFiltered, Projection, Unaryop, determine_column_projection, is_filter_pushdown_available, ) +from dask_expr._reductions import Reduction from dask_expr._repartition import Repartition from dask_expr._shuffle import ( RearrangeByColumn, @@ -36,6 +38,7 @@ _select_columns_or_index, ) from dask_expr._util import _convert_to_list, _tokenize_deterministic, is_scalar +from dask_expr.io import IO _HASH_COLUMN_NAME = "__hash_partition" _PARTITION_COLUMN = "_partitions" @@ -135,6 +138,35 @@ def _meta(self): kwargs["how"] = "left" return make_meta(left.merge(right, **kwargs)) + def _find_partition_changer(self, expr): + # Look for an operation that reorganizes the number of partitions + # We ignore Blockwise stuff and reducers + stack = [expr] + seen = set() + result_nodes = [] + while stack: + node = stack.pop() + if node._name in seen: + continue + seen.add(node._name) + + if isinstance(node, Reduction): + continue + elif node.ndim == 0 or node.npartitions == 1: + continue + elif isinstance(node, IO): + return node + elif isinstance(node, (Blockwise, MaybeAlignPartitions)): + stack.extend(node.dependencies()) + continue + + result_nodes.append(node) + if len(result_nodes): + # The node with the maximum number of partitions will most likely have + # dominated the resulting partition count + return list(sorted(result_nodes, key=lambda x: x.npartitions))[-1] + return expr + @functools.cached_property def _npartitions(self): if self.operand("_npartitions") is not None: @@ -144,14 +176,21 @@ def _npartitions(self): if self.left.npartitions <= self.right.npartitions: df_lower = self.left df_higher = self.right + merge_base_columns = self._find_partition_changer(self.right).columns else: df_lower = self.right df_higher = self.left + merge_base_columns = self._find_partition_changer(self.left).columns npartitions = df_higher.npartitions - factor = (len(df_lower.columns) + len(df_higher.columns)) / len( - df_higher.columns - ) - return int(npartitions * factor) + common_merge_columns = [] + if self.left_on is not None and self.right_on is not None: + common_merge_columns = set(_convert_to_list(self.left_on)) & set( + _convert_to_list(self.right_on) + ) + factor = ( + len(df_lower.columns) + len(df_higher.columns) - len(common_merge_columns) + ) / len(merge_base_columns) + return int(math.floor(npartitions * factor)) @property def _bcast_left(self): @@ -808,6 +847,12 @@ class BlockwiseMerge(Merge, Blockwise): is_broadcast_join = False + @functools.cached_property + def _npartitions(self): + if self.operand("_npartitions") is not None: + return self.operand("_npartitions") + return max(self.left.npartitions, self.right.npartitions) + def _divisions(self): if self.left.npartitions == self.right.npartitions: return super()._divisions() diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index 4c1813eac..281cb92e4 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -51,7 +51,7 @@ async def test_merge_p2p_shuffle(c, s, a, b, npartitions_left): right = from_pandas(df_right, npartitions=5) out = left.merge(right, shuffle_method="p2p") - assert out.npartitions == npartitions_left + assert out.npartitions == 8 x = c.compute(out) x = await x pd.testing.assert_frame_equal(x.reset_index(drop=True), df_left.merge(df_right)) @@ -88,7 +88,7 @@ async def test_merge_index_precedence(c, s, a, b, shuffle, name): result = df.join(df2, shuffle_method=shuffle) x = await c.compute(result) - assert result.npartitions == 3 + assert result.npartitions == 6 pd.testing.assert_frame_equal(x.sort_index(ascending=False), pdf.join(pdf2)) @@ -222,7 +222,7 @@ async def test_index_merge_p2p_shuffle(c, s, a, b, npartitions_left): right = from_pandas(df_right, npartitions=5) out = left.merge(right, left_index=True, right_on="a", shuffle_method="p2p") - assert out.npartitions == npartitions_left + assert out.npartitions == 7 if npartitions_left == 5 else 18 x = c.compute(out) x = await x pd.testing.assert_frame_equal( @@ -239,7 +239,7 @@ async def test_merge_p2p_shuffle(c, s, a, b): right = from_pandas(df_right, npartitions=5) out = left.merge(right, shuffle_method="p2p")[["b", "c"]] - assert out.npartitions == 6 + assert out.npartitions == 8 x = c.compute(out) x = await x pd.testing.assert_frame_equal( diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index 048ab3a15..56faa0fd9 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -231,8 +231,8 @@ def test_merge_combine_similar(npartitions_left, npartitions_right): query["new"] = query.b + query.c query = query.groupby(["a", "e", "x"]).new.sum() assert ( - len(query.optimize().__dask_graph__()) <= 25 - ) # 45 is the non-combined version + len(query.optimize().__dask_graph__()) <= 37 + ) # the non-combined version is higher expected = pdf.merge(pdf2) expected["new"] = expected.b + expected.c @@ -899,3 +899,43 @@ def test_merge_leftsemi(): df2 = from_pandas(pdf2, npartitions=2) with pytest.raises(NotImplementedError, match="on columns from the index"): df1.merge(df2, how="leftsemi", on="aa") + + +def test_merge_npartitions_adjustment(): + pdf1 = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 2, "b": 1, "c": 1, "d": 1} + ) + pdf2 = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 2, "b": 1, "x": 1, "y": 1} + ) + pdf3 = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 2, "b": 1, "m": 1, "n": 1} + ) + df1 = from_pandas(pdf1, npartitions=10) + df2 = from_pandas(pdf2, npartitions=10) + df3 = from_pandas(pdf3, npartitions=10) + result = df1.merge(df2, on="a") + assert result.optimize().npartitions == 17 + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = df1.merge(df2, left_on=["a", "c"], right_on=["b", "x"]) + assert result.optimize().npartitions == 20 + + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = result.dropna() # block projections + result = result[["a", "b"]].merge(df3) + assert result.optimize().npartitions == 10 + + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = result.dropna() # block projections + result = result[["a", "b", "x", "y"]].merge(df3) + assert result.optimize().npartitions == 15 + + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = result.dropna() # block projections + result = result + result.a.sum() + result = result[["a", "b", "x", "y"]].merge(df3) + assert result.optimize().npartitions == 15