1010from typing import TYPE_CHECKING , Protocol , runtime_checkable
1111
1212if TYPE_CHECKING :
13+ from collections .abc import Callable
14+
1315 from hugr .hugr .base import Hugr
1416
1517
@@ -18,22 +20,10 @@ class ComposablePass(Protocol):
1820 """A Protocol which represents a composable Hugr transformation."""
1921
2022 def __call__ (self , hugr : Hugr , * , inplace : bool = True ) -> Hugr :
21- """Call the pass to transform a HUGR."""
22- if inplace :
23- self ._apply_inplace (hugr )
24- return hugr
25- else :
26- return self ._apply (hugr )
27-
28- # At least one of the following _apply methods must be overriden
29- def _apply (self , hugr : Hugr ) -> Hugr :
30- hugr = deepcopy (hugr )
31- self ._apply_inplace (hugr )
32- return hugr
23+ """Call the pass to transform a HUGR.
3324
34- def _apply_inplace (self , hugr : Hugr ) -> None :
35- new_hugr = self ._apply (hugr )
36- hugr ._overwrite_hugr (new_hugr )
25+ See :func:`_impl_pass_call` for a helper function to implement this method.
26+ """
3727
3828 @property
3929 def name (self ) -> str :
@@ -57,21 +47,65 @@ def then(self, other: ComposablePass) -> ComposablePass:
5747 return ComposedPass (pass_list )
5848
5949
50+ def impl_pass_call (
51+ * ,
52+ hugr : Hugr ,
53+ inplace : bool ,
54+ inplace_call : Callable [[Hugr ], None ] | None = None ,
55+ copy_call : Callable [[Hugr ], Hugr ] | None = None ,
56+ ) -> Hugr :
57+ """Helper function to implement a ComposablePass.__call__ method, given an
58+ inplace or copy-returning pass methods.
59+
60+ At least one of the `inplace_call` or `copy_call` arguments must be provided.
61+
62+ :param hugr: The Hugr to apply the pass to.
63+ :param inplace: Whether to apply the pass inplace.
64+ :param inplace_call: The method to apply the pass inplace.
65+ :param copy_call: The method to apply the pass by copying the Hugr.
66+ :return: The transformed Hugr.
67+ """
68+ if inplace and inplace_call is not None :
69+ inplace_call (hugr )
70+ return hugr
71+ elif inplace and copy_call is not None :
72+ new_hugr = copy_call (hugr )
73+ hugr ._overwrite_hugr (new_hugr )
74+ return hugr
75+ elif not inplace and copy_call is not None :
76+ return copy_call (hugr )
77+ elif not inplace and inplace_call is not None :
78+ new_hugr = deepcopy (hugr )
79+ inplace_call (new_hugr )
80+ return new_hugr
81+ else :
82+ msg = "Pass must implement at least an inplace or copy run method"
83+ raise ValueError (msg )
84+
85+
6086@dataclass
6187class ComposedPass (ComposablePass ):
6288 """A sequence of composable passes."""
6389
6490 passes : list [ComposablePass ]
6591
66- def _apply (self , hugr : Hugr ) -> Hugr :
67- result_hugr = hugr
68- for comp_pass in self .passes :
69- result_hugr = comp_pass (result_hugr , inplace = False )
70- return result_hugr
71-
72- def _apply_inplace (self , hugr : Hugr ) -> None :
73- for comp_pass in self .passes :
74- comp_pass (hugr , inplace = True )
92+ def __call__ (self , hugr : Hugr , * , inplace : bool = True ) -> Hugr :
93+ def apply (hugr : Hugr ) -> Hugr :
94+ result_hugr = hugr
95+ for comp_pass in self .passes :
96+ result_hugr = comp_pass (result_hugr , inplace = False )
97+ return result_hugr
98+
99+ def apply_inplace (hugr : Hugr ) -> None :
100+ for comp_pass in self .passes :
101+ comp_pass (hugr , inplace = True )
102+
103+ return impl_pass_call (
104+ hugr = hugr ,
105+ inplace = inplace ,
106+ inplace_call = apply_inplace ,
107+ copy_call = apply ,
108+ )
75109
76110 @property
77111 def name (self ) -> str :
0 commit comments