Skip to content

Commit 859c811

Browse files
committed
idea: Alternative to multiple ComposablePass apply methods
1 parent 8037052 commit 859c811

File tree

2 files changed

+64
-26
lines changed

2 files changed

+64
-26
lines changed

hugr-py/src/hugr/passes/_composable_pass.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from typing import TYPE_CHECKING, Protocol, runtime_checkable
1111

1212
if TYPE_CHECKING:
13+
from collections.abc import Callable
14+
1315
from hugr.hugr.base import Hugr
1416

1517

@@ -18,22 +20,10 @@ class ComposablePass(Protocol):
1820
"""A Protocol which represents a composable Hugr transformation."""
1921

2022
def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr:
21-
"""Call the pass to transform a HUGR."""
22-
if inplace:
23-
self._apply_inplace(hugr)
24-
return hugr
25-
else:
26-
return self._apply(hugr)
27-
28-
# At least one of the following _apply methods must be overriden
29-
def _apply(self, hugr: Hugr) -> Hugr:
30-
hugr = deepcopy(hugr)
31-
self._apply_inplace(hugr)
32-
return hugr
23+
"""Call the pass to transform a HUGR.
3324
34-
def _apply_inplace(self, hugr: Hugr) -> None:
35-
new_hugr = self._apply(hugr)
36-
hugr._overwrite_hugr(new_hugr)
25+
See :func:`_impl_pass_call` for a helper function to implement this method.
26+
"""
3727

3828
@property
3929
def name(self) -> str:
@@ -57,21 +47,65 @@ def then(self, other: ComposablePass) -> ComposablePass:
5747
return ComposedPass(pass_list)
5848

5949

50+
def impl_pass_call(
51+
*,
52+
hugr: Hugr,
53+
inplace: bool,
54+
inplace_call: Callable[[Hugr], None] | None = None,
55+
copy_call: Callable[[Hugr], Hugr] | None = None,
56+
) -> Hugr:
57+
"""Helper function to implement a ComposablePass.__call__ method, given an
58+
inplace or copy-returning pass methods.
59+
60+
At least one of the `inplace_call` or `copy_call` arguments must be provided.
61+
62+
:param hugr: The Hugr to apply the pass to.
63+
:param inplace: Whether to apply the pass inplace.
64+
:param inplace_call: The method to apply the pass inplace.
65+
:param copy_call: The method to apply the pass by copying the Hugr.
66+
:return: The transformed Hugr.
67+
"""
68+
if inplace and inplace_call is not None:
69+
inplace_call(hugr)
70+
return hugr
71+
elif inplace and copy_call is not None:
72+
new_hugr = copy_call(hugr)
73+
hugr._overwrite_hugr(new_hugr)
74+
return hugr
75+
elif not inplace and copy_call is not None:
76+
return copy_call(hugr)
77+
elif not inplace and inplace_call is not None:
78+
new_hugr = deepcopy(hugr)
79+
inplace_call(new_hugr)
80+
return new_hugr
81+
else:
82+
msg = "Pass must implement at least an inplace or copy run method"
83+
raise ValueError(msg)
84+
85+
6086
@dataclass
6187
class ComposedPass(ComposablePass):
6288
"""A sequence of composable passes."""
6389

6490
passes: list[ComposablePass]
6591

66-
def _apply(self, hugr: Hugr) -> Hugr:
67-
result_hugr = hugr
68-
for comp_pass in self.passes:
69-
result_hugr = comp_pass(result_hugr, inplace=False)
70-
return result_hugr
71-
72-
def _apply_inplace(self, hugr: Hugr) -> None:
73-
for comp_pass in self.passes:
74-
comp_pass(hugr, inplace=True)
92+
def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr:
93+
def apply(hugr: Hugr) -> Hugr:
94+
result_hugr = hugr
95+
for comp_pass in self.passes:
96+
result_hugr = comp_pass(result_hugr, inplace=False)
97+
return result_hugr
98+
99+
def apply_inplace(hugr: Hugr) -> None:
100+
for comp_pass in self.passes:
101+
comp_pass(hugr, inplace=True)
102+
103+
return impl_pass_call(
104+
hugr=hugr,
105+
inplace=inplace,
106+
inplace_call=apply_inplace,
107+
copy_call=apply,
108+
)
75109

76110
@property
77111
def name(self) -> str:

hugr-py/tests/test_passes.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from hugr.hugr.base import Hugr
2-
from hugr.passes._composable_pass import ComposablePass, ComposedPass
2+
from hugr.passes._composable_pass import ComposablePass, ComposedPass, impl_pass_call
33

44

55
def test_composable_pass() -> None:
66
class MyDummyPass(ComposablePass):
77
def __call__(self, hugr: Hugr, inplace: bool = True) -> Hugr:
8-
return self(hugr, inplace)
8+
return impl_pass_call(
9+
hugr=hugr,
10+
inplace=inplace,
11+
inplace_call=lambda hugr: None,
12+
)
913

1014
dummy = MyDummyPass()
1115

0 commit comments

Comments
 (0)