Skip to content

feat!: ComposablePass trait allowing sequencing and validation #1895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f63c3d1
Basic composable passes: MapErr, (a,b) for sequence, Validating
acl-cqc Jan 28, 2025
f3a0cb1
Add validate_if_test, rename ComposablePass::E -> Err
acl-cqc Jan 28, 2025
cac9b14
default to cfg! extension_inference
acl-cqc Jan 28, 2025
d64c229
Make composable public, re-export only ComposablePass
acl-cqc Jan 28, 2025
fccb130
Update const_fold...add_entry_point still a TODO
acl-cqc Jan 28, 2025
c7b93ac
Remove add_entry_point....plan to revert later
acl-cqc Jan 28, 2025
fbd1304
Update remove_dead_funcs
acl-cqc Jan 28, 2025
7f27dc3
BREAKING: rm monomorphize, rename monomorphize_ref -> monomorphize
acl-cqc Jan 28, 2025
fba2042
Update monomorphize
acl-cqc Jan 28, 2025
0f09a49
Remove hugr-passes::validation
acl-cqc Jan 28, 2025
f7f3493
Merge remote-tracking branch 'origin/main' into acl/composable_pass
acl-cqc Jan 29, 2025
c6f446b
Do not validate extensions in tests :-(
acl-cqc Jan 29, 2025
dab7039
MonomorphizePass: don't derive Default; tests use monomorphize restor…
acl-cqc Jan 29, 2025
6f4f1b7
Merge 'origin/main' into acl/composable_pass
acl-cqc Mar 14, 2025
7532b13
Test sequencing, deriving Clone+PartialEq for ConstFoldError
acl-cqc Mar 14, 2025
5b94749
And test validation
acl-cqc Mar 14, 2025
2487143
Driveby: no need for ConstFoldContext to implement Deref - from earli…
acl-cqc Mar 14, 2025
395b8c3
redundant missing docs
acl-cqc Mar 14, 2025
c3e2e17
a few tweaks
acl-cqc Mar 14, 2025
c8653d0
fmt
acl-cqc Mar 14, 2025
bbbcd6f
sequence -> then, add note about Infallible
acl-cqc Mar 14, 2025
80cedfc
rename Err -> Error
acl-cqc Mar 17, 2025
55a1f18
Merge remote-tracking branch 'origin/main' into acl/composable_pass
acl-cqc Apr 5, 2025
917b41f
Fix UntuplePass and ReplaceTypes (making CompasablePass) - adding Result
acl-cqc Apr 5, 2025
7254443
WIP IfTrueThen
acl-cqc Apr 5, 2025
b00a272
Combine then+then_either via trait ErrorCombiner, applies to IfTrueTh…
acl-cqc Apr 5, 2025
29f8095
Reorder, separate with ----- comments, some missing docs
acl-cqc Apr 6, 2025
1a9d28b
Merge remote-tracking branch 'origin/main' into acl/composable_pass
acl-cqc Apr 15, 2025
9d4b478
test_sequence => test_then, both orders
acl-cqc Apr 15, 2025
58b7326
IfTrueThen=>IfThen, test; derive PartialEq for UntupleResult
acl-cqc Apr 15, 2025
8a1e4c5
imports wtf git
acl-cqc Apr 15, 2025
5daab98
Fix all-features
acl-cqc Apr 15, 2025
4bbe0a6
Merge remote-tracking branch 'origin/main' into acl/composable_pass
acl-cqc Apr 16, 2025
0fc3ed6
Merge remote-tracking branch 'origin/release-rs-v0.16.0' into acl/com…
acl-cqc Apr 16, 2025
75fc36a
tidy imports, correct comment
acl-cqc Apr 16, 2025
b7b9569
inline change_type_then_untup
acl-cqc Apr 16, 2025
c543be3
clippy
acl-cqc Apr 16, 2025
662d765
Fix all-features by disabling extension validation
acl-cqc Apr 16, 2025
a597008
comment typo
acl-cqc Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 361 additions & 0 deletions hugr-passes/src/composable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
//! Compiler passes and utilities for composing them

use std::{error::Error, marker::PhantomData};

use hugr_core::hugr::{hugrmut::HugrMut, ValidationError};
use hugr_core::HugrView;
use itertools::Either;

/// An optimization pass that can be sequenced with another and/or wrapped
/// e.g. by [ValidatingPass]
pub trait ComposablePass: Sized {
type Error: Error;
type Result; // Would like to default to () but currently unstable

fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error>;

fn map_err<E2: Error>(
self,
f: impl Fn(Self::Error) -> E2,
) -> impl ComposablePass<Result = Self::Result, Error = E2> {
ErrMapper::new(self, f)
}

/// Returns a [ComposablePass] that does "`self` then `other`", so long as
/// `other::Err` can be combined with ours.
fn then<P: ComposablePass, E: ErrorCombiner<Self::Error, P::Error>>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a possible method then2 that requires E: ErrorCombiner<P::Error, Self::Error> i.e. that would allow when our error-type maps into P's (as opposed to then which requires the other way around).

However, both ways also allow returning an Either error and then2 would reverse that too (you probably wouldn't bother using then2 if you wanted an Either but). So then I wonder, should we also reverse the order of results...and the doubt as to what's the right API has lead me to avoid providing then2 entirely. But I could definitely be persuaded....

(Sadly we can't allow ErrorCombiner to work for both <A, B: Into<A>> and <A: Into<B>, B> because that would provide two conflicting impls when A and B are the same type :-( :-( )

self,
other: P,
) -> impl ComposablePass<Result = (Self::Result, P::Result), Error = E> {
struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
impl<E, P1, P2> ComposablePass for Sequence<E, P1, P2>
where
P1: ComposablePass,
P2: ComposablePass,
E: ErrorCombiner<P1::Error, P2::Error>,
{
type Error = E;

type Result = (P1::Result, P2::Result);

fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error> {
let res1 = self.0.run(hugr).map_err(E::from_first)?;
let res2 = self.1.run(hugr).map_err(E::from_second)?;
Ok((res1, res2))
}
}

Sequence(self, other, PhantomData)
}
}

/// Trait for combining the error types from two different passes
/// into a single error.
pub trait ErrorCombiner<A, B>: Error {
fn from_first(a: A) -> Self;
fn from_second(b: B) -> Self;
}

impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to have

impl<E: Error, A: Into<E>, B: Into<E>> ErrorCombiner<A, B> for E

instead, which covers the Either impl.
But that breaks the tests below as it doesn't allow us to have ErrorCombiner<A, Infallible> -.-

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FTR...AFAICS Either does not implement From from either left or right, e.g. I get two errors here (one on each .into()):

fn foo<A,B>(a:A, b:B, which: bool) -> Either<A,B> {
    if which {a.into()} else {b.into()}
}

fn from_first(a: A) -> Self {
a
}

fn from_second(b: B) -> Self {
b.into()
}
}

impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
fn from_first(a: A) -> Self {
Either::Left(a)
}

fn from_second(b: B) -> Self {
Either::Right(b)
}
}

