-
Notifications
You must be signed in to change notification settings - Fork 13
feat: Result type for ComposablePasses #2703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0c4246f
63e120a
7481efb
e83887c
0e1ffa9
fe94fcc
71a0a86
e29aafd
1b6559c
e56cd7d
bfb6f26
8037052
859c811
d79a031
b3eabde
97e5406
07caa46
a1eebb0
ad8ed71
44f5fa3
21ee55a
be1cad4
8b9b59b
4c4cdcd
1e64469
1794edc
19662f5
7cf550d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,81 +6,88 @@ | |
| from __future__ import annotations | ||
|
|
||
| from copy import deepcopy | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Protocol, runtime_checkable | ||
| from dataclasses import dataclass, field | ||
| from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Callable | ||
|
|
||
| from hugr.hugr.base import Hugr | ||
|
|
||
|
|
||
| # Type alias for a pass name | ||
| PassName = str | ||
|
|
||
|
|
||
| @runtime_checkable | ||
| class ComposablePass(Protocol): | ||
| """A Protocol which represents a composable Hugr transformation.""" | ||
|
|
||
| def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr: | ||
| """Call the pass to transform a HUGR. | ||
| """Call the pass to transform a HUGR, returning a Hugr.""" | ||
| return self.run(hugr, inplace=inplace).hugr | ||
|
|
||
| def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult: | ||
| """Run the pass to transform a HUGR, returning a PassResult. | ||
|
|
||
| See :func:`_impl_pass_call` for a helper function to implement this method. | ||
| See :func:`implement_pass_run` for a helper function to implement this method. | ||
| """ | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| def name(self) -> PassName: | ||
| """Returns the name of the pass.""" | ||
| return self.__class__.__name__ | ||
|
|
||
| def then(self, other: ComposablePass) -> ComposablePass: | ||
| """Perform another composable pass after this pass.""" | ||
| # Provide a default implementation for composing passes. | ||
| pass_list = [] | ||
| if isinstance(self, ComposedPass): | ||
| pass_list.extend(self.passes) | ||
| else: | ||
| pass_list.append(self) | ||
| return ComposedPass(self, other) | ||
|
|
||
| if isinstance(other, ComposedPass): | ||
| pass_list.extend(other.passes) | ||
| else: | ||
| pass_list.append(other) | ||
|
|
||
| return ComposedPass(pass_list) | ||
|
|
||
|
|
||
| def impl_pass_call( | ||
| def implement_pass_run( | ||
| composable_pass: ComposablePass, | ||
| *, | ||
| hugr: Hugr, | ||
| inplace: bool, | ||
| inplace_call: Callable[[Hugr], None] | None = None, | ||
| copy_call: Callable[[Hugr], Hugr] | None = None, | ||
| ) -> Hugr: | ||
| """Helper function to implement a ComposablePass.__call__ method, given an | ||
| inplace or copy-returning pass methods. | ||
| inplace_call: Callable[[Hugr], PassResult] | None = None, | ||
| copy_call: Callable[[Hugr], PassResult] | None = None, | ||
| ) -> PassResult: | ||
| """Helper function to implement a ComposablePass.run method, given an | ||
| inplace or copy-returning pass method. | ||
|
|
||
| At least one of the `inplace_call` or `copy_call` arguments must be provided. | ||
|
|
||
| :param composable_pass: The pass being run. Used for error messages. | ||
| :param hugr: The Hugr to apply the pass to. | ||
| :param inplace: Whether to apply the pass inplace. | ||
| :param inplace_call: The method to apply the pass inplace. | ||
| :param copy_call: The method to apply the pass by copying the Hugr. | ||
| :return: The transformed Hugr. | ||
| :return: The result of the pass application. | ||
| :raises ValueError: If neither `inplace_call` nor `copy_call` is provided. | ||
| """ | ||
| if inplace and inplace_call is not None: | ||
| inplace_call(hugr) | ||
| return hugr | ||
| elif inplace and copy_call is not None: | ||
| new_hugr = copy_call(hugr) | ||
| hugr._overwrite_hugr(new_hugr) | ||
| return hugr | ||
| elif not inplace and copy_call is not None: | ||
| return copy_call(hugr) | ||
| elif not inplace and inplace_call is not None: | ||
| new_hugr = deepcopy(hugr) | ||
| inplace_call(new_hugr) | ||
| return new_hugr | ||
| else: | ||
| msg = "Pass must implement at least an inplace or copy run method" | ||
| raise ValueError(msg) | ||
| if inplace: | ||
| if inplace_call is not None: | ||
| return inplace_call(hugr) | ||
| elif copy_call is not None: | ||
| pass_result = copy_call(hugr) | ||
| pass_result.hugr = hugr | ||
| if pass_result.modified: | ||
| hugr._overwrite_hugr(pass_result.hugr) | ||
| pass_result.inplace = True | ||
| return pass_result | ||
| elif not inplace: | ||
| if copy_call is not None: | ||
| return copy_call(hugr) | ||
| elif inplace_call is not None: | ||
| new_hugr = deepcopy(hugr) | ||
| pass_result = inplace_call(new_hugr) | ||
| pass_result.inplace = False | ||
| return pass_result | ||
|
|
||
| msg = ( | ||
| f"{composable_pass.name} needs to implement at least " | ||
| + "an inplace or copy run method" | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -89,24 +96,92 @@ class ComposedPass(ComposablePass): | |
|
|
||
| passes: list[ComposablePass] | ||
|
|
||
| def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr: | ||
| def apply(hugr: Hugr) -> Hugr: | ||
| result_hugr = hugr | ||
| def __init__(self, *passes: ComposablePass) -> None: | ||
| self.passes = [] | ||
| for composable_pass in passes: | ||
| if isinstance(composable_pass, ComposedPass): | ||
| self.passes.extend(composable_pass.passes) | ||
| else: | ||
| self.passes.append(composable_pass) | ||
|
|
||
| def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think you can just do: Because
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's already the case, or I'm missing something? |
||
| def apply(inplace: bool, hugr: Hugr) -> PassResult: | ||
| pass_result = PassResult(hugr=hugr, inplace=inplace) | ||
| for comp_pass in self.passes: | ||
| result_hugr = comp_pass(result_hugr, inplace=False) | ||
| return result_hugr | ||
| new_result = comp_pass.run(pass_result.hugr, inplace=inplace) | ||
| pass_result = pass_result.then(new_result) | ||
| return pass_result | ||
|
|
||
| def apply_inplace(hugr: Hugr) -> None: | ||
| for comp_pass in self.passes: | ||
| comp_pass(hugr, inplace=True) | ||
|
|
||
| return impl_pass_call( | ||
| return implement_pass_run( | ||
| self, | ||
| hugr=hugr, | ||
| inplace=inplace, | ||
| inplace_call=apply_inplace, | ||
| copy_call=apply, | ||
| inplace_call=lambda hugr: apply(True, hugr), | ||
| copy_call=lambda hugr: apply(False, hugr), | ||
| ) | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return f"Composed({ ', '.join(pass_.name for pass_ in self.passes) })" | ||
| def name(self) -> PassName: | ||
| names = [composable_pass.name for composable_pass in self.passes] | ||
| return f"Composed({ ', '.join(names) })" | ||
|
|
||
|
|
||
| @dataclass | ||
| class PassResult: | ||
| """The result of a series of composed passes applied to a HUGR. | ||
|
|
||
| Includes a flag indicating whether the passes modified the HUGR, and an | ||
| arbitrary result object for each pass. | ||
|
|
||
| :attr hugr: The transformed Hugr. | ||
| :attr inplace: Whether the pass was applied inplace. | ||
| If this is `True`, `hugr` will be the same object passed as input. | ||
| If this is `False`, `hugr` will be an independent copy of the original Hugr. | ||
| :attr modified: Whether the pass made changes to the HUGR. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so if
whereas if
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Conditionally aliasing the output is a bug waiting to happen. If we say the output is a copying the object then we should always do that.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Ok, so - if
if
|
||
| If `False`, `hugr` will have the same contents as the original Hugr. | ||
| If `True`, no guarantees are made about the contents of `hugr`. | ||
| :attr results: The result of each applied pass, as a tuple of the pass name | ||
| and the result. | ||
| """ | ||
|
|
||
| hugr: Hugr | ||
| inplace: bool = False | ||
| modified: bool = False | ||
| results: list[tuple[PassName, Any]] = field(default_factory=list) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why we include the pass name here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mostly for debuggability, otherwise The result may also be serialized, so the original
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought including the pass name was quite a nice solution! |
||
|
|
||
| @classmethod | ||
| def for_pass( | ||
| cls, | ||
| composable_pass: ComposablePass, | ||
| hugr: Hugr, | ||
| *, | ||
| result: Any, | ||
| inplace: bool, | ||
| modified: bool = True, | ||
| ) -> PassResult: | ||
| """Create a new PassResult after a pass application. | ||
|
|
||
| :param hugr: The Hugr that was transformed. | ||
| :param composable_pass: The pass that was applied. | ||
| :param result: The result of the pass application. | ||
| :param inplace: Whether the pass was applied inplace. | ||
| :param modified: Whether the pass modified the HUGR. | ||
| """ | ||
| return cls( | ||
| hugr=hugr, | ||
| inplace=inplace, | ||
| modified=modified, | ||
| results=[(composable_pass.name, result)], | ||
| ) | ||
|
|
||
| def then(self, other: PassResult) -> PassResult: | ||
| """Extend the PassResult with the results of another PassResult. | ||
|
|
||
| Keeps the hugr returned by the last pass. | ||
| """ | ||
| return PassResult( | ||
| hugr=other.hugr, | ||
| inplace=self.inplace and other.inplace, | ||
| modified=self.modified or other.modified, | ||
| results=self.results + other.results, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,30 +1,105 @@ | ||
| from copy import deepcopy | ||
|
|
||
| import pytest | ||
|
|
||
| from hugr.hugr.base import Hugr | ||
| from hugr.passes._composable_pass import ComposablePass, ComposedPass, impl_pass_call | ||
| from hugr.passes._composable_pass import ( | ||
| ComposablePass, | ||
| ComposedPass, | ||
| PassResult, | ||
| implement_pass_run, | ||
| ) | ||
|
|
||
|
|
||
| def test_composable_pass() -> None: | ||
| class MyDummyPass(ComposablePass): | ||
| def __call__(self, hugr: Hugr, inplace: bool = True) -> Hugr: | ||
| return impl_pass_call( | ||
| class DummyInlinePass(ComposablePass): | ||
| def run(self, hugr: Hugr, inplace: bool = True) -> PassResult: | ||
| return implement_pass_run( | ||
| self, | ||
| hugr=hugr, | ||
| inplace=inplace, | ||
| inplace_call=lambda hugr: None, | ||
| inplace_call=lambda hugr: PassResult.for_pass( | ||
| self, | ||
| hugr, | ||
| result=None, | ||
| inplace=True, | ||
| # Say that we modified the HUGR even though we didn't | ||
| modified=True, | ||
| ), | ||
| ) | ||
|
|
||
| dummy = MyDummyPass() | ||
| class DummyCopyPass(ComposablePass): | ||
| def run(self, hugr: Hugr, inplace: bool = True) -> PassResult: | ||
| return implement_pass_run( | ||
| self, | ||
| hugr=hugr, | ||
| inplace=inplace, | ||
| copy_call=lambda hugr: PassResult.for_pass( | ||
| self, | ||
| deepcopy(hugr), | ||
| result=None, | ||
| inplace=False, | ||
| # Say that we modified the HUGR even though we didn't | ||
| modified=True, | ||
| ), | ||
| ) | ||
|
|
||
| composed_dummies = dummy.then(dummy) | ||
| dummy_inline = DummyInlinePass() | ||
| dummy_copy = DummyCopyPass() | ||
|
|
||
| my_composed_pass = ComposedPass([dummy, dummy]) | ||
| assert my_composed_pass.passes == [dummy, dummy] | ||
| composed_dummies = dummy_inline.then(dummy_copy) | ||
| assert isinstance(composed_dummies, ComposedPass) | ||
|
|
||
| assert isinstance(composed_dummies, ComposablePass) | ||
| assert composed_dummies == my_composed_pass | ||
| assert dummy_inline.name == "DummyInlinePass" | ||
| assert dummy_copy.name == "DummyCopyPass" | ||
| assert composed_dummies.name == "Composed(DummyInlinePass, DummyCopyPass)" | ||
| assert composed_dummies.then(dummy_inline).then(composed_dummies).name == ( | ||
| "Composed(" | ||
| + "DummyInlinePass, DummyCopyPass, " | ||
| + "DummyInlinePass, " | ||
| + "DummyInlinePass, DummyCopyPass)" | ||
| ) | ||
|
|
||
| assert dummy.name == "MyDummyPass" | ||
| assert composed_dummies.name == "Composed(MyDummyPass, MyDummyPass)" | ||
| # Apply the passes | ||
| hugr: Hugr = Hugr() | ||
| new_hugr = composed_dummies(hugr, inplace=False) | ||
| assert hugr == new_hugr | ||
| assert new_hugr is not hugr | ||
|
|
||
| assert ( | ||
| composed_dummies.then(my_composed_pass).name | ||
| == "Composed(MyDummyPass, MyDummyPass, MyDummyPass, MyDummyPass)" | ||
| ) | ||
| # Verify the pass results | ||
| hugr = Hugr() | ||
| inplace_result = composed_dummies.run(hugr, inplace=True) | ||
| assert inplace_result.modified | ||
| assert inplace_result.inplace | ||
| assert inplace_result.results == [ | ||
| ("DummyInlinePass", None), | ||
| ("DummyCopyPass", None), | ||
| ] | ||
| assert inplace_result.hugr is hugr | ||
|
|
||
| hugr = Hugr() | ||
| copy_result = composed_dummies.run(hugr, inplace=False) | ||
| assert copy_result.modified | ||
| assert not copy_result.inplace | ||
| assert copy_result.results == [ | ||
| ("DummyInlinePass", None), | ||
| ("DummyCopyPass", None), | ||
| ] | ||
| assert copy_result.hugr is not hugr | ||
|
|
||
|
|
||
| def test_invalid_composable_pass() -> None: | ||
| class DummyInvalidPass(ComposablePass): | ||
| def run(self, hugr: Hugr, inplace: bool = True) -> PassResult: | ||
| return implement_pass_run( | ||
| self, | ||
| hugr=hugr, | ||
| inplace=inplace, | ||
| ) | ||
|
|
||
| dummy_invalid = DummyInvalidPass() | ||
| with pytest.raises( | ||
| ValueError, | ||
| match="DummyInvalidPass needs to implement at least an inplace or copy run method", # noqa: E501 | ||
| ): | ||
| dummy_invalid.run(Hugr()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have the
impl_pass_runfunction as a helper for implementingComposablePass.runwhere is the__call__method actually used in the pass implementation?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a suggestion from Seyon. Most usecases just need the Hugr after the pass, so we provide a simple call method.
When we actually need to inspect the result/pass output we can use the other call.