Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: ComposablePass trait allowing sequencing and validation #1895

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f63c3d1
Basic composable passes: MapErr, (a,b) for sequence, Validating
acl-cqc Jan 28, 2025
f3a0cb1
Add validate_if_test, rename ComposablePass::E -> Err
acl-cqc Jan 28, 2025
cac9b14
default to cfg! extension_inference
acl-cqc Jan 28, 2025
d64c229
Make composable public, re-export only ComposablePass
acl-cqc Jan 28, 2025
fccb130
Update const_fold...add_entry_point still a TODO
acl-cqc Jan 28, 2025
c7b93ac
Remove add_entry_point....plan to revert later
acl-cqc Jan 28, 2025
fbd1304
Update remove_dead_funcs
acl-cqc Jan 28, 2025
7f27dc3
BREAKING: rm monomorphize, rename monomorphize_ref -> monomorphize
acl-cqc Jan 28, 2025
fba2042
Update monomorphize
acl-cqc Jan 28, 2025
0f09a49
Remove hugr-passes::validation
acl-cqc Jan 28, 2025
f7f3493
Merge remote-tracking branch 'origin/main' into acl/composable_pass
acl-cqc Jan 29, 2025
c6f446b
Do not validate extensions in tests :-(
acl-cqc Jan 29, 2025
dab7039
MonomorphizePass: don't derive Default; tests use monomorphize restor…
acl-cqc Jan 29, 2025
6f4f1b7
Merge 'origin/main' into acl/composable_pass
acl-cqc Mar 14, 2025
7532b13
Test sequencing, deriving Clone+PartialEq for ConstFoldError
acl-cqc Mar 14, 2025
5b94749
And test validation
acl-cqc Mar 14, 2025
2487143
Driveby: no need for ConstFoldContext to implement Deref - from earli…
acl-cqc Mar 14, 2025
395b8c3
redundant missing docs
acl-cqc Mar 14, 2025
c3e2e17
a few tweaks
acl-cqc Mar 14, 2025
c8653d0
fmt
acl-cqc Mar 14, 2025
bbbcd6f
sequence -> then, add note about Infallible
acl-cqc Mar 14, 2025
80cedfc
rename Err -> Error
acl-cqc Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions hugr-passes/src/composable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
//! Compiler passes and utilities for composing them

use std::{error::Error, marker::PhantomData};

use hugr_core::hugr::{hugrmut::HugrMut, ValidationError};
use hugr_core::HugrView;
use itertools::Either;

/// An optimization pass that can be sequenced with another and/or wrapped
/// e.g. by [ValidatingPass]
pub trait ComposablePass: Sized {
type Err: Error;
fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err>;

fn map_err<E2: Error>(self, f: impl Fn(Self::Err) -> E2) -> impl ComposablePass<Err = E2> {
ErrMapper::new(self, f)
}

fn sequence(
self,
other: impl ComposablePass<Err = Self::Err>,
Copy link
Contributor Author

@acl-cqc acl-cqc Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other way to do this is

fn sequence_before<P2: ComposablePass>(self, other: P2) -> impl ComposablePass<Err = Self::Err>
   where P2::Err : Into<Self::Err>

(and similarly sequence_after) which avoids what may be a common-case use of map_err (i.e. with Into).

However, Infallible doesn't implement Into, so it's probably not much of a win. (I could add sequence_infallible(self, other: ComposablePass<Err=Infallible>) I guess, but would then want that before/after versions of that too)

) -> impl ComposablePass<Err = Self::Err> {
(self, other) // SequencePass::new(self, other) ?
}

fn sequence_either<P: ComposablePass>(
self,
other: P,
) -> impl ComposablePass<Err = Either<Self::Err, P::Err>> {
self.map_err(Either::Left)
.sequence(other.map_err(Either::Right))
}
}

struct ErrMapper<P, E, F>(P, F, PhantomData<E>);

impl<P: ComposablePass, E: Error, F: Fn(P::Err) -> E> ErrMapper<P, E, F> {
fn new(pass: P, err_fn: F) -> Self {
Self(pass, err_fn, PhantomData)
}
}

impl<P: ComposablePass, E: Error, F: Fn(P::Err) -> E> ComposablePass for ErrMapper<P, E, F> {
type Err = E;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err> {
self.0.run(hugr).map_err(&self.1)
}
}

impl<E: Error, P1: ComposablePass<Err = E>, P2: ComposablePass<Err = E>> ComposablePass
for (P1, P2)
{
type Err = E;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err> {
self.0.run(hugr)?;
self.1.run(hugr)
}
}

/// Error from a [ValidatingPass]
#[derive(thiserror::Error, Debug)]
pub enum ValidatePassError<E> {
#[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
Input {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
Output {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error(transparent)]
Underlying(#[from] E),
}

/// Runs an underlying pass, but with validation of the Hugr
/// both before and afterwards.
pub struct ValidatingPass<P>(P, bool);

impl<P: ComposablePass> ValidatingPass<P> {
pub fn new_default(underlying: P) -> Self {
// Self(underlying, cfg!(feature = "extension_inference"))
// Sadly, many tests fail with extension inference, hence:
Self(underlying, false)
}

pub fn new_validating_extensions(underlying: P) -> Self {
Self(underlying, true)
}

pub fn new(underlying: P, validate_extensions: bool) -> Self {
Self(underlying, validate_extensions)
}

fn validation_impl<E>(
&self,
hugr: &impl HugrView,
mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError<E>,
) -> Result<(), ValidatePassError<E>> {
match self.1 {
false => hugr.validate_no_extensions(),
true => hugr.validate(),
}
.map_err(|err| mk_err(err, hugr.mermaid_string()))
}
}

impl<P: ComposablePass> ComposablePass for ValidatingPass<P> {
type Err = ValidatePassError<P::Err>;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err> {
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
err,
pretty_hugr,
})?;
self.0.run(hugr).map_err(ValidatePassError::Underlying)?;
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
err,
pretty_hugr,
})
}
}

