Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: ComposablePass trait allowing sequencing and validation #1895

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f63c3d1
Basic composable passes: MapErr, (a,b) for sequence, Validating
acl-cqc Jan 28, 2025
f3a0cb1
Add validate_if_test, rename ComposablePass::E -> Err
acl-cqc Jan 28, 2025
cac9b14
default to cfg! extension_inference
acl-cqc Jan 28, 2025
d64c229
Make composable public, re-export only ComposablePass
acl-cqc Jan 28, 2025
fccb130
Update const_fold...add_entry_point still a TODO
acl-cqc Jan 28, 2025
c7b93ac
Remove add_entry_point....plan to revert later
acl-cqc Jan 28, 2025
fbd1304
Update remove_dead_funcs
acl-cqc Jan 28, 2025
7f27dc3
BREAKING: rm monomorphize, rename monomorphize_ref -> monomorphize
acl-cqc Jan 28, 2025
fba2042
Update monomorphize
acl-cqc Jan 28, 2025
0f09a49
Remove hugr-passes::validation
acl-cqc Jan 28, 2025
f7f3493
Merge remote-tracking branch 'origin/main' into acl/composable_pass
acl-cqc Jan 29, 2025
c6f446b
Do not validate extensions in tests :-(
acl-cqc Jan 29, 2025
dab7039
MonomorphizePass: don't derive Default; tests use monomorphize restor…
acl-cqc Jan 29, 2025
6f4f1b7
Merge 'origin/main' into acl/composable_pass
acl-cqc Mar 14, 2025
7532b13
Test sequencing, deriving Clone+PartialEq for ConstFoldError
acl-cqc Mar 14, 2025
5b94749
And test validation
acl-cqc Mar 14, 2025
2487143
Driveby: no need for ConstFoldContext to implement Deref - from earli…
acl-cqc Mar 14, 2025
395b8c3
redundant missing docs
acl-cqc Mar 14, 2025
c3e2e17
a few tweaks
acl-cqc Mar 14, 2025
c8653d0
fmt
acl-cqc Mar 14, 2025
bbbcd6f
sequence -> then, add note about Infallible
acl-cqc Mar 14, 2025
80cedfc
rename Err -> Error
acl-cqc Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 231 additions & 0 deletions hugr-passes/src/composable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
//! Compiler passes and utilities for composing them

use std::{error::Error, marker::PhantomData};

use hugr_core::hugr::{hugrmut::HugrMut, ValidationError};
use hugr_core::HugrView;
use itertools::Either;

/// An optimization pass that can be sequenced with another and/or wrapped
/// e.g. by [ValidatingPass]
pub trait ComposablePass: Sized {
type Error: Error;
fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error>;

fn map_err<E2: Error>(self, f: impl Fn(Self::Error) -> E2) -> impl ComposablePass<Error = E2> {
ErrMapper::new(self, f)
}

/// Returns a [ComposablePass] that does "`self` then `other`", so long as
/// `other::Err` maps into ours.
fn then<P: ComposablePass>(self, other: P) -> impl ComposablePass<Error = Self::Error>
where
P::Error: Into<Self::Error>,
{
(self, other.map_err(Into::into))
}

/// Returns a [ComposablePass] that does "`self` then `other`", combining
/// the two error types via `Either`
fn then_either<P: ComposablePass>(
self,
other: P,
) -> impl ComposablePass<Error = Either<Self::Error, P::Error>> {
(self.map_err(Either::Left), other.map_err(Either::Right))
}

// Note: in the short term another variant could be useful:
// fn then_inf(self, other: impl ComposablePass<Err=Infallible>) -> impl ComposablePass<Err = Self::Err>
// however this will become redundant when Infallible is replaced by ! (never_type)
// as (unlike Infallible) ! converts Into anything
}

struct ErrMapper<P, E, F>(P, F, PhantomData<E>);

impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, E, F> {
fn new(pass: P, err_fn: F) -> Self {
Self(pass, err_fn, PhantomData)
}
}

impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ComposablePass for ErrMapper<P, E, F> {
type Error = E;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error> {
self.0.run(hugr).map_err(&self.1)
}
}

impl<E: Error, P1: ComposablePass<Error = E>, P2: ComposablePass<Error = E>> ComposablePass
for (P1, P2)
{
type Error = E;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error> {
self.0.run(hugr)?;
self.1.run(hugr)
}
}

/// Error from a [ValidatingPass]
#[derive(thiserror::Error, Debug)]
pub enum ValidatePassError<E> {
#[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
Input {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
Output {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error(transparent)]
Underlying(#[from] E),
}

/// Runs an underlying pass, but with validation of the Hugr
/// both before and afterwards.
pub struct ValidatingPass<P>(P, bool);

impl<P: ComposablePass> ValidatingPass<P> {
pub fn new_default(underlying: P) -> Self {
// Self(underlying, cfg!(feature = "extension_inference"))
// Sadly, many tests fail with extension inference, hence:
Self(underlying, false)
}

pub fn new_validating_extensions(underlying: P) -> Self {
Self(underlying, true)
}

pub fn new(underlying: P, validate_extensions: bool) -> Self {
Self(underlying, validate_extensions)
}

fn validation_impl<E>(
&self,
hugr: &impl HugrView,
mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError<E>,
) -> Result<(), ValidatePassError<E>> {
match self.1 {
false => hugr.validate_no_extensions(),
true => hugr.validate(),
}
.map_err(|err| mk_err(err, hugr.mermaid_string()))
}
}

impl<P: ComposablePass> ComposablePass for ValidatingPass<P> {
type Error = ValidatePassError<P::Error>;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error> {
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
err,
pretty_hugr,
})?;
self.0.run(hugr).map_err(ValidatePassError::Underlying)?;
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
err,
pretty_hugr,
})
}
}

