diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs new file mode 100644 index 000000000..632205ec9 --- /dev/null +++ b/hugr-passes/src/composable.rs @@ -0,0 +1,231 @@ +//! Compiler passes and utilities for composing them + +use std::{error::Error, marker::PhantomData}; + +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; +use itertools::Either; + +/// An optimization pass that can be sequenced with another and/or wrapped +/// e.g. by [ValidatingPass] +pub trait ComposablePass: Sized { + type Error: Error; + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error>; + + fn map_err(self, f: impl Fn(Self::Error) -> E2) -> impl ComposablePass { + ErrMapper::new(self, f) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", so long as + /// `other::Err` maps into ours. + fn then(self, other: P) -> impl ComposablePass + where + P::Error: Into, + { + (self, other.map_err(Into::into)) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", combining + /// the two error types via `Either` + fn then_either( + self, + other: P, + ) -> impl ComposablePass> { + (self.map_err(Either::Left), other.map_err(Either::Right)) + } + + // Note: in the short term another variant could be useful: + // fn then_inf(self, other: impl ComposablePass) -> impl ComposablePass + // however this will become redundant when Infallible is replaced by ! (never_type) + // as (unlike Infallible) ! converts Into anything +} + +struct ErrMapper(P, F, PhantomData); + +impl E> ErrMapper { + fn new(pass: P, err_fn: F) -> Self { + Self(pass, err_fn, PhantomData) + } +} + +impl E> ComposablePass for ErrMapper { + type Error = E; + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error> { + self.0.run(hugr).map_err(&self.1) + } +} + +impl, P2: ComposablePass> ComposablePass + for (P1, P2) +{ + type Error = E; + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error> { + self.0.run(hugr)?; + self.1.run(hugr) + } +} + +/// Error from a [ValidatingPass] +#[derive(thiserror::Error, Debug)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + Input { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + Output { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error(transparent)] + Underlying(#[from] E), +} + +/// Runs an underlying pass, but with validation of the Hugr +/// both before and afterwards. +pub struct ValidatingPass

(P, bool); + +impl ValidatingPass