// Note: in the short term we could wish for two more impls:
// impl<E:Error> ErrorCombiner<Infallible, E> for E
// impl<E:Error> ErrorCombiner<E, Infallible> for E
// however, these aren't possible as they conflict with
// impl<A, B:Into<A>> ErrorCombiner<A,B> for A
// when A=E=Infallible, boo :-(.
// However this will become possible, indeed automatic, when Infallible is replaced
// by ! (never_type) as (unlike Infallible) ! converts Into anything
Comment on lines +79 to +86
Copy link
Collaborator

@aborgna-q aborgna-q Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is covered by impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B>, no?
It is required to compose DeadCodeElimPass in test_then.


// ErrMapper ------------------------------
struct ErrMapper<P, E, F>(P, F, PhantomData<E>);

impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, E, F> {
fn new(pass: P, err_fn: F) -> Self {
Self(pass, err_fn, PhantomData)
}
}

impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ComposablePass for ErrMapper<P, E, F> {
type Error = E;
type Result = P::Result;

fn run(&self, hugr: &mut impl HugrMut) -> Result<P::Result, Self::Error> {
self.0.run(hugr).map_err(&self.1)
}
}

// ValidatingPass ------------------------------

/// Error from a [ValidatingPass]
#[derive(thiserror::Error, Debug)]
pub enum ValidatePassError<E> {
#[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
Input {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
Output {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error(transparent)]
Underlying(#[from] E),
}

/// Runs an underlying pass, but with validation of the Hugr
/// both before and afterwards.
pub struct ValidatingPass<P>(P, bool);

impl<P: ComposablePass> ValidatingPass<P> {
pub fn new_default(underlying: P) -> Self {
// Self(underlying, cfg!(feature = "extension_inference"))
// Sadly, many tests fail with extension inference, hence:
Self(underlying, false)
}

pub fn new_validating_extensions(underlying: P) -> Self {
Self(underlying, true)
}

pub fn new(underlying: P, validate_extensions: bool) -> Self {
Self(underlying, validate_extensions)
}

fn validation_impl<E>(
&self,
hugr: &impl HugrView,
mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError<E>,
) -> Result<(), ValidatePassError<E>> {
match self.1 {
false => hugr.validate_no_extensions(),
true => hugr.validate(),
}
.map_err(|err| mk_err(err, hugr.mermaid_string()))
}
}

impl<P: ComposablePass> ComposablePass for ValidatingPass<P> {
type Error = ValidatePassError<P::Error>;
type Result = P::Result;

fn run(&self, hugr: &mut impl HugrMut) -> Result<P::Result, Self::Error> {
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
err,
pretty_hugr,
})?;
let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?;
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
err,
pretty_hugr,
})?;
Ok(res)
}
}

// IfThen ------------------------------
/// [ComposablePass] that executes a first pass that returns a `bool`
/// result; and then, if-and-only-if that first result was true,
/// executes a second pass
pub struct IfThen<E, A, B>(A, B, PhantomData<E>);

impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Error, B::Error>>
IfThen<E, A, B>
{
/// Make a new instance given the [ComposablePass] to run first
/// and (maybe) second
pub fn new(fst: A, opt_snd: B) -> Self {
Self(fst, opt_snd, PhantomData)
}
}

impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Error, B::Error>>
ComposablePass for IfThen<E, A, B>
{
type Error = E;

type Result = Option<B::Result>;

fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error> {
let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
.transpose()
}
}

pub(crate) fn validate_if_test<P: ComposablePass>(
pass: P,
hugr: &mut impl HugrMut,
) -> Result<P::Result, ValidatePassError<P::Error>> {
if cfg!(test) {
ValidatingPass::new_default(pass).run(hugr)
} else {
pass.run(hugr).map_err(ValidatePassError::Underlying)
}
}

#[cfg(test)]
mod test {
use itertools::{Either, Itertools};
use std::convert::Infallible;

use hugr_core::builder::{
Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
ModuleBuilder,
};
use hugr_core::extension::prelude::{
bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID,
};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG};
use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
use hugr_core::types::{Signature, TypeRow};
use hugr_core::{Hugr, HugrView, IncomingPort};

use crate::const_fold::{ConstFoldError, ConstantFoldPass};
use crate::untuple::{UntupleRecursive, UntupleResult};
use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass};

use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass};

#[test]
fn test_then() {
let mut mb = ModuleBuilder::new();
let id1 = mb
.define_function("id1", Signature::new_endo(usize_t()))
.unwrap();
let inps = id1.input_wires();
let id1 = id1.finish_with_outputs(inps).unwrap();
let id2 = mb
.define_function("id2", Signature::new_endo(usize_t()))
.unwrap();
let inps = id2.input_wires();
let id2 = id2.finish_with_outputs(inps).unwrap();
let hugr = mb.finish_hugr().unwrap();

let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]);
let cfold =
ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]);

cfold.run(&mut hugr.clone()).unwrap();

let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE);
let r: Result<_, Either<Infallible, ConstFoldError>> =
dce.clone().then(cfold.clone()).run(&mut hugr.clone());
assert_eq!(r, Err(Either::Right(exp_err.clone())));

let r = dce
.clone()
.map_err(|inf| match inf {})
.then(cfold.clone())
.run(&mut hugr.clone());
assert_eq!(r, Err(exp_err));

let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone());
r2.unwrap();
}

#[test]
fn test_validation() {
let mut h = Hugr::new(DFG {
signature: Signature::new(usize_t(), bool_t()),
});
let inp = h.add_node_with_parent(
h.root(),
Input {
types: usize_t().into(),
},
);
let outp = h.add_node_with_parent(
h.root(),
Output {
types: bool_t().into(),
},
);
h.connect(inp, 0, outp, 0);
let backup = h.clone();
let err = backup.validate().unwrap_err();

let no_inputs: [(IncomingPort, _); 0] = [];
let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs);
cfold.run(&mut h).unwrap();
assert_eq!(h, backup); // Did nothing

let r = ValidatingPass(cfold, false).run(&mut h);
assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err));
}

#[test]
fn test_if_then() {
let tr = TypeRow::from(vec![usize_t(); 2]);

let h = {
let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID);
let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
let [a, b] = fb.input_wires_arr();
let tup = fb
.add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
.unwrap();
let untup = fb
.add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
.unwrap();
fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
};

let untup = UntuplePass::new(UntupleRecursive::Recursive);
{
// Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple
let mut repl = ReplaceTypes::default();
let usize_custom_t = usize_t().as_extension().unwrap().clone();
repl.replace_type(usize_custom_t, INT_TYPES[6].clone());
let ifthen = IfThen::<Either<_, _>, _, _>::new(repl, untup.clone());

let mut h = h.clone();
let r = validate_if_test(ifthen, &mut h).unwrap();
assert_eq!(
r,
Some(UntupleResult {
rewrites_applied: 1
})
);
let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap();
assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
}

// Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple
let mut repl = ReplaceTypes::default();
let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
repl.replace_type(i32_custom_t, INT_TYPES[6].clone());
let ifthen = IfThen::<Either<_, _>, _, _>::new(repl, untup);
let mut h = h;
let r = validate_if_test(ifthen, &mut h).unwrap();
assert_eq!(r, None);
assert_eq!(h.children(h.root()).count(), 4);
let mktup = h
.output_neighbours(h.first_child(h.root()).unwrap())
.next()
.unwrap();
assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
}
}
Loading
Loading