pub(crate) fn validate_if_test<P: ComposablePass>(
pass: P,
hugr: &mut impl HugrMut,
) -> Result<(), ValidatePassError<P::Error>> {
if cfg!(test) {
ValidatingPass::new_default(pass).run(hugr)
} else {
pass.run(hugr).map_err(ValidatePassError::Underlying)
}
}

#[cfg(test)]
mod test {
use std::convert::Infallible;

use hugr_core::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::ops::{handle::NodeHandle, Input, Output, DEFAULT_OPTYPE, DFG};
use hugr_core::{types::Signature, Hugr, HugrView, IncomingPort};
use itertools::Either;

use crate::composable::{ValidatePassError, ValidatingPass};
use crate::const_fold::{ConstFoldError, ConstantFoldPass};
use crate::DeadCodeElimPass;

use super::ComposablePass;

#[test]
fn test_sequence() {
let mut mb = ModuleBuilder::new();
let id1 = mb
.define_function("id1", Signature::new_endo(usize_t()))
.unwrap();
let inps = id1.input_wires();
let id1 = id1.finish_with_outputs(inps).unwrap();
let id2 = mb
.define_function("id2", Signature::new_endo(usize_t()))
.unwrap();
let inps = id2.input_wires();
let id2 = id2.finish_with_outputs(inps).unwrap();
let hugr = mb.finish_hugr().unwrap();

let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]);
let cfold =
ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]);

cfold.run(&mut hugr.clone()).unwrap();

let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE);
let r: Result<(), Either<Infallible, ConstFoldError>> = dce
.clone()
.then_either(cfold.clone())
.run(&mut hugr.clone());
assert_eq!(r, Err(Either::Right(exp_err.clone())));

let r: Result<(), ConstFoldError> = dce
.map_err(|inf| match inf {})
.then(cfold)
.run(&mut hugr.clone());
assert_eq!(r, Err(exp_err));
}

#[test]
fn test_validation() {
let mut h = Hugr::new(DFG {
signature: Signature::new(usize_t(), bool_t()),
});
let inp = h.add_node_with_parent(
h.root(),
Input {
types: usize_t().into(),
},
);
let outp = h.add_node_with_parent(
h.root(),
Output {
types: bool_t().into(),
},
);
h.connect(inp, 0, outp, 0);
let backup = h.clone();
let err = backup.validate().unwrap_err();

let no_inputs: [(IncomingPort, _); 0] = [];
let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs);
cfold.run(&mut h).unwrap();
assert_eq!(h, backup); // Did nothing

let r = ValidatingPass(cfold, false).run(&mut h);
assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err));
}
}
64 changes: 23 additions & 41 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,30 @@ use hugr_core::{
};
use value_handle::ValueHandle;

use crate::dataflow::{
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
TailLoopTermination,
};
use crate::dead_code::{DeadCodeElimPass, PreserveNode};
use crate::validation::{ValidatePassError, ValidationLevel};
use crate::ComposablePass;
use crate::{
composable::validate_if_test,
dataflow::{
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
TailLoopTermination,
},
};

#[derive(Debug, Clone, Default)]
/// A configuration for the Constant Folding pass.
pub struct ConstantFoldPass {
validation: ValidationLevel,
allow_increase_termination: bool,
/// Each outer key Node must be either:
/// - a FuncDefn child of the root, if the root is a module; or
/// - the root, if the root is not a Module
inputs: HashMap<Node, HashMap<IncomingPort, Value>>,
}

#[derive(Debug, Error)]
#[derive(Clone, Debug, Error, PartialEq)]
#[non_exhaustive]
/// Errors produced by [ConstantFoldPass].
pub enum ConstFoldError {
#[error(transparent)]
#[allow(missing_docs)]
ValidationError(#[from] ValidatePassError),
/// Error raised when a Node is specified as an entry-point but
/// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor
/// a [Conditional](OpType::Conditional).
Expand All @@ -53,12 +52,6 @@ pub enum ConstFoldError {
}

impl ConstantFoldPass {
/// Sets the validation level used before and after the pass is run
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their
/// result (if/when they do terminate) is either known or not needed.
///
Expand Down Expand Up @@ -90,9 +83,18 @@ impl ConstantFoldPass {
.extend(inputs.into_iter().map(|(p, v)| (p.into(), v)));
self
}
}

impl ComposablePass for ConstantFoldPass {
type Error = ConstFoldError;

/// Run the Constant Folding pass.
fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> {
///
/// # Errors
///
/// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs]
/// was of an invalid [OpType]
fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> {
let fresh_node = Node::from(portgraph::NodeIndex::new(
hugr.nodes().max().map_or(0, |n| n.index() + 1),
));
Expand Down Expand Up @@ -168,23 +170,10 @@ impl ConstantFoldPass {
}
})
})
.run(hugr)?;
.run(hugr)
.map_err(|inf| match inf {})?; // TODO use into_ok when available
Ok(())
}

/// Run the pass using this configuration.
///
/// # Errors
///
/// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards
/// (if [Self::validation_level] is set, or in tests)
///
/// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs]
/// was of an invalid OpType
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
self.validation
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
}
}

/// Exhaustively apply constant folding to a HUGR.
Expand All @@ -202,18 +191,11 @@ pub fn constant_fold_pass<H: HugrMut>(h: &mut H) {
} else {
c
};
c.run(h).unwrap()
validate_if_test(c, h).unwrap()
}

struct ConstFoldContext<'a, H>(&'a H);

impl<H: HugrView> std::ops::Deref for ConstFoldContext<'_, H> {
type Target = H;
fn deref(&self) -> &H {
self.0
}
}

impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldContext<'_, H> {
type Node = H::Node;

Expand Down Expand Up @@ -244,7 +226,7 @@ impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldCo
};
// Returning the function body as a value, here, would be sufficient for inlining IndirectCall
// but not for transforming to a direct Call.
let func = DescendantsGraph::<FuncID<true>>::try_new(&**self, node).ok()?;
let func = DescendantsGraph::<FuncID<true>>::try_new(self.0, node).ok()?;
Some(ValueHandle::new_const_hugr(
ConstLocation::Node(node),
Box::new(func.extract_hugr()),
Expand Down
1 change: 1 addition & 0 deletions hugr-passes/src/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node};

use crate::dataflow::{partial_from_const, DFContext, PartialValue};
use crate::ComposablePass as _;

use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle};

Expand Down
Loading
Loading