{ + pub fn new_default(underlying: P) -> Self { + // Self(underlying, cfg!(feature = "extension_inference")) + // Sadly, many tests fail with extension inference, hence: + Self(underlying, false) + } + + pub fn new_validating_extensions(underlying: P) -> Self { + Self(underlying, true) + } + + pub fn new(underlying: P, validate_extensions: bool) -> Self { + Self(underlying, validate_extensions) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), ValidatePassError> { + match self.1 { + false => hugr.validate_no_extensions(), + true => hugr.validate(), + } + .map_err(|err| mk_err(err, hugr.mermaid_string())) + } +} + +impl ComposablePass for ValidatingPass

{ + type Error = ValidatePassError; + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Error> { + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { + err, + pretty_hugr, + })?; + self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { + err, + pretty_hugr, + }) + } +} + +pub(crate) fn validate_if_test( + pass: P, + hugr: &mut impl HugrMut, +) -> Result<(), ValidatePassError> { + if cfg!(test) { + ValidatingPass::new_default(pass).run(hugr) + } else { + pass.run(hugr).map_err(ValidatePassError::Underlying) + } +} + +#[cfg(test)] +mod test { + use std::convert::Infallible; + + use hugr_core::builder::{ + Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, + }; + use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize}; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::ops::{handle::NodeHandle, Input, Output, DEFAULT_OPTYPE, DFG}; + use hugr_core::{types::Signature, Hugr, HugrView, IncomingPort}; + use itertools::Either; + + use crate::composable::{ValidatePassError, ValidatingPass}; + use crate::const_fold::{ConstFoldError, ConstantFoldPass}; + use crate::DeadCodeElimPass; + + use super::ComposablePass; + + #[test] + fn test_sequence() { + let mut mb = ModuleBuilder::new(); + let id1 = mb + .define_function("id1", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id1.input_wires(); + let id1 = id1.finish_with_outputs(inps).unwrap(); + let id2 = mb + .define_function("id2", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id2.input_wires(); + let id2 = id2.finish_with_outputs(inps).unwrap(); + let hugr = mb.finish_hugr().unwrap(); + + let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]); + let cfold = + ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]); + + cfold.run(&mut hugr.clone()).unwrap(); + + let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE); + let r: Result<(), Either> = dce + .clone() + .then_either(cfold.clone()) + .run(&mut hugr.clone()); + assert_eq!(r, Err(Either::Right(exp_err.clone()))); + + let r: Result<(), ConstFoldError> = dce + .map_err(|inf| match inf {}) + .then(cfold) + .run(&mut hugr.clone()); + assert_eq!(r, Err(exp_err)); + } + + #[test] + fn test_validation() { + let mut h = Hugr::new(DFG { + signature: Signature::new(usize_t(), bool_t()), + }); + let inp = h.add_node_with_parent( + h.root(), + Input { + types: usize_t().into(), + }, + ); + let outp = h.add_node_with_parent( + h.root(), + Output { + types: bool_t().into(), + }, + ); + h.connect(inp, 0, outp, 0); + let backup = h.clone(); + let err = backup.validate().unwrap_err(); + + let no_inputs: [(IncomingPort, _); 0] = []; + let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs); + cfold.run(&mut h).unwrap(); + assert_eq!(h, backup); // Did nothing + + let r = ValidatingPass(cfold, false).run(&mut h); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + } +} diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..84b798cfb 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -20,17 +20,19 @@ use hugr_core::{ }; use value_handle::ValueHandle; -use crate::dataflow::{ - partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue, - TailLoopTermination, -}; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; +use crate::{ + composable::validate_if_test, + dataflow::{ + partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue, + TailLoopTermination, + }, +}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { - validation: ValidationLevel, allow_increase_termination: bool, /// Each outer key Node must be either: /// - a FuncDefn child of the root, if the root is a module; or @@ -38,13 +40,10 @@ pub struct ConstantFoldPass { inputs: HashMap>, } -#[derive(Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), /// Error raised when a Node is specified as an entry-point but /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor /// a [Conditional](OpType::Conditional). @@ -53,12 +52,6 @@ pub enum ConstFoldError { } impl ConstantFoldPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their /// result (if/when they do terminate) is either known or not needed. /// @@ -90,9 +83,18 @@ impl ConstantFoldPass { .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); self } +} + +impl ComposablePass for ConstantFoldPass { + type Error = ConstFoldError; /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + /// + /// # Errors + /// + /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] + /// was of an invalid [OpType] + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -168,23 +170,10 @@ impl ConstantFoldPass { } }) }) - .run(hugr)?; + .run(hugr) + .map_err(|inf| match inf {})?; // TODO use into_ok when available Ok(()) } - - /// Run the pass using this configuration. - /// - /// # Errors - /// - /// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards - /// (if [Self::validation_level] is set, or in tests) - /// - /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] - /// was of an invalid OpType - pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } /// Exhaustively apply constant folding to a HUGR. @@ -202,18 +191,11 @@ pub fn constant_fold_pass(h: &mut H) { } else { c }; - c.run(h).unwrap() + validate_if_test(c, h).unwrap() } struct ConstFoldContext<'a, H>(&'a H); -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} - impl> ConstLoader> for ConstFoldContext<'_, H> { type Node = H::Node; @@ -244,7 +226,7 @@ impl> ConstLoader> for ConstFoldCo }; // Returning the function body as a value, here, would be sufficient for inlining IndirectCall // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; + let func = DescendantsGraph::>::try_new(self.0, node).ok()?; Some(ValueHandle::new_const_hugr( ConstLocation::Node(node), Box::new(func.extract_hugr()), diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..f815992e4 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::ComposablePass as _; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index b714dd6fd..fa891c274 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,13 +1,14 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -18,7 +19,6 @@ pub struct DeadCodeElimPass { /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [PreserveNode::default_for]. preserve_callback: Arc, - validation: ValidationLevel, } impl Default for DeadCodeElimPass { @@ -26,7 +26,6 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), - validation: ValidationLevel::default(), } } } @@ -39,13 +38,11 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a> { entry_points: &'a Vec, - validation: ValidationLevel, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, - validation: self.validation, }, f, ) @@ -86,13 +83,6 @@ impl PreserveNode { } impl DeadCodeElimPass { - /// Sets the validation level used before and after the pass is run - #[allow(unused)] - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows setting a callback that determines whether a node must be preserved /// (even when its result is not used) pub fn set_preserve_callback(mut self, cb: Arc) -> Self { @@ -146,24 +136,6 @@ impl DeadCodeElimPass { needed } - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - self.validation.run_validated_pass(hugr, |h, _| { - self.run_no_validate(h); - Ok(()) - }) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) { - let needed = self.find_needed_nodes(&*hugr); - let remove = hugr - .nodes() - .filter(|n| !needed.contains(n)) - .collect::>(); - for n in remove { - hugr.remove_node(n); - } - } - fn must_preserve( &self, h: &impl HugrView, @@ -185,6 +157,21 @@ impl DeadCodeElimPass { } } +impl ComposablePass for DeadCodeElimPass { + type Error = Infallible; + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } +} #[cfg(test)] mod test { use std::sync::Arc; @@ -196,6 +183,8 @@ mod test { use hugr_core::{ops::Value, type_row, HugrView}; use itertools::Itertools; + use crate::ComposablePass; + use super::{DeadCodeElimPass, PreserveNode}; #[test] diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index b114a9e42..1e38980e3 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -10,7 +10,10 @@ use hugr_core::{ }; use petgraph::visit::{Dfs, Walker}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{ + composable::{validate_if_test, ValidatePassError}, + ComposablePass, +}; use super::call_graph::{CallGraph, CallGraphNode}; @@ -26,9 +29,6 @@ pub enum RemoveDeadFuncsError { /// The invalid node. node: N, }, - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), } fn reachable_funcs<'a, H: HugrView>( @@ -64,17 +64,10 @@ fn reachable_funcs<'a, H: HugrView>( #[derive(Debug, Clone, Default)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - validation: ValidationLevel, entry_points: Vec, } impl RemoveDeadFuncsPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Adds new entry points - these must be [FuncDefn] nodes /// that are children of the [Module] at the root of the Hugr. /// @@ -87,16 +80,31 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } +} - /// Runs the pass (see [remove_dead_funcs]) with this configuration - pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - remove_dead_funcs(hugr, self.entry_points.iter().cloned()) - }) +impl ComposablePass for RemoveDeadFuncsPass { + type Error = RemoveDeadFuncsError; + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs( + &CallGraph::new(hugr), + hugr, + self.entry_points.iter().cloned(), + )? + .collect::>(); + let unreachable = hugr + .nodes() + .filter(|n| { + OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n) + }) + .collect::>(); + for n in unreachable { + hugr.remove_subtree(n); + } + Ok(()) } } -/// Delete from the Hugr any functions that are not used by either [Call] or +/// Deletes from the Hugr any functions that are not used by either [Call] or /// [LoadFunction] nodes in reachable parts. /// /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, @@ -118,16 +126,11 @@ impl RemoveDeadFuncsPass { pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, -) -> Result<(), RemoveDeadFuncsError> { - let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); - let unreachable = h - .nodes() - .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) - .collect::>(); - for n in unreachable { - h.remove_subtree(n); - } - Ok(()) +) -> Result<(), ValidatePassError> { + validate_if_test( + RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + h, + ) } #[cfg(test)] @@ -142,7 +145,7 @@ mod test { }; use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; - use super::RemoveDeadFuncsPass; + use super::remove_dead_funcs; #[rstest] #[case([], vec![])] // No entry_points removes everything! @@ -182,15 +185,14 @@ mod test { }) .collect::>(); - RemoveDeadFuncsPass::default() - .with_module_entry_points( - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .run(&mut hugr) - .unwrap(); + remove_dead_funcs( + &mut hugr, + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 38ef01505..047efe080 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod call_graph; +pub mod composable; +pub use composable::ComposablePass; pub mod const_fold; pub mod dataflow; pub mod dead_code; @@ -20,17 +22,9 @@ mod monomorphize; )] #[allow(deprecated)] pub use monomorphize::remove_polyfuncs; -// TODO: Deprecated re-export. Remove on a breaking release. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -#[allow(deprecated)] -pub use monomorphize::monomorphize; -pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{monomorphize, MonomorphizePass}; pub mod nest_cfgs; pub mod non_local; -pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 1e3b2d9a0..4a04aa95c 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,5 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, + convert::Infallible, fmt::Write, ops::Deref, }; @@ -12,7 +13,9 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; -use thiserror::Error; + +use crate::composable::{validate_if_test, ValidatePassError}; +use crate::ComposablePass; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -30,26 +33,8 @@ use thiserror::Error; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. -pub fn monomorphize(mut h: Hugr) -> Hugr { - monomorphize_ref(&mut h); - h -} - -fn monomorphize_ref(h: &mut impl HugrMut) { - let root = h.root(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - if !h.get_optype(root).is_module() { - #[allow(deprecated)] // TODO remove in next breaking release and update docs - remove_polyfuncs_ref(h); - } - } +pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + validate_if_test(MonomorphizePass, hugr) } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -254,8 +239,6 @@ fn instantiate( mono_tgt } -use crate::validation::{ValidatePassError, ValidationLevel}; - /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. /// @@ -271,38 +254,24 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone, Default)] -pub struct MonomorphizePass { - validation: ValidationLevel, -} - -#[derive(Debug, Error)] -#[non_exhaustive] -/// Errors produced by [MonomorphizePass]. -pub enum MonomorphizeError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), -} - -impl MonomorphizePass { - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Monomorphization pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { - monomorphize_ref(hugr); +#[derive(Debug, Clone)] +pub struct MonomorphizePass; + +impl ComposablePass for MonomorphizePass { + type Error = Infallible; + + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(h, root, None, &mut HashMap::new()); + if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs + remove_polyfuncs_ref(h); + } + } Ok(()) } - - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } struct TypeArgsList<'a>(&'a [TypeArg]); @@ -387,9 +356,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::remove_dead_funcs; + use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -410,7 +379,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass::default().run(&mut hugr2).unwrap(); + monomorphize(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -472,7 +441,7 @@ mod test { .count(), 3 ); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -493,7 +462,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass::default().run(&mut mono2)?; + monomorphize(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -601,7 +570,7 @@ mod test { .outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -662,7 +631,7 @@ mod test { let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); @@ -719,7 +688,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs deleted file mode 100644 index baf3b86d8..000000000 --- a/hugr-passes/src/validation.rs +++ /dev/null @@ -1,93 +0,0 @@ -//! Provides [ValidationLevel] with tools to run passes with configurable -//! validation. - -use thiserror::Error; - -use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; -use hugr_core::HugrView; - -#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)] -/// A type for running [HugrMut] algorithms with verification. -/// -/// Provides [ValidationLevel::run_validated_pass] to invoke a closure with pre and post -/// validation. -/// -/// The default level is `None` because validation can be expensive. -pub enum ValidationLevel { - /// Do no verification. - None, - /// Validate using [HugrView::validate_no_extensions]. This is useful when you - /// do not expect valid Extension annotations on Nodes. - WithoutExtensions, - /// Validate using [HugrView::validate]. - WithExtensions, -} - -#[derive(Error, Debug)] -#[allow(missing_docs)] -pub enum ValidatePassError { - #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] - InputError { - #[source] - err: ValidationError, - pretty_hugr: String, - }, - #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] - OutputError { - #[source] - err: ValidationError, - pretty_hugr: String, - }, -} - -impl Default for ValidationLevel { - fn default() -> Self { - if cfg!(test) { - // Many tests fail when run with Self::WithExtensions - Self::WithoutExtensions - } else { - Self::None - } - } -} - -impl ValidationLevel { - /// Run an operation on a [HugrMut]. `hugr` will be verified according to - /// [self](ValidationLevel), then `pass` will be invoked. If `pass` succeeds - /// then `hugr` will be verified again. - pub fn run_validated_pass( - &self, - hugr: &mut H, - pass: impl FnOnce(&mut H, &Self) -> Result, - ) -> Result - where - ValidatePassError: Into, - { - self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::InputError { - err, - pretty_hugr, - })?; - let result = pass(hugr, self)?; - self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::OutputError { - err, - pretty_hugr, - })?; - Ok(result) - } - - fn validation_impl( - &self, - hugr: &impl HugrView, - mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, - ) -> Result<(), E> - where - ValidatePassError: Into, - { - match self { - ValidationLevel::None => Ok(()), - ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(), - ValidationLevel::WithExtensions => hugr.validate(), - } - .map_err(|err| mk_err(err, hugr.mermaid_string()).into()) - } -}