Skip to content

Commit d79a031

Browse files
committed
feat: PassResult definition
1 parent 859c811 commit d79a031

File tree

2 files changed

+117
-45
lines changed

2 files changed

+117
-45
lines changed

hugr-py/src/hugr/passes/_composable_pass.py

Lines changed: 101 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from __future__ import annotations
77

88
from copy import deepcopy
9-
from dataclasses import dataclass
10-
from typing import TYPE_CHECKING, Protocol, runtime_checkable
9+
from dataclasses import dataclass, field
10+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
1111

1212
if TYPE_CHECKING:
1313
from collections.abc import Callable
@@ -20,7 +20,11 @@ class ComposablePass(Protocol):
2020
"""A Protocol which represents a composable Hugr transformation."""
2121

2222
def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr:
23-
"""Call the pass to transform a HUGR.
23+
"""Call the pass to transform a HUGR, returning a Hugr."""
24+
return self.run(hugr, inplace=inplace).hugr
25+
26+
def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult:
27+
"""Run the pass to transform a HUGR, returning a PassResult.
2428
2529
See :func:`_impl_pass_call` for a helper function to implement this method.
2630
"""
@@ -32,29 +36,17 @@ def name(self) -> str:
3236

3337
def then(self, other: ComposablePass) -> ComposablePass:
3438
"""Perform another composable pass after this pass."""
35-
# Provide a default implementation for composing passes.
36-
pass_list = []
37-
if isinstance(self, ComposedPass):
38-
pass_list.extend(self.passes)
39-
else:
40-
pass_list.append(self)
41-
42-
if isinstance(other, ComposedPass):
43-
pass_list.extend(other.passes)
44-
else:
45-
pass_list.append(other)
46-
47-
return ComposedPass(pass_list)
39+
return ComposedPass(self, other)
4840

4941

50-
def impl_pass_call(
42+
def impl_pass_run(
5143
*,
5244
hugr: Hugr,
5345
inplace: bool,
54-
inplace_call: Callable[[Hugr], None] | None = None,
55-
copy_call: Callable[[Hugr], Hugr] | None = None,
56-
) -> Hugr:
57-
"""Helper function to implement a ComposablePass.__call__ method, given an
46+
inplace_call: Callable[[Hugr], PassResult] | None = None,
47+
copy_call: Callable[[Hugr], PassResult] | None = None,
48+
) -> PassResult:
49+
"""Helper function to implement a ComposablePass.run method, given an
5850
inplace or copy-returning pass methods.
5951
6052
At least one of the `inplace_call` or `copy_call` arguments must be provided.
@@ -63,21 +55,25 @@ def impl_pass_call(
6355
:param inplace: Whether to apply the pass inplace.
6456
:param inplace_call: The method to apply the pass inplace.
6557
:param copy_call: The method to apply the pass by copying the Hugr.
66-
:return: The transformed Hugr.
58+
:return: The result of the pass application.
59+
:raises ValueError: If neither `inplace_call` nor `copy_call` is provided.
6760
"""
6861
if inplace and inplace_call is not None:
69-
inplace_call(hugr)
70-
return hugr
62+
return inplace_call(hugr)
7163
elif inplace and copy_call is not None:
72-
new_hugr = copy_call(hugr)
73-
hugr._overwrite_hugr(new_hugr)
74-
return hugr
64+
pass_result = copy_call(hugr)
65+
pass_result.hugr = hugr
66+
if pass_result.modified:
67+
hugr._overwrite_hugr(pass_result.hugr)
68+
pass_result.original_dirty = True
69+
return pass_result
7570
elif not inplace and copy_call is not None:
7671
return copy_call(hugr)
7772
elif not inplace and inplace_call is not None:
7873
new_hugr = deepcopy(hugr)
79-
inplace_call(new_hugr)
80-
return new_hugr
74+
pass_result = inplace_call(new_hugr)
75+
pass_result.original_dirty = False
76+
return pass_result
8177
else:
8278
msg = "Pass must implement at least an inplace or copy run method"
8379
raise ValueError(msg)
@@ -89,24 +85,89 @@ class ComposedPass(ComposablePass):
8985

9086
passes: list[ComposablePass]
9187

92-
def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr:
93-
def apply(hugr: Hugr) -> Hugr:
94-
result_hugr = hugr
95-
for comp_pass in self.passes:
96-
result_hugr = comp_pass(result_hugr, inplace=False)
97-
return result_hugr
98-
99-
def apply_inplace(hugr: Hugr) -> None:
88+
def __init__(self, *passes: ComposablePass) -> None:
89+
self.passes = []
90+
for pass_ in passes:
91+
if isinstance(pass_, ComposedPass):
92+
self.passes.extend(pass_.passes)
93+
else:
94+
self.passes.append(pass_)
95+
96+
def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult:
97+
def apply(hugr: Hugr) -> PassResult:
98+
pass_result = PassResult(hugr=hugr)
10099
for comp_pass in self.passes:
101-
comp_pass(hugr, inplace=True)
100+
new_result = comp_pass.run(pass_result.hugr, inplace=inplace)
101+
pass_result = pass_result.then(new_result)
102+
return pass_result
102103

103-
return impl_pass_call(
104+
return impl_pass_run(
104105
hugr=hugr,
105106
inplace=inplace,
106-
inplace_call=apply_inplace,
107+
inplace_call=apply,
107108
copy_call=apply,
108109
)
109110

110111
@property
111112
def name(self) -> str:
112113
return f"Composed({ ', '.join(pass_.name for pass_ in self.passes) })"
114+
115+
116+
@dataclass
117+
class PassResult:
118+
"""The result of a series of composed passes applied to a HUGR.
119+
120+
Includes a flag indicating whether the passes modified the HUGR, and an
121+
arbitrary result object for each pass.
122+
123+
In some cases, `modified` may be set to `True` even if the pass did not
124+
modify the program.
125+
126+
:attr hugr: The transformed Hugr.
127+
:attr original_dirty: Whether the original HUGR was modified by the pass.
128+
:attr modified: Whether the pass made changes to the HUGR.
129+
:attr results: The result of each applied pass, as a tuple of the pass and
130+
the result.
131+
"""
132+
133+
hugr: Hugr
134+
original_dirty: bool = False
135+
modified: bool = False
136+
results: list[tuple[ComposablePass, Any]] = field(default_factory=list)
137+
138+
@classmethod
139+
def for_pass(
140+
cls,
141+
pass_: ComposablePass,
142+
hugr: Hugr,
143+
*,
144+
result: Any,
145+
inline: bool,
146+
modified: bool = True,
147+
) -> PassResult:
148+
"""Create a new PassResult after a pass application.
149+
150+
:param hugr: The Hugr that was transformed.
151+
:param pass_: The pass that was applied.
152+
:param result: The result of the pass application.
153+
:param inline: Whether the pass was applied inplace.
154+
:param modified: Whether the pass modified the HUGR.
155+
"""
156+
return cls(
157+
hugr=hugr,
158+
original_dirty=inline and modified,
159+
modified=modified,
160+
results=[(pass_, result)],
161+
)
162+
163+
def then(self, other: PassResult) -> PassResult:
164+
"""Extend the PassResult with the results of another PassResult.
165+
166+
Keeps the hugr returned by the last pass.
167+
"""
168+
return PassResult(
169+
hugr=other.hugr,
170+
original_dirty=self.original_dirty or other.original_dirty,
171+
modified=self.modified or other.modified,
172+
results=self.results + other.results,
173+
)

hugr-py/tests/test_passes.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
from hugr.hugr.base import Hugr
2-
from hugr.passes._composable_pass import ComposablePass, ComposedPass, impl_pass_call
2+
from hugr.passes._composable_pass import (
3+
ComposablePass,
4+
ComposedPass,
5+
PassResult,
6+
impl_pass_run,
7+
)
38

49

510
def test_composable_pass() -> None:
611
class MyDummyPass(ComposablePass):
7-
def __call__(self, hugr: Hugr, inplace: bool = True) -> Hugr:
8-
return impl_pass_call(
12+
def run(self, hugr: Hugr, inplace: bool = True) -> PassResult:
13+
return impl_pass_run(
914
hugr=hugr,
1015
inplace=inplace,
11-
inplace_call=lambda hugr: None,
16+
inplace_call=lambda hugr: PassResult.for_pass(
17+
self,
18+
hugr,
19+
result=None,
20+
inline=True,
21+
modified=False,
22+
),
1223
)
1324

1425
dummy = MyDummyPass()
1526

1627
composed_dummies = dummy.then(dummy)
1728

18-
my_composed_pass = ComposedPass([dummy, dummy])
29+
my_composed_pass = ComposedPass(dummy, dummy)
1930
assert my_composed_pass.passes == [dummy, dummy]
2031

2132
assert isinstance(composed_dummies, ComposablePass)

0 commit comments

Comments
 (0)