66from __future__ import annotations
77
88from copy import deepcopy
9- from dataclasses import dataclass
10- from typing import TYPE_CHECKING , Protocol , runtime_checkable
9+ from dataclasses import dataclass , field
10+ from typing import TYPE_CHECKING , Any , Protocol , runtime_checkable
1111
1212if TYPE_CHECKING :
1313 from collections .abc import Callable
@@ -20,7 +20,11 @@ class ComposablePass(Protocol):
2020 """A Protocol which represents a composable Hugr transformation."""
2121
2222 def __call__ (self , hugr : Hugr , * , inplace : bool = True ) -> Hugr :
23- """Call the pass to transform a HUGR.
23+ """Call the pass to transform a HUGR, returning a Hugr."""
24+ return self .run (hugr , inplace = inplace ).hugr
25+
26+ def run (self , hugr : Hugr , * , inplace : bool = True ) -> PassResult :
27+ """Run the pass to transform a HUGR, returning a PassResult.
2428
2529 See :func:`_impl_pass_call` for a helper function to implement this method.
2630 """
@@ -32,29 +36,17 @@ def name(self) -> str:
3236
3337 def then (self , other : ComposablePass ) -> ComposablePass :
3438 """Perform another composable pass after this pass."""
35- # Provide a default implementation for composing passes.
36- pass_list = []
37- if isinstance (self , ComposedPass ):
38- pass_list .extend (self .passes )
39- else :
40- pass_list .append (self )
41-
42- if isinstance (other , ComposedPass ):
43- pass_list .extend (other .passes )
44- else :
45- pass_list .append (other )
46-
47- return ComposedPass (pass_list )
39+ return ComposedPass (self , other )
4840
4941
50- def impl_pass_call (
42+ def impl_pass_run (
5143 * ,
5244 hugr : Hugr ,
5345 inplace : bool ,
54- inplace_call : Callable [[Hugr ], None ] | None = None ,
55- copy_call : Callable [[Hugr ], Hugr ] | None = None ,
56- ) -> Hugr :
57- """Helper function to implement a ComposablePass.__call__ method, given an
46+ inplace_call : Callable [[Hugr ], PassResult ] | None = None ,
47+ copy_call : Callable [[Hugr ], PassResult ] | None = None ,
48+ ) -> PassResult :
49+ """Helper function to implement a ComposablePass.run method, given an
5850 inplace or copy-returning pass methods.
5951
6052 At least one of the `inplace_call` or `copy_call` arguments must be provided.
@@ -63,21 +55,25 @@ def impl_pass_call(
6355 :param inplace: Whether to apply the pass inplace.
6456 :param inplace_call: The method to apply the pass inplace.
6557 :param copy_call: The method to apply the pass by copying the Hugr.
66- :return: The transformed Hugr.
58+ :return: The result of the pass application.
59+ :raises ValueError: If neither `inplace_call` nor `copy_call` is provided.
6760 """
6861 if inplace and inplace_call is not None :
69- inplace_call (hugr )
70- return hugr
62+ return inplace_call (hugr )
7163 elif inplace and copy_call is not None :
72- new_hugr = copy_call (hugr )
73- hugr ._overwrite_hugr (new_hugr )
74- return hugr
64+ pass_result = copy_call (hugr )
65+ pass_result .hugr = hugr
66+ if pass_result .modified :
67+ hugr ._overwrite_hugr (pass_result .hugr )
68+ pass_result .original_dirty = True
69+ return pass_result
7570 elif not inplace and copy_call is not None :
7671 return copy_call (hugr )
7772 elif not inplace and inplace_call is not None :
7873 new_hugr = deepcopy (hugr )
79- inplace_call (new_hugr )
80- return new_hugr
74+ pass_result = inplace_call (new_hugr )
75+ pass_result .original_dirty = False
76+ return pass_result
8177 else :
8278 msg = "Pass must implement at least an inplace or copy run method"
8379 raise ValueError (msg )
@@ -89,24 +85,89 @@ class ComposedPass(ComposablePass):
8985
9086 passes : list [ComposablePass ]
9187
92- def __call__ (self , hugr : Hugr , * , inplace : bool = True ) -> Hugr :
93- def apply (hugr : Hugr ) -> Hugr :
94- result_hugr = hugr
95- for comp_pass in self .passes :
96- result_hugr = comp_pass (result_hugr , inplace = False )
97- return result_hugr
98-
99- def apply_inplace (hugr : Hugr ) -> None :
88+ def __init__ (self , * passes : ComposablePass ) -> None :
89+ self .passes = []
90+ for pass_ in passes :
91+ if isinstance (pass_ , ComposedPass ):
92+ self .passes .extend (pass_ .passes )
93+ else :
94+ self .passes .append (pass_ )
95+
96+ def run (self , hugr : Hugr , * , inplace : bool = True ) -> PassResult :
97+ def apply (hugr : Hugr ) -> PassResult :
98+ pass_result = PassResult (hugr = hugr )
10099 for comp_pass in self .passes :
101- comp_pass (hugr , inplace = True )
100+ new_result = comp_pass .run (pass_result .hugr , inplace = inplace )
101+ pass_result = pass_result .then (new_result )
102+ return pass_result
102103
103- return impl_pass_call (
104+ return impl_pass_run (
104105 hugr = hugr ,
105106 inplace = inplace ,
106- inplace_call = apply_inplace ,
107+ inplace_call = apply ,
107108 copy_call = apply ,
108109 )
109110
110111 @property
111112 def name (self ) -> str :
112113 return f"Composed({ ', ' .join (pass_ .name for pass_ in self .passes ) } )"
114+
115+
116+ @dataclass
117+ class PassResult :
118+ """The result of a series of composed passes applied to a HUGR.
119+
120+ Includes a flag indicating whether the passes modified the HUGR, and an
121+ arbitrary result object for each pass.
122+
123+ In some cases, `modified` may be set to `True` even if the pass did not
124+ modify the program.
125+
126+ :attr hugr: The transformed Hugr.
127+ :attr original_dirty: Whether the original HUGR was modified by the pass.
128+ :attr modified: Whether the pass made changes to the HUGR.
129+ :attr results: The result of each applied pass, as a tuple of the pass and
130+ the result.
131+ """
132+
133+ hugr : Hugr
134+ original_dirty : bool = False
135+ modified : bool = False
136+ results : list [tuple [ComposablePass , Any ]] = field (default_factory = list )
137+
138+ @classmethod
139+ def for_pass (
140+ cls ,
141+ pass_ : ComposablePass ,
142+ hugr : Hugr ,
143+ * ,
144+ result : Any ,
145+ inline : bool ,
146+ modified : bool = True ,
147+ ) -> PassResult :
148+ """Create a new PassResult after a pass application.
149+
150+ :param hugr: The Hugr that was transformed.
151+ :param pass_: The pass that was applied.
152+ :param result: The result of the pass application.
153+ :param inline: Whether the pass was applied inplace.
154+ :param modified: Whether the pass modified the HUGR.
155+ """
156+ return cls (
157+ hugr = hugr ,
158+ original_dirty = inline and modified ,
159+ modified = modified ,
160+ results = [(pass_ , result )],
161+ )
162+
163+ def then (self , other : PassResult ) -> PassResult :
164+ """Extend the PassResult with the results of another PassResult.
165+
166+ Keeps the hugr returned by the last pass.
167+ """
168+ return PassResult (
169+ hugr = other .hugr ,
170+ original_dirty = self .original_dirty or other .original_dirty ,
171+ modified = self .modified or other .modified ,
172+ results = self .results + other .results ,
173+ )
0 commit comments