pub(crate) fn validate_if_test<P: ComposablePass>(
pass: P,
hugr: &mut impl HugrMut,
) -> Result<(), ValidatePassError<P::Err>> {
if cfg!(test) {
ValidatingPass::new_default(pass).run(hugr)
} else {
pass.run(hugr).map_err(ValidatePassError::Underlying)
}
}

#[cfg(test)]
mod test {
use std::convert::Infallible;

use hugr_core::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::ops::{Input, Output, DEFAULT_OPTYPE, DFG, handle::NodeHandle};
use hugr_core::{Hugr, HugrView, IncomingPort, types::Signature};
use itertools::Either;

use crate::composable::{ValidatePassError, ValidatingPass};
use crate::const_fold::{ConstFoldError, ConstantFoldPass};
use crate::DeadCodeElimPass;

use super::ComposablePass;

#[test]
fn test_sequence() {
let mut mb = ModuleBuilder::new();
let id1 = mb
.define_function("id1", Signature::new_endo(usize_t()))
.unwrap();
let inps = id1.input_wires();
let id1 = id1.finish_with_outputs(inps).unwrap();
let id2 = mb
.define_function("id2", Signature::new_endo(usize_t()))
.unwrap();
let inps = id2.input_wires();
let id2 = id2.finish_with_outputs(inps).unwrap();
let hugr = mb.finish_hugr().unwrap();

let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]);
let cfold =
ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]);

cfold.run(&mut hugr.clone()).unwrap();

let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE);
let r: Result<(), Either<Infallible, ConstFoldError>> = dce
.clone()
.sequence_either(cfold.clone())
.run(&mut hugr.clone());
assert_eq!(r, Err(Either::Right(exp_err.clone())));

let r: Result<(), ConstFoldError> = dce
.map_err(|inf| match inf {})
.sequence(cfold)
.run(&mut hugr.clone());
assert_eq!(r, Err(exp_err));
}

#[test]
fn test_validation() {
let mut h = Hugr::new(DFG {
signature: Signature::new(usize_t(), bool_t()),
});
let inp = h.add_node_with_parent(
h.root(),
Input {
types: usize_t().into(),
},
);
let outp = h.add_node_with_parent(
h.root(),
Output {
types: bool_t().into(),
},
);
h.connect(inp, 0, outp, 0);
let backup = h.clone();
let err = backup.validate().unwrap_err();

let no_inputs: [(IncomingPort, _); 0] = [];
let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs);
cfold.run(&mut h).unwrap();
assert_eq!(h, backup); // Did nothing

let r = ValidatingPass(cfold, false).run(&mut h);
assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err));
}
}
64 changes: 23 additions & 41 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,30 @@ use hugr_core::{
};
use value_handle::ValueHandle;

use crate::dataflow::{
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
TailLoopTermination,
};
use crate::dead_code::{DeadCodeElimPass, PreserveNode};
use crate::validation::{ValidatePassError, ValidationLevel};
use crate::ComposablePass;
use crate::{
composable::validate_if_test,
dataflow::{
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
TailLoopTermination,
},
};

#[derive(Debug, Clone, Default)]
/// A configuration for the Constant Folding pass.
pub struct ConstantFoldPass {
validation: ValidationLevel,
allow_increase_termination: bool,
/// Each outer key Node must be either:
/// - a FuncDefn child of the root, if the root is a module; or
/// - the root, if the root is not a Module
inputs: HashMap<Node, HashMap<IncomingPort, Value>>,
}

#[derive(Debug, Error)]
#[derive(Clone, Debug, Error, PartialEq)]
#[non_exhaustive]
/// Errors produced by [ConstantFoldPass].
pub enum ConstFoldError {
#[error(transparent)]
#[allow(missing_docs)]
ValidationError(#[from] ValidatePassError),
/// Error raised when a Node is specified as an entry-point but
/// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor
/// a [Conditional](OpType::Conditional).
Expand All @@ -53,12 +52,6 @@ pub enum ConstFoldError {
}

impl ConstantFoldPass {
/// Sets the validation level used before and after the pass is run
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their
/// result (if/when they do terminate) is either known or not needed.
///
Expand Down Expand Up @@ -90,9 +83,18 @@ impl ConstantFoldPass {
.extend(inputs.into_iter().map(|(p, v)| (p.into(), v)));
self
}
}

impl ComposablePass for ConstantFoldPass {
type Err = ConstFoldError;

/// Run the Constant Folding pass.
fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> {
///
/// # Errors
///
/// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs]
/// was of an invalid [OpType]
fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> {
let fresh_node = Node::from(portgraph::NodeIndex::new(
hugr.nodes().max().map_or(0, |n| n.index() + 1),
));
Expand Down Expand Up @@ -168,23 +170,10 @@ impl ConstantFoldPass {
}
})
})
.run(hugr)?;
.run(hugr)
.map_err(|inf| match inf {})?; // TODO use into_ok when available
Ok(())
}

/// Run the pass using this configuration.
///
/// # Errors
///
/// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards
/// (if [Self::validation_level] is set, or in tests)
///
/// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs]
/// was of an invalid OpType
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
self.validation
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
}
}

/// Exhaustively apply constant folding to a HUGR.
Expand All @@ -202,18 +191,11 @@ pub fn constant_fold_pass<H: HugrMut>(h: &mut H) {
} else {
c
};
c.run(h).unwrap()
validate_if_test(c, h).unwrap()
}

struct ConstFoldContext<'a, H>(&'a H);

impl<H: HugrView> std::ops::Deref for ConstFoldContext<'_, H> {
type Target = H;
fn deref(&self) -> &H {
self.0
}
}

impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldContext<'_, H> {
type Node = H::Node;

Expand Down Expand Up @@ -244,7 +226,7 @@ impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldCo
};
// Returning the function body as a value, here, would be sufficient for inlining IndirectCall
// but not for transforming to a direct Call.
let func = DescendantsGraph::<FuncID<true>>::try_new(&**self, node).ok()?;
let func = DescendantsGraph::<FuncID<true>>::try_new(self.0, node).ok()?;
Some(ValueHandle::new_const_hugr(
ConstLocation::Node(node),
Box::new(func.extract_hugr()),
Expand Down
1 change: 1 addition & 0 deletions hugr-passes/src/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node};

use crate::dataflow::{partial_from_const, DFContext, PartialValue};
use crate::ComposablePass as _;

use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle};

Expand Down
Loading
Loading