diff --git a/hugr-py/src/hugr/passes/_composable_pass.py b/hugr-py/src/hugr/passes/_composable_pass.py index 371e8206a..68d935f3b 100644 --- a/hugr-py/src/hugr/passes/_composable_pass.py +++ b/hugr-py/src/hugr/passes/_composable_pass.py @@ -6,8 +6,8 @@ from __future__ import annotations from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from collections.abc import Callable @@ -15,72 +15,79 @@ from hugr.hugr.base import Hugr +# Type alias for a pass name +PassName = str + + @runtime_checkable class ComposablePass(Protocol): """A Protocol which represents a composable Hugr transformation.""" def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr: - """Call the pass to transform a HUGR. + """Call the pass to transform a HUGR, returning a Hugr.""" + return self.run(hugr, inplace=inplace).hugr + + def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult: + """Run the pass to transform a HUGR, returning a PassResult. - See :func:`_impl_pass_call` for a helper function to implement this method. + See :func:`implement_pass_run` for a helper function to implement this method. """ @property - def name(self) -> str: + def name(self) -> PassName: """Returns the name of the pass.""" return self.__class__.__name__ def then(self, other: ComposablePass) -> ComposablePass: """Perform another composable pass after this pass.""" - # Provide a default implementation for composing passes. - pass_list = [] - if isinstance(self, ComposedPass): - pass_list.extend(self.passes) - else: - pass_list.append(self) + return ComposedPass(self, other) - if isinstance(other, ComposedPass): - pass_list.extend(other.passes) - else: - pass_list.append(other) - return ComposedPass(pass_list) - - -def impl_pass_call( +def implement_pass_run( + composable_pass: ComposablePass, *, hugr: Hugr, inplace: bool, - inplace_call: Callable[[Hugr], None] | None = None, - copy_call: Callable[[Hugr], Hugr] | None = None, -) -> Hugr: - """Helper function to implement a ComposablePass.__call__ method, given an - inplace or copy-returning pass methods. + inplace_call: Callable[[Hugr], PassResult] | None = None, + copy_call: Callable[[Hugr], PassResult] | None = None, +) -> PassResult: + """Helper function to implement a ComposablePass.run method, given an + inplace or copy-returning pass method. At least one of the `inplace_call` or `copy_call` arguments must be provided. + :param composable_pass: The pass being run. Used for error messages. :param hugr: The Hugr to apply the pass to. :param inplace: Whether to apply the pass inplace. :param inplace_call: The method to apply the pass inplace. :param copy_call: The method to apply the pass by copying the Hugr. - :return: The transformed Hugr. + :return: The result of the pass application. + :raises ValueError: If neither `inplace_call` nor `copy_call` is provided. """ - if inplace and inplace_call is not None: - inplace_call(hugr) - return hugr - elif inplace and copy_call is not None: - new_hugr = copy_call(hugr) - hugr._overwrite_hugr(new_hugr) - return hugr - elif not inplace and copy_call is not None: - return copy_call(hugr) - elif not inplace and inplace_call is not None: - new_hugr = deepcopy(hugr) - inplace_call(new_hugr) - return new_hugr - else: - msg = "Pass must implement at least an inplace or copy run method" - raise ValueError(msg) + if inplace: + if inplace_call is not None: + return inplace_call(hugr) + elif copy_call is not None: + pass_result = copy_call(hugr) + pass_result.hugr = hugr + if pass_result.modified: + hugr._overwrite_hugr(pass_result.hugr) + pass_result.inplace = True + return pass_result + elif not inplace: + if copy_call is not None: + return copy_call(hugr) + elif inplace_call is not None: + new_hugr = deepcopy(hugr) + pass_result = inplace_call(new_hugr) + pass_result.inplace = False + return pass_result + + msg = ( + f"{composable_pass.name} needs to implement at least " + + "an inplace or copy run method" + ) + raise ValueError(msg) @dataclass @@ -89,24 +96,92 @@ class ComposedPass(ComposablePass): passes: list[ComposablePass] - def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr: - def apply(hugr: Hugr) -> Hugr: - result_hugr = hugr + def __init__(self, *passes: ComposablePass) -> None: + self.passes = [] + for composable_pass in passes: + if isinstance(composable_pass, ComposedPass): + self.passes.extend(composable_pass.passes) + else: + self.passes.append(composable_pass) + + def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult: + def apply(inplace: bool, hugr: Hugr) -> PassResult: + pass_result = PassResult(hugr=hugr, inplace=inplace) for comp_pass in self.passes: - result_hugr = comp_pass(result_hugr, inplace=False) - return result_hugr + new_result = comp_pass.run(pass_result.hugr, inplace=inplace) + pass_result = pass_result.then(new_result) + return pass_result - def apply_inplace(hugr: Hugr) -> None: - for comp_pass in self.passes: - comp_pass(hugr, inplace=True) - - return impl_pass_call( + return implement_pass_run( + self, hugr=hugr, inplace=inplace, - inplace_call=apply_inplace, - copy_call=apply, + inplace_call=lambda hugr: apply(True, hugr), + copy_call=lambda hugr: apply(False, hugr), ) @property - def name(self) -> str: - return f"Composed({ ', '.join(pass_.name for pass_ in self.passes) })" + def name(self) -> PassName: + names = [composable_pass.name for composable_pass in self.passes] + return f"Composed({ ', '.join(names) })" + + +@dataclass +class PassResult: + """The result of a series of composed passes applied to a HUGR. + + Includes a flag indicating whether the passes modified the HUGR, and an + arbitrary result object for each pass. + + :attr hugr: The transformed Hugr. + :attr inplace: Whether the pass was applied inplace. + If this is `True`, `hugr` will be the same object passed as input. + If this is `False`, `hugr` will be an independent copy of the original Hugr. + :attr modified: Whether the pass made changes to the HUGR. + If `False`, `hugr` will have the same contents as the original Hugr. + If `True`, no guarantees are made about the contents of `hugr`. + :attr results: The result of each applied pass, as a tuple of the pass name + and the result. + """ + + hugr: Hugr + inplace: bool = False + modified: bool = False + results: list[tuple[PassName, Any]] = field(default_factory=list) + + @classmethod + def for_pass( + cls, + composable_pass: ComposablePass, + hugr: Hugr, + *, + result: Any, + inplace: bool, + modified: bool = True, + ) -> PassResult: + """Create a new PassResult after a pass application. + + :param hugr: The Hugr that was transformed. + :param composable_pass: The pass that was applied. + :param result: The result of the pass application. + :param inplace: Whether the pass was applied inplace. + :param modified: Whether the pass modified the HUGR. + """ + return cls( + hugr=hugr, + inplace=inplace, + modified=modified, + results=[(composable_pass.name, result)], + ) + + def then(self, other: PassResult) -> PassResult: + """Extend the PassResult with the results of another PassResult. + + Keeps the hugr returned by the last pass. + """ + return PassResult( + hugr=other.hugr, + inplace=self.inplace and other.inplace, + modified=self.modified or other.modified, + results=self.results + other.results, + ) diff --git a/hugr-py/tests/test_passes.py b/hugr-py/tests/test_passes.py index 64426dcda..d7e517c0b 100644 --- a/hugr-py/tests/test_passes.py +++ b/hugr-py/tests/test_passes.py @@ -1,30 +1,105 @@ +from copy import deepcopy + +import pytest + from hugr.hugr.base import Hugr -from hugr.passes._composable_pass import ComposablePass, ComposedPass, impl_pass_call +from hugr.passes._composable_pass import ( + ComposablePass, + ComposedPass, + PassResult, + implement_pass_run, +) def test_composable_pass() -> None: - class MyDummyPass(ComposablePass): - def __call__(self, hugr: Hugr, inplace: bool = True) -> Hugr: - return impl_pass_call( + class DummyInlinePass(ComposablePass): + def run(self, hugr: Hugr, inplace: bool = True) -> PassResult: + return implement_pass_run( + self, hugr=hugr, inplace=inplace, - inplace_call=lambda hugr: None, + inplace_call=lambda hugr: PassResult.for_pass( + self, + hugr, + result=None, + inplace=True, + # Say that we modified the HUGR even though we didn't + modified=True, + ), ) - dummy = MyDummyPass() + class DummyCopyPass(ComposablePass): + def run(self, hugr: Hugr, inplace: bool = True) -> PassResult: + return implement_pass_run( + self, + hugr=hugr, + inplace=inplace, + copy_call=lambda hugr: PassResult.for_pass( + self, + deepcopy(hugr), + result=None, + inplace=False, + # Say that we modified the HUGR even though we didn't + modified=True, + ), + ) - composed_dummies = dummy.then(dummy) + dummy_inline = DummyInlinePass() + dummy_copy = DummyCopyPass() - my_composed_pass = ComposedPass([dummy, dummy]) - assert my_composed_pass.passes == [dummy, dummy] + composed_dummies = dummy_inline.then(dummy_copy) + assert isinstance(composed_dummies, ComposedPass) - assert isinstance(composed_dummies, ComposablePass) - assert composed_dummies == my_composed_pass + assert dummy_inline.name == "DummyInlinePass" + assert dummy_copy.name == "DummyCopyPass" + assert composed_dummies.name == "Composed(DummyInlinePass, DummyCopyPass)" + assert composed_dummies.then(dummy_inline).then(composed_dummies).name == ( + "Composed(" + + "DummyInlinePass, DummyCopyPass, " + + "DummyInlinePass, " + + "DummyInlinePass, DummyCopyPass)" + ) - assert dummy.name == "MyDummyPass" - assert composed_dummies.name == "Composed(MyDummyPass, MyDummyPass)" + # Apply the passes + hugr: Hugr = Hugr() + new_hugr = composed_dummies(hugr, inplace=False) + assert hugr == new_hugr + assert new_hugr is not hugr - assert ( - composed_dummies.then(my_composed_pass).name - == "Composed(MyDummyPass, MyDummyPass, MyDummyPass, MyDummyPass)" - ) + # Verify the pass results + hugr = Hugr() + inplace_result = composed_dummies.run(hugr, inplace=True) + assert inplace_result.modified + assert inplace_result.inplace + assert inplace_result.results == [ + ("DummyInlinePass", None), + ("DummyCopyPass", None), + ] + assert inplace_result.hugr is hugr + + hugr = Hugr() + copy_result = composed_dummies.run(hugr, inplace=False) + assert copy_result.modified + assert not copy_result.inplace + assert copy_result.results == [ + ("DummyInlinePass", None), + ("DummyCopyPass", None), + ] + assert copy_result.hugr is not hugr + + +def test_invalid_composable_pass() -> None: + class DummyInvalidPass(ComposablePass): + def run(self, hugr: Hugr, inplace: bool = True) -> PassResult: + return implement_pass_run( + self, + hugr=hugr, + inplace=inplace, + ) + + dummy_invalid = DummyInvalidPass() + with pytest.raises( + ValueError, + match="DummyInvalidPass needs to implement at least an inplace or copy run method", # noqa: E501 + ): + dummy_invalid.run(Hugr())