From 692a75978494f977e975d300d31623f9ed80442a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 31 Mar 2025 20:06:04 +0100 Subject: [PATCH 01/33] No need for ConstFoldContext to impl Deref --- hugr-passes/src/const_fold.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..1dc4554cf 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -207,13 +207,6 @@ pub fn constant_fold_pass(h: &mut H) { struct ConstFoldContext<'a, H>(&'a H); -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} - impl> ConstLoader> for ConstFoldContext<'_, H> { type Node = H::Node; @@ -244,7 +237,7 @@ impl> ConstLoader> for ConstFoldCo }; // Returning the function body as a value, here, would be sufficient for inlining IndirectCall // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; + let func = DescendantsGraph::>::try_new(self.0, node).ok()?; Some(ValueHandle::new_const_hugr( ConstLocation::Node(node), Box::new(func.extract_hugr()), From 48143f558ba5c43f23eb9fe9531804ffe3f403fe Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 15:24:17 +0100 Subject: [PATCH 02/33] Add PartialValue::LoadedFunction --- hugr-passes/src/const_fold.rs | 2 +- hugr-passes/src/dataflow.rs | 12 +- hugr-passes/src/dataflow/datalog.rs | 22 +-- hugr-passes/src/dataflow/partial_value.rs | 214 ++++++++++++++-------- hugr-passes/src/dataflow/results.rs | 18 +- hugr-passes/src/dataflow/test.rs | 7 +- hugr-passes/src/dataflow/value_row.rs | 38 ++-- 7 files changed, 190 insertions(+), 123 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 1dc4554cf..7368bf710 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -131,7 +131,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 43caa9c94..bec178e29 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -9,7 +9,7 @@ mod results; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, LoadedFunction, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; @@ -31,8 +31,8 @@ pub trait DFContext: ConstLoader { &mut self, _node: Self::Node, _e: &ExtensionOp, - _ins: &[PartialValue], - _outs: &mut [PartialValue], + _ins: &[PartialValue], + _outs: &mut [PartialValue], ) { } } @@ -94,7 +94,7 @@ pub fn partial_from_const<'a, V, CL: ConstLoader>( cl: &CL, loc: impl Into>, cst: &Value, -) -> PartialValue +) -> PartialValue where CL::Node: 'a, { @@ -120,8 +120,8 @@ where /// A row of inputs to a node contains bottom (can't happen, the node /// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). -pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( - elements: impl IntoIterator>, +pub fn row_contains_bottom<'a, V: 'a, N: 'a>( + elements: impl IntoIterator>, ) -> bool { elements.into_iter().any(PartialValue::contains_bottom) } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 13e510daf..0e195240a 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -15,7 +15,7 @@ use super::{ PartialValue, }; -type PV = PartialValue; +type PV = PartialValue; /// Basic structure for performing an analysis. Usage: /// 1. Make a new instance via [Self::new()] @@ -27,7 +27,7 @@ type PV = PartialValue; /// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine( H, - HashMap)>>, + HashMap)>>, ); impl Machine { @@ -40,7 +40,7 @@ impl Machine { impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed /// or any value previously prepopulated for the same Wire. - pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { for (n, inp) in self.0.linked_inputs(w.node(), w.source()) { self.1.entry(n).or_default().push((inp, v.clone())); } @@ -54,7 +54,7 @@ impl Machine { pub fn prepopulate_inputs( &mut self, parent: H::Node, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> Result<(), OpType> { match self.0.get_optype(parent) { OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => { @@ -102,7 +102,7 @@ impl Machine { pub fn run( mut self, context: impl DFContext, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = self.0.root(); if self.0.get_optype(root).is_module() { @@ -138,7 +138,7 @@ impl Machine { pub(super) fn run_datalog( mut ctx: impl DFContext, hugr: H, - in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, + in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. @@ -155,9 +155,9 @@ pub(super) fn run_datalog( relation parent_of_node(H::Node, H::Node); // is parent of relation input_child(H::Node, H::Node); // has 1st child that is its `Input` relation output_child(H::Node, H::Node); // has 2nd child that is its `Output` - lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value - lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value - lattice node_in_value_row(H::Node, ValueRow); // 's inputs are + lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(H::Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -341,9 +341,9 @@ fn propagate_leaf_op( ctx: &mut impl DFContext, hugr: &H, n: H::Node, - ins: &[PV], + ins: &[PV], num_outs: usize, -) -> Option> { +) -> Option> { match hugr.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f2a497806..d2781d534 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,8 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::Node; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -51,15 +52,23 @@ pub struct Sum { pub st: SumType, } +/// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" +/// to a function at a specific node, instantiated with the provided type-args. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct LoadedFunction { + pub func_node: N, + pub args: Vec, +} + /// A representation of a value of [SumType], that may have one or more possible tags, /// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] -pub struct PartialSum(pub HashMap>>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { /// New instance for a single known tag. /// (Multi-tag instances can be created via [Self::try_join_mut].) - pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -75,9 +84,21 @@ impl PartialSum { pv.assert_invariants(); } } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } -impl PartialSum { +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -141,12 +162,9 @@ impl PartialSum { } Ok(changed) } +} - /// Whether this sum might have the specified tag - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.contains_key(&tag) - } - +impl PartialSum { /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. /// @@ -155,10 +173,14 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> + pub fn try_into_sum( + self, + typ: &Type, + ) -> Result, ExtractValueError> where V: TryInto, Sum: TryInto, + LoadedFunction: TryInto, { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); @@ -185,22 +207,15 @@ impl PartialSum { num_elements: v.len(), }) } - - /// Can this ever occur at runtime? See [PartialValue::contains_bottom] - pub fn contains_bottom(&self) -> bool { - self.0 - .iter() - .all(|(_tag, elements)| row_contains_bottom(elements)) - } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type /// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] - MultipleVariants(PartialSum), + MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] ValueIsBottom, #[error("Value contained `Top`")] @@ -209,6 +224,8 @@ pub enum ExtractValueError { CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] CouldNotBuildSum(#[source] SE), + #[error("Could not turn LoadedFunction into concrete")] + CouldNotLoadFunction(#[source] LE), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -217,14 +234,14 @@ pub enum ExtractValueError { }, } -impl PartialSum { +impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. - pub fn variant_values(&self, variant: usize) -> Option>> { + pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } } -impl PartialOrd for PartialSum { +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -254,13 +271,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -273,30 +290,32 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { /// No possibilities known (so far) Bottom, + /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) + LoadedFunction(LoadedFunction), /// A single value (of the underlying representation) Value(V), /// Sum (with at least one, perhaps several, possible tags) of underlying values - PartialSum(PartialSum), + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } } -impl From> for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -312,33 +331,59 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } + + /// New instance of self for a [LoadFunction](hugr_core::ops::LoadFunction) + pub fn new_load(func_node: N, args: impl Into>) -> Self { + Self::LoadedFunction(LoadedFunction { + func_node, + args: args.into(), + }) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + false + } + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } -impl PartialValue { +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + return None + } PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) } +} - /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } - +impl PartialValue { /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by /// [PartialSum::try_into_sum]. @@ -348,16 +393,23 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete(self, typ: &Type) -> Result> + pub fn try_into_concrete( + self, + typ: &Type, + ) -> Result> where V: TryInto, Sum: TryInto, + LoadedFunction: TryInto, { match self { Self::Value(v) => v .clone() .try_into() - .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), + .map_err(|e| ExtractValueError::CouldNotConvert(v, e)), + Self::LoadedFunction(lf) => lf + .try_into() + .map_err(ExtractValueError::CouldNotLoadFunction), Self::PartialSum(ps) => ps .try_into_sum(typ)? .try_into() @@ -366,18 +418,6 @@ impl PartialValue { Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } - - /// A value contains bottom means that it cannot occur during execution: - /// it may be an artefact during bootstrapping of the analysis, or else - /// the value depends upon a `panic` or a loop that - /// [never terminates](super::TailLoopTermination::NeverBreaks). - pub fn contains_bottom(&self) -> bool { - match self { - PartialValue::Bottom => true, - PartialValue::Top | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.contains_bottom(), - } - } } impl TryFrom> for Value { @@ -388,7 +428,15 @@ impl TryFrom> for Value { } } -impl Lattice for PartialValue { +impl TryFrom> for Value { + type Error = LoadedFunction; + + fn try_from(value: LoadedFunction) -> Result { + Err(value) + } +} + +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); let mut old_self = Self::Top; @@ -400,13 +448,17 @@ impl Lattice for PartialValue { Some((h3, b)) => (Self::Value(h3), b), None => (Self::Top, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also require TypeArgs to be equal by at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Top, true) - } + _ => (Self::Top, true), }; *self = res; ch @@ -423,20 +475,24 @@ impl Lattice for PartialValue { Some((h3, ch)) => (Self::Value(h3), ch), None => (Self::Bottom, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also require TypeArgs to be equal by at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Bottom, true) - } + _ => (Self::Bottom, true), }; *self = res; ch } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self::Top } @@ -446,7 +502,7 @@ impl BoundedLattice for PartialValue { } } -impl PartialOrd for PartialValue { +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { @@ -457,6 +513,9 @@ impl PartialOrd for PartialValue { (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) => { + (lf1 == lf2).then_some(Ordering::Equal) + } (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } @@ -468,6 +527,7 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::Node; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; @@ -506,7 +566,7 @@ mod test { } impl TestSumType { - fn check_value(&self, pv: &PartialValue) -> bool { + fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), @@ -567,7 +627,7 @@ mod test { fn single_sum_strat( tag: usize, elems: Vec>, - ) -> impl Strategy> { + ) -> impl Strategy> { elems .iter() .map(Arc::as_ref) @@ -578,11 +638,11 @@ mod test { fn partial_sum_strat( variants: &[Vec>], - ) -> impl Strategy> { + ) -> impl Strategy> { // We have to clone the `variants` here but only as far as the Vec>> let tagged_variants = variants.iter().cloned().enumerate().collect::>(); // The type annotation here (and the .boxed() enabling it) are just for documentation - let sum_variants_strat: BoxedStrategy>> = + let sum_variants_strat: BoxedStrategy>> = subsequence(tagged_variants, 1..=variants.len()) .prop_flat_map(|selected_variants| { selected_variants @@ -591,7 +651,7 @@ mod test { .collect::>() }) .boxed(); - sum_variants_strat.prop_map(|psums: Vec>| { + sum_variants_strat.prop_map(|psums: Vec>| { let mut psums = psums.into_iter(); let first = psums.next().unwrap(); psums.fold(first, |mut a, b| { @@ -603,7 +663,7 @@ mod test { fn any_partial_value_of_type( ust: &TestSumType, - ) -> impl Strategy> { + ) -> impl Strategy> { match ust { TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), TestSumType::Leaf(Some(i)) => (0..*i) @@ -616,15 +676,16 @@ mod test { fn any_partial_value_with( params: ::Parameters, - ) -> impl Strategy> { + ) -> impl Strategy> { any_with::(params).prop_flat_map(|t| any_partial_value_of_type(&t)) } - fn any_partial_value() -> impl Strategy> { + fn any_partial_value() -> impl Strategy> { any_partial_value_with(Default::default()) } - fn any_partial_values() -> impl Strategy; N]> { + fn any_partial_values( + ) -> impl Strategy; N]> { any::().prop_flat_map(|ust| { TryInto::<[_; N]>::try_into( (0..N) @@ -635,7 +696,8 @@ mod test { }) } - fn any_typed_partial_value() -> impl Strategy)> { + fn any_typed_partial_value( + ) -> impl Strategy)> { any::() .prop_flat_map(|t| any_partial_value_of_type(&t).prop_map(move |v| (t.clone(), v))) } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c40f1d87f..c3d1db696 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,16 +2,16 @@ use std::collections::HashMap; use hugr_core::{HugrView, IncomingPort, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; +use super::{partial_value::ExtractValueError, AbstractValue, LoadedFunction, PartialValue, Sum}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, - pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, + pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(H::Node, H::Node)>, pub(super) bb_reachable: Vec<(H::Node, H::Node)>, - pub(super) out_wire_values: HashMap, PartialValue>, + pub(super) out_wire_values: HashMap, PartialValue>, } impl AnalysisResults { @@ -21,7 +21,7 @@ impl AnalysisResults { } /// Gets the lattice value computed for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() } @@ -84,12 +84,14 @@ impl AnalysisResults { /// `None` if the analysis did not produce a result for that wire, or if /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` - pub fn try_read_wire_concrete( + pub fn try_read_wire_concrete( &self, w: Wire, - ) -> Result>> + ) -> Result>> where - V2: TryFrom + TryFrom, Error = SE>, + V2: TryFrom + + TryFrom, Error = SE> + + TryFrom, Error = LE>, { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self @@ -116,7 +118,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - fn from_control_value(v: &PartialValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3af0097f7..94443cc90 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -19,7 +19,10 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{ + AbstractValue, ConstLoader, DFContext, Machine, PartialValue, + TailLoopTermination, +}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -296,7 +299,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); assert!(results - .try_read_wire_concrete::(cond_o2) + .try_read_wire_concrete::(cond_o2) .is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 50cf10318..43c842d91 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,25 +5,25 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::{lattice::BoundedLattice, Lattice}; +use ascent::Lattice; use itertools::zip_eq; use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] -pub(super) struct ValueRow(Vec>); +pub(super) struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PartialValue::Bottom; len]) } - pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { *self.0.get_mut(idx).unwrap() = v; self } - pub fn singleton(v: PartialValue) -> Self { + pub fn singleton(v: PartialValue) -> Self { Self(vec![v]) } @@ -34,25 +34,25 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option>> { + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } } -impl Lattice for ValueRow { +impl Lattice for ValueRow { fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; @@ -72,30 +72,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) From 5809594880e4c1af4f679a17b503953c9b92dcdc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 15:40:07 +0100 Subject: [PATCH 03/33] Docs, clippy::type_complexity --- hugr-passes/src/dataflow/datalog.rs | 11 ++++++----- hugr-passes/src/dataflow/partial_value.rs | 2 ++ hugr-passes/src/dataflow/results.rs | 10 +++++++--- hugr-passes/src/dataflow/test.rs | 5 +---- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0e195240a..934cecdff 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -17,6 +17,8 @@ use super::{ type PV = PartialValue; +type NodeInputs = Vec<(IncomingPort, PV)>; + /// Basic structure for performing an analysis. Usage: /// 1. Make a new instance via [Self::new()] /// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or @@ -25,10 +27,7 @@ type PV = PartialValue; /// [Self::prepopulate_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine( - H, - HashMap)>>, -); +pub struct Machine(H, HashMap>); impl Machine { /// Create a new Machine to analyse the given Hugr(View) @@ -135,10 +134,12 @@ impl Machine { } } +pub(super) type InWire = (N, IncomingPort, PartialValue); + pub(super) fn run_datalog( mut ctx: impl DFContext, hugr: H, - in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, + in_wire_value_proto: Vec>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d2781d534..46e981244 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -56,7 +56,9 @@ pub struct Sum { /// to a function at a specific node, instantiated with the provided type-args. #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct LoadedFunction { + /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded pub func_node: N, + /// The type arguments provided when loading pub args: Vec, } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c3d1db696..a6b864f5e 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,14 +1,17 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, PortIndex, Wire}; +use hugr_core::{HugrView, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, LoadedFunction, PartialValue, Sum}; +use super::{ + datalog::InWire, partial_value::ExtractValueError, AbstractValue, LoadedFunction, PartialValue, + Sum, +}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, - pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, + pub(super) in_wire_value: Vec>, pub(super) case_reachable: Vec<(H::Node, H::Node)>, pub(super) bb_reachable: Vec<(H::Node, H::Node)>, pub(super) out_wire_values: HashMap, PartialValue>, @@ -84,6 +87,7 @@ impl AnalysisResults { /// `None` if the analysis did not produce a result for that wire, or if /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` + #[allow(clippy::type_complexity)] pub fn try_read_wire_concrete( &self, w: Wire, diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 94443cc90..ca0cfdb44 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -19,10 +19,7 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{ - AbstractValue, ConstLoader, DFContext, Machine, PartialValue, - TailLoopTermination, -}; +use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] From cccefe4c49df0de66bb50319248db9dcc4341c73 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 15:25:49 +0100 Subject: [PATCH 04/33] datalog handles LoadFunction, drop value_from_function (TODO ConstFoldCtx ignores self.0) --- hugr-passes/src/const_fold.rs | 31 +++++------------------------ hugr-passes/src/dataflow.rs | 5 +++-- hugr-passes/src/dataflow/datalog.rs | 12 ++++++----- 3 files changed, 15 insertions(+), 33 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7368bf710..ef5fc355c 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,15 +7,11 @@ use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, + hugr::hugrmut::HugrMut, ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, }, - types::{EdgeKind, TypeArg}, + types::EdgeKind, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; @@ -205,7 +201,8 @@ pub fn constant_fold_pass(h: &mut H) { c.run(h).unwrap() } -struct ConstFoldContext<'a, H>(&'a H); +// Probably intend to remove this in a future PR, but not certain, so leaving in for now +struct ConstFoldContext<'a, H>(#[allow(unused)] &'a H); impl> ConstLoader> for ConstFoldContext<'_, H> { type Node = H::Node; @@ -225,24 +222,6 @@ impl> ConstLoader> for ConstFoldCo ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } - - fn value_from_function( - &self, - node: H::Node, - type_args: &[TypeArg], - ) -> Option> { - if !type_args.is_empty() { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(self.0, node).ok()?; - Some(ValueHandle::new_const_hugr( - ConstLocation::Node(node), - Box::new(func.extract_hugr()), - )) - } } impl> DFContext> for ConstFoldContext<'_, H> { diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index bec178e29..7029ec0b2 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -55,8 +55,8 @@ impl From for ConstLocation<'_, N> { } /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. -/// Implementors will likely want to override some/all of [Self::value_from_opaque], -/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// Implementors will likely want to override either/both of [Self::value_from_opaque] +/// and [Self::value_from_const_hugr]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { /// The type of nodes in the Hugr. @@ -81,6 +81,7 @@ pub trait ConstLoader { /// [FuncDefn]: hugr_core::ops::FuncDefn /// [FuncDecl]: hugr_core::ops::FuncDecl /// [LoadFunction]: hugr_core::ops::LoadFunction + #[deprecated(note = "Automatically handled by Datalog, implementation will be ignored")] fn value_from_function(&self, _node: Self::Node, _type_args: &[TypeArg]) -> Option { None } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 934cecdff..cce8cf402 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -12,7 +12,7 @@ use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; use super::{ partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, - PartialValue, + LoadedFunction, PartialValue, }; type PV = PartialValue; @@ -381,10 +381,12 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::singleton( - ctx.value_from_function(func_node, &load_op.type_args) - .map_or(PV::Top, PV::Value), - )) + Some(ValueRow::singleton(PartialValue::LoadedFunction( + LoadedFunction { + func_node, + args: load_op.type_args.clone(), + }, + ))) } OpType::ExtensionOp(e) => { Some(ValueRow::from_iter(if row_contains_bottom(ins) { From ca442c4b7facdc96e0de676a3e48a2a3a481d37f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 15:17:36 +0100 Subject: [PATCH 05/33] And handle CallIndirect --- hugr-passes/src/dataflow/datalog.rs | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index cce8cf402..d1f47be23 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -323,6 +323,24 @@ pub(super) fn run_datalog( func_call(call, func), output_child(func, outp), in_wire_value(outp, p, v); + + // CallIndirect -------------------- + relation indirect_call(H::Node, H::Node); // is an `IndirectCall` to `FuncDefn` + indirect_call(call, func_node) <-- + node(call), + if let OpType::CallIndirect(_) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + if let PartialValue::LoadedFunction(LoadedFunction {func_node, ..}) = v; + + out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- + indirect_call(call, func), + input_child(func, inp), + in_wire_value(call, p, v); + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + indirect_call(call, func), + output_child(func, outp), + in_wire_value(outp, p, v); }; let out_wire_values = all_results .out_wire_value @@ -363,8 +381,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent - OpType::Call(_) => None, // handled via Input/Output of FuncDefn - OpType::Const(_) => None, // handled by LoadConstant: + OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = hugr @@ -404,6 +421,7 @@ fn propagate_leaf_op( outs })) } - o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + // We only call propagate_leaf_op for dataflow op non-containers, + o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive } } From c4f4c80bfb01ef5b7d47160880d73fbaa721751d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 15:57:57 +0100 Subject: [PATCH 06/33] Remove Hugr from ConstFoldContext, deparametrize/drop bounds --- hugr-passes/src/const_fold.rs | 28 ++++++++++++++-------------- hugr-passes/src/const_fold/test.rs | 6 ++---- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index ef5fc355c..0ddb9613c 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,6 +7,7 @@ use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ + core::HugrNode, hugr::hugrmut::HugrMut, ops::{ constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, @@ -98,7 +99,7 @@ impl ConstantFoldPass { n, in_vals.iter().map(|(p, v)| { let const_with_dummy_loc = partial_from_const( - &ConstFoldContext(hugr), + &ConstFoldContext, ConstLocation::Field(p.index(), &fresh_node.into()), v, ); @@ -108,7 +109,7 @@ impl ConstantFoldPass { .map_err(|opty| ConstFoldError::InvalidEntryPoint(n, opty))?; } - let results = m.run(ConstFoldContext(hugr), []); + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); let wires_to_break = hugr @@ -201,36 +202,35 @@ pub fn constant_fold_pass(h: &mut H) { c.run(h).unwrap() } -// Probably intend to remove this in a future PR, but not certain, so leaving in for now -struct ConstFoldContext<'a, H>(#[allow(unused)] &'a H); +struct ConstFoldContext; -impl> ConstLoader> for ConstFoldContext<'_, H> { - type Node = H::Node; +impl ConstLoader> for ConstFoldContext { + type Node = N; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } } -impl> DFContext> for ConstFoldContext<'_, H> { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: H::Node, + node: N, op: &ExtensionOp, - ins: &[PartialValue>], - outs: &mut [PartialValue>], + ins: &[PartialValue, N>], + outs: &mut [PartialValue, N>], ) { let sig = op.signature(); let known_ins = sig diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..58e69c568 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -42,8 +42,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&temp); + let ctx = ConstFoldContext; let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -114,8 +113,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let temp = Hugr::default(); - let mut ctx = ConstFoldContext(&temp); + let mut ctx = ConstFoldContext; let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); From c6abad4afd5a657c9a7d0499928bd26014eeab7f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 15:27:22 +0100 Subject: [PATCH 07/33] Add FoldVal, constant_fold2 (taking FoldVals) --- hugr-core/src/extension.rs | 2 +- hugr-core/src/extension/const_fold.rs | 96 ++++++++++++++++++++-- hugr-core/src/extension/op_def.rs | 21 ++++- hugr-core/src/ops/custom.rs | 7 +- hugr-passes/src/const_fold.rs | 73 ++++++++++------ hugr-passes/src/const_fold/value_handle.rs | 23 +++++- hugr-passes/src/dataflow/partial_value.rs | 12 +++ 7 files changed, 199 insertions(+), 35 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 408c88e15..3f981d1e5 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -34,7 +34,7 @@ pub mod resolution; pub mod simple_op; mod type_def; -pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder}; +pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, FoldVal, Folder}; pub use op_def::{ CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, ValidateJustArgs, ValidateTypeArgs, diff --git a/hugr-core/src/extension/const_fold.rs b/hugr-core/src/extension/const_fold.rs index a2cd66f42..74e02df54 100644 --- a/hugr-core/src/extension/const_fold.rs +++ b/hugr-core/src/extension/const_fold.rs @@ -2,18 +2,87 @@ use std::fmt::Formatter; use std::fmt::Debug; +use crate::ops::constant::{OpaqueValue, Sum}; use crate::ops::Value; -use crate::types::TypeArg; +use crate::types::{SumType, TypeArg}; +use crate::{IncomingPort, Node, OutgoingPort, PortIndex}; -use crate::IncomingPort; -use crate::OutgoingPort; +/// Representation of values used for constant folding. +// Should we be non-exhaustive?? +// No point in parametrizing by HugrNode since then ConstFold would not be dyn/object-safe +#[derive(Clone, PartialEq, Default)] +pub enum FoldVal { + /// Value is unknown, must assume that it could be anything + #[default] + Unknown, + /// A variant of a [SumType] + Sum { + /// Which variant of the sum type this value is. + tag: usize, + /// Describes the type of the whole value. + // Can we deprecate this immediately? It is only for converting to Value + sum_type: SumType, + /// A value for each element (type) within the variant + items: Vec, + }, + /// A constant value defined by an extension + Extension(OpaqueValue), + /// A function pointer loaded from a [FuncDefn](crate::ops::FuncDefn) or `FuncDecl` + LoadedFunction(Node, Vec), // Deliberately skipping Function(Box) ATM +} -use crate::ops; +impl TryFrom for Value { + type Error = Option; + + fn try_from(value: FoldVal) -> Result { + match value { + FoldVal::Unknown => Err(None), + FoldVal::Sum { + tag, + sum_type, + items, + } => { + let values = items + .into_iter() + .map(Value::try_from) + .collect::, _>>()?; + Ok(Value::Sum(Sum { + tag, + values, + sum_type, + })) + } + FoldVal::Extension(e) => Ok(Value::Extension { e }), + FoldVal::LoadedFunction(node, _) => Err(Some(node)), + } + } +} + +impl From for FoldVal { + fn from(value: Value) -> Self { + match value { + Value::Extension { e } => FoldVal::Extension(e), + Value::Function { .. } => FoldVal::Unknown, + Value::Sum(Sum { + tag, + values, + sum_type, + }) => { + let items = values.into_iter().map(FoldVal::from).collect(); + FoldVal::Sum { + tag, + sum_type, + items, + } + } + } + } +} /// Output of constant folding an operation, None indicates folding was either /// not possible or unsuccessful. An empty vector indicates folding was /// successful and no values are output. -pub type ConstFoldResult = Option>; +pub type ConstFoldResult = Option>; /// Tag some output constants with [`OutgoingPort`] inferred from the ordering. pub fn fold_out_row(consts: impl IntoIterator) -> ConstFoldResult { @@ -27,6 +96,23 @@ pub fn fold_out_row(consts: impl IntoIterator) -> ConstFoldResult /// Trait implemented by extension operations that can perform constant folding. pub trait ConstFold: Send + Sync { + /// Given type arguments `type_args` and [`FoldVal`]s for each input, + /// update the outputs (these will be initialized to [FoldVal::Unknown]). + /// + /// Defaults to calling [Self::fold] with those arguments that can be converted --- + /// [FoldVal::LoadedFunction]s will be lost as these are not representable as [Value]s. + fn fold2(&self, type_args: &[TypeArg], inputs: Vec, outputs: &mut [FoldVal]) { + let consts = inputs + .into_iter() + .enumerate() + .filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?))) + .collect::>(); + let outs = self.fold(type_args, &consts); + for (p, v) in outs.unwrap_or_default() { + outputs[p.index()] = v.into(); + } + } + /// Given type arguments `type_args` and /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, /// try to evaluate the operation. diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index d5c9a5b5d..53c3a89ae 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -4,15 +4,16 @@ use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::sync::{Arc, Weak}; +use super::const_fold::FoldVal; use super::{ ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet, SignatureError, }; -use crate::ops::{OpName, OpNameRef}; +use crate::ops::{OpName, OpNameRef, Value}; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; -use crate::Hugr; +use crate::{Hugr, IncomingPort}; mod serialize_signature_func; /// Trait necessary for binary computations of OpDef signature @@ -460,11 +461,25 @@ impl OpDef { pub fn constant_fold( &self, type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Value)], + consts: &[(IncomingPort, Value)], ) -> ConstFoldResult { (self.constant_folder.as_ref())?.fold(type_args, consts) } + /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given + /// [FoldVal] values for each input, and update the outputs, which should be + /// initialised to [FoldVal::Unknown]. + pub fn constant_fold2( + &self, + type_args: &[TypeArg], + inputs: Vec, + outputs: &mut [FoldVal], + ) { + if let Some(cf) = self.constant_folder.as_ref() { + cf.fold2(type_args, inputs, outputs) + } + } + /// Returns a reference to the signature function of this [`OpDef`]. pub fn signature_func(&self) -> &SignatureFunc { &self.signature_func diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 96e884f6d..b7abe2527 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -12,7 +12,7 @@ use { ::proptest_derive::Arbitrary, }; -use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError}; +use crate::extension::{ConstFoldResult, ExtensionId, FoldVal, OpDef, SignatureError}; use crate::types::{type_param::TypeArg, Signature}; use crate::{ops, IncomingPort, Node}; @@ -96,6 +96,11 @@ impl ExtensionOp { self.def().constant_fold(self.args(), consts) } + /// Attempt to evaluate this operation, See ['OpDef::constant_fold2`] + pub fn constant_fold2(&self, inputs: Vec, outputs: &mut [FoldVal]) { + self.def().constant_fold2(self.args(), inputs, outputs) + } + /// Creates a new [`OpaqueOp`] as a downgraded version of this /// [`ExtensionOp`]. /// diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 0ddb9613c..df9464b31 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -3,11 +3,12 @@ //! An (example) use of the [dataflow analysis framework](super::dataflow). pub mod value_handle; +use itertools::Itertools; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - core::HugrNode, + extension::FoldVal, hugr::hugrmut::HugrMut, ops::{ constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, @@ -18,7 +19,7 @@ use hugr_core::{ use value_handle::ValueHandle; use crate::dataflow::{ - partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue, + partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialSum, PartialValue, TailLoopTermination, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; @@ -127,6 +128,9 @@ impl ConstantFoldPass { (!hugr.get_optype(src).is_load_constant() && Some(src) != mb_root_inp).then_some(( n, ip, + // TODO switch to using FoldVal rather than Value here, thus enabling turning CallIndirect + // into Call when the function input is known. (This will mean we will be unable to handle + // Value::Function's, so best to wait until those are removed.) results .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, @@ -204,51 +208,72 @@ pub fn constant_fold_pass(h: &mut H) { struct ConstFoldContext; -impl ConstLoader> for ConstFoldContext { - type Node = N; +impl ConstLoader> for ConstFoldContext { + type Node = Node; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } } -impl DFContext> for ConstFoldContext { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: N, + node: Node, op: &ExtensionOp, - ins: &[PartialValue, N>], - outs: &mut [PartialValue, N>], + ins: &[PartialValue, Node>], + outs: &mut [PartialValue, Node>], ) { let sig = op.signature(); - let known_ins = sig + let inp_fvs = sig .input_types() .iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_concrete(ty) - .ok() - .map(|v| (IncomingPort::from(i), v)) - }) + .zip_eq(ins.iter()) + .map(|(ty, pv)| pv.clone().try_into_concrete(ty).unwrap_or_default()) .collect::>(); - for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { - outs[p.index()] = - partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v); + let mut out_fvs = vec![FoldVal::Unknown; outs.len()]; + op.constant_fold2(inp_fvs, &mut out_fvs); + for ((p, out), out_fv) in outs.iter_mut().enumerate().zip(out_fvs) { + // UGH. Need a partial_from_const for FoldVal, *as well* as the one from Value + // 'coz we need to keep the latter for constants!! + *out = pv_from_fold_val(ConstLocation::Field(p, &node.into()), out_fv); + } + } +} + +fn pv_from_fold_val( + loc: ConstLocation, + value: FoldVal, +) -> PartialValue, Node> { + match value { + FoldVal::Unknown => PartialValue::Top, + FoldVal::Sum { + tag, + sum_type: _, + items, + } => PartialValue::PartialSum(PartialSum::new_variant( + tag, + items + .into_iter() + .enumerate() + .map(|(i, v)| pv_from_fold_val(ConstLocation::Field(i, &loc), v)), + )), + FoldVal::Extension(opaque_value) => { + PartialValue::Value(ValueHandle::new_opaque(loc, opaque_value)) } + FoldVal::LoadedFunction(node, args) => PartialValue::new_load(node, args), } } diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..362f50b4b 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -5,12 +5,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; +use hugr_core::extension::FoldVal; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, ConstLocation}; +use crate::dataflow::{AbstractValue, ConstLocation, LoadedFunction}; /// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] @@ -173,6 +174,26 @@ impl From> for Value { } } +impl From> for FoldVal { + fn from(value: ValueHandle) -> FoldVal { + match value { + ValueHandle::Hashable(HashedConst { val, .. }) + | ValueHandle::Unhashable { + leaf: Either::Left(val), + .. + } => FoldVal::Extension(Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone())), + _ => FoldVal::Unknown, + } + } +} + +impl From> for FoldVal { + fn from(value: LoadedFunction) -> Self { + let LoadedFunction { func_node, args } = value; + FoldVal::LoadedFunction(func_node, args) + } +} + #[cfg(test)] mod test { use hugr_core::{ diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 46e981244..1d3367bfd 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,5 +1,6 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; +use hugr_core::extension::FoldVal; use hugr_core::ops::Value; use hugr_core::types::{ConstTypeError, SumType, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::Node; @@ -430,6 +431,17 @@ impl TryFrom> for Value { } } +impl From> for FoldVal { + fn from(value: Sum) -> Self { + let Sum { tag, values, st } = value; + Self::Sum { + tag, + sum_type: st, + items: values, + } + } +} + impl TryFrom> for Value { type Error = LoadedFunction; From f6196e52107a10927e20ce2c22791520d013b89c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 17:44:30 +0100 Subject: [PATCH 08/33] Switch fold2 to take &[FoldVal] not Vec --- hugr-core/src/extension/const_fold.rs | 5 +++-- hugr-core/src/extension/op_def.rs | 2 +- hugr-core/src/ops/custom.rs | 2 +- hugr-passes/src/const_fold.rs | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/hugr-core/src/extension/const_fold.rs b/hugr-core/src/extension/const_fold.rs index 74e02df54..a0e7c28f0 100644 --- a/hugr-core/src/extension/const_fold.rs +++ b/hugr-core/src/extension/const_fold.rs @@ -101,9 +101,10 @@ pub trait ConstFold: Send + Sync { /// /// Defaults to calling [Self::fold] with those arguments that can be converted --- /// [FoldVal::LoadedFunction]s will be lost as these are not representable as [Value]s. - fn fold2(&self, type_args: &[TypeArg], inputs: Vec, outputs: &mut [FoldVal]) { + fn fold2(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { let consts = inputs - .into_iter() + .iter() + .cloned() .enumerate() .filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?))) .collect::>(); diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 53c3a89ae..f344b65b1 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -472,7 +472,7 @@ impl OpDef { pub fn constant_fold2( &self, type_args: &[TypeArg], - inputs: Vec, + inputs: &[FoldVal], outputs: &mut [FoldVal], ) { if let Some(cf) = self.constant_folder.as_ref() { diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index b7abe2527..aca062639 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -97,7 +97,7 @@ impl ExtensionOp { } /// Attempt to evaluate this operation, See ['OpDef::constant_fold2`] - pub fn constant_fold2(&self, inputs: Vec, outputs: &mut [FoldVal]) { + pub fn constant_fold2(&self, inputs: &[FoldVal], outputs: &mut [FoldVal]) { self.def().constant_fold2(self.args(), inputs, outputs) } diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index df9464b31..25b720770 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -244,7 +244,7 @@ impl DFContext> for ConstFoldContext { .map(|(ty, pv)| pv.clone().try_into_concrete(ty).unwrap_or_default()) .collect::>(); let mut out_fvs = vec![FoldVal::Unknown; outs.len()]; - op.constant_fold2(inp_fvs, &mut out_fvs); + op.constant_fold2(&inp_fvs, &mut out_fvs); for ((p, out), out_fv) in outs.iter_mut().enumerate().zip(out_fvs) { // UGH. Need a partial_from_const for FoldVal, *as well* as the one from Value // 'coz we need to keep the latter for constants!! From 7c5466adaec9becf55bfa5e7fbb6fc0a2fa2a844 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 17:45:27 +0100 Subject: [PATCH 09/33] Deprecate old fold routines; port tests, add helpers --- hugr-core/src/extension/const_fold.rs | 51 ++++++++++++++++++- hugr-core/src/extension/op_def.rs | 2 + hugr-core/src/extension/prelude/generic.rs | 17 ++++--- hugr-core/src/ops/custom.rs | 2 + .../std_extensions/arithmetic/conversions.rs | 39 +++++--------- .../std_extensions/arithmetic/float_ops.rs | 27 ++++------ .../src/std_extensions/arithmetic/int_ops.rs | 34 ++++--------- .../src/std_extensions/collections/list.rs | 31 +++++------ hugr-core/src/std_extensions/logic.rs | 41 ++++++--------- 9 files changed, 125 insertions(+), 119 deletions(-) diff --git a/hugr-core/src/extension/const_fold.rs b/hugr-core/src/extension/const_fold.rs index a0e7c28f0..819c3c8f7 100644 --- a/hugr-core/src/extension/const_fold.rs +++ b/hugr-core/src/extension/const_fold.rs @@ -2,6 +2,7 @@ use std::fmt::Formatter; use std::fmt::Debug; +use crate::ops::constant::CustomConst; use crate::ops::constant::{OpaqueValue, Sum}; use crate::ops::Value; use crate::types::{SumType, TypeArg}; @@ -10,7 +11,7 @@ use crate::{IncomingPort, Node, OutgoingPort, PortIndex}; /// Representation of values used for constant folding. // Should we be non-exhaustive?? // No point in parametrizing by HugrNode since then ConstFold would not be dyn/object-safe -#[derive(Clone, PartialEq, Default)] +#[derive(Clone, Debug, PartialEq, Default)] pub enum FoldVal { /// Value is unknown, must assume that it could be anything #[default] @@ -31,6 +32,52 @@ pub enum FoldVal { LoadedFunction(Node, Vec), // Deliberately skipping Function(Box) ATM } +impl From for FoldVal +where + T: CustomConst, +{ + fn from(value: T) -> Self { + Self::Extension(value.into()) + } +} + +impl FoldVal { + /// Returns a constant "false" value, i.e. the first variant of Sum((), ()). + pub const fn false_val() -> Self { + Self::Sum { + tag: 0, + sum_type: SumType::Unit { size: 2 }, + items: vec![], + } + } + + /// Returns a constant "true" value, i.e. the second variant of Sum((), ()). + pub const fn true_val() -> Self { + Self::Sum { + tag: 1, + sum_type: SumType::Unit { size: 2 }, + items: vec![], + } + } + + /// Returns a constant boolean - either [Self::false_val] or [Self::true_val] + pub const fn from_bool(b: bool) -> Self { + if b { + Self::true_val() + } else { + Self::false_val() + } + } + + /// Extract the specified type of [CustomConst] fro this instance, if it is one + pub fn get_custom_value(&self) -> Option<&T> { + let Self::Extension(e) = self else { + return None; + }; + e.value().downcast_ref() + } +} + impl TryFrom for Value { type Error = Option; @@ -108,6 +155,7 @@ pub trait ConstFold: Send + Sync { .enumerate() .filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?))) .collect::>(); + #[allow(deprecated)] // remove this when fold is removed let outs = self.fold(type_args, &consts); for (p, v) in outs.unwrap_or_default() { outputs[p.index()] = v.into(); @@ -117,6 +165,7 @@ pub trait ConstFold: Send + Sync { /// Given type arguments `type_args` and /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, /// try to evaluate the operation. + #[deprecated(note = "Use fold2")] fn fold( &self, type_args: &[TypeArg], diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index f344b65b1..a3c809dc6 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -458,11 +458,13 @@ impl OpDef { /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s. + #[deprecated(note = "use constant_fold2")] pub fn constant_fold( &self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)], ) -> ConstFoldResult { + #[allow(deprecated)] // we are in deprecated function, remove at same time (self.constant_folder.as_ref())?.fold(type_args, consts) } diff --git a/hugr-core/src/extension/prelude/generic.rs b/hugr-core/src/extension/prelude/generic.rs index 7db49ed79..ff9877bc3 100644 --- a/hugr-core/src/extension/prelude/generic.rs +++ b/hugr-core/src/extension/prelude/generic.rs @@ -169,11 +169,14 @@ impl HasConcrete for LoadNatDef { mod tests { use crate::{ builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::{usize_t, ConstUsize}, - ops::{constant, OpType}, + extension::{ + prelude::{usize_t, ConstUsize}, + FoldVal, + }, + ops::OpType, type_row, types::TypeArg, - HugrView, OutgoingPort, + HugrView, }; use super::LoadNat; @@ -209,10 +212,10 @@ mod tests { match optype { OpType::ExtensionOp(ext_op) => { - let result = ext_op.constant_fold(&[]); - let exp_port: OutgoingPort = 0.into(); - let exp_val: constant::Value = ConstUsize::new(5).into(); - assert_eq!(result, Some(vec![(exp_port, exp_val)])) + let mut out = [FoldVal::Unknown]; + ext_op.constant_fold2(&[], &mut out); + let exp_val: FoldVal = ConstUsize::new(5).into(); + assert_eq!(out, [exp_val]) } _ => panic!(), } diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index aca062639..53170dd0e 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -92,7 +92,9 @@ impl ExtensionOp { } /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. + #[deprecated(note = "use constant_fold2")] pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult { + #[allow(deprecated)] // in deprecated function, remove at same time self.def().constant_fold(self.args(), consts) } diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index abeb61ab0..f466c2418 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -212,10 +212,9 @@ impl HasDef for ConvertOpType { mod test { use rstest::rstest; - use crate::extension::prelude::ConstUsize; - use crate::ops::Value; + use crate::extension::{prelude::ConstUsize, FoldVal}; + use crate::std_extensions::arithmetic::int_types::ConstInt; - use crate::IncomingPort; use super::*; @@ -249,35 +248,21 @@ mod test { } #[rstest] - #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])] - #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])] - #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])] - #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])] + #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[FoldVal::false_val()])] + #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[FoldVal::true_val()])] + #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[FoldVal::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])] + #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[FoldVal::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])] #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])] #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])] fn convert_fold( #[case] op: ConvertOpType, - #[case] inputs: &[Value], - #[case] outputs: &[Value], + #[case] inputs: &[FoldVal], + #[case] outputs: &[FoldVal], ) { - use crate::ops::Value; - - let consts: Vec<(IncomingPort, Value)> = inputs - .iter() - .enumerate() - .map(|(i, v)| (i.into(), v.clone())) - .collect(); - - let res = op - .to_extension_op() + let mut working = vec![FoldVal::Unknown; outputs.len()]; + op.to_extension_op() .unwrap() - .constant_fold(&consts) - .unwrap(); - - for (i, expected) in outputs.iter().enumerate() { - let res_val: &Value = &res.get(i).unwrap().1; - - assert_eq!(res_val, expected); - } + .constant_fold2(inputs, &mut working); + assert_eq!(&working, outputs); } } diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 08b478535..3cbc11ca5 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -129,7 +129,10 @@ impl MakeRegisteredOp for FloatOps { #[cfg(test)] mod test { + use crate::extension::FoldVal; + use crate::std_extensions::arithmetic::float_types::ConstF64; use cgmath::AbsDiffEq; + use itertools::Itertools; use rstest::rstest; use super::*; @@ -154,28 +157,20 @@ mod test { #[case::fceil(FloatOps::fceil, &[42.42], &[43.])] #[case::fround(FloatOps::fround, &[42.42], &[42.])] fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) { - use crate::ops::Value; - use crate::std_extensions::arithmetic::float_types::ConstF64; - let consts: Vec<_> = inputs .iter() - .enumerate() - .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x)))) + .map(|&f| FoldVal::from(ConstF64::new(f))) .collect(); - let res = op - .to_extension_op() + let mut actual = vec![FoldVal::Unknown; outputs.len()]; + op.to_extension_op() .unwrap() - .constant_fold(&consts) - .unwrap(); - - for (i, expected) in outputs.iter().enumerate() { - let res_val: f64 = res - .get(i) - .unwrap() - .1 + .constant_fold2(&consts, &mut actual); + + for (act, expected) in actual.into_iter().zip_eq(outputs) { + let res_val: f64 = act .get_custom_value::() - .expect("This function assumes all incoming constants are floats.") + .expect("This function assumes all output constants are floats.") .value(); assert!( diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index d0ae7baa7..0f3889ca6 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -350,12 +350,11 @@ fn sum_ty_with_err(t: Type) -> Type { #[cfg(test)] mod test { + use itertools::Itertools; use rstest::rstest; - use crate::{ - ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type, - types::Signature, - }; + use crate::std_extensions::arithmetic::int_types::{int_type, ConstInt}; + use crate::{extension::FoldVal, ops::dataflow::DataflowOpTrait, types::Signature}; use super::*; @@ -442,33 +441,20 @@ mod test { #[case] outputs: &[u64], #[case] log_width: u8, ) { - use crate::ops::Value; - use crate::std_extensions::arithmetic::int_types::ConstInt; - let consts: Vec<_> = inputs .iter() - .enumerate() - .map(|(i, &x)| { - ( - i.into(), - Value::extension(ConstInt::new_u(log_width, x).unwrap()), - ) - }) + .map(|&x| FoldVal::from(ConstInt::new_u(log_width, x).unwrap())) .collect(); - let res = op - .to_extension_op() + let mut outs = vec![FoldVal::Unknown; outputs.len()]; + op.to_extension_op() .unwrap() - .constant_fold(&consts) - .unwrap(); + .constant_fold2(&consts, &mut outs); - for (i, &expected) in outputs.iter().enumerate() { - let res_val: u64 = res - .get(i) - .unwrap() - .1 + for (act, &expected) in outs.into_iter().zip_eq(outputs) { + let res_val: u64 = act .get_custom_value::() - .expect("This function assumes all incoming constants are floats.") + .expect("This function assumes all result constants are floats.") .value_u(); assert_eq!(res_val, expected); diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 98804bab0..f67467a75 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -383,12 +383,11 @@ mod test { use rstest::rstest; use crate::extension::prelude::{ - const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, + const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, qb_t, usize_t, ConstUsize, }; use crate::ops::OpTrait; - use crate::PortIndex; use crate::{ - extension::prelude::{qb_t, usize_t, ConstUsize}, + extension::FoldVal, std_extensions::arithmetic::float_types::{float64_type, ConstF64}, types::TypeRow, }; @@ -506,29 +505,23 @@ mod test { #[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])] #[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])] fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) { - let consts: Vec<_> = inputs + let inputs: Vec<_> = inputs .iter() - .enumerate() - .map(|(i, x)| (i.into(), x.to_value())) + .map(TestVal::to_value) + .map(FoldVal::from) .collect(); - let res = op - .with_type(usize_t()) + let mut actual = vec![FoldVal::Unknown; outputs.len()]; + + op.with_type(usize_t()) .to_extension_op() .unwrap() - .constant_fold(&consts) - .unwrap(); + .constant_fold2(&inputs, &mut actual); - for (i, expected) in outputs.iter().enumerate() { - let expected = expected.to_value(); - let res_val = res - .iter() - .find(|(port, _)| port.index() == i) - .unwrap() - .1 - .clone(); + for (act, expected) in actual.into_iter().zip_eq(outputs) { + let expected = expected.to_value().into(); - assert_eq!(res_val, expected); + assert_eq!(act, expected); } } } diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index fcc8be9d3..a25816d65 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -173,15 +173,15 @@ pub(crate) mod test { use std::sync::Arc; use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; + use crate::{ - extension::{ - prelude::bool_t, - simple_op::{MakeOpDef, MakeRegisteredOp}, - }, - ops::{NamedOp, Value}, + extension::simple_op::{MakeOpDef, MakeRegisteredOp}, + extension::{prelude::bool_t, ConstFold, FoldVal}, + ops::NamedOp, Extension, }; + use itertools::Itertools; use rstest::rstest; use strum::IntoEnumIterator; @@ -248,15 +248,11 @@ pub(crate) mod test { use itertools::Itertools; use crate::extension::ConstFold; - let in_vals = ins - .into_iter() - .enumerate() - .map(|(i, b)| (i.into(), Value::from_bool(b))) - .collect_vec(); - assert_eq!( - Some(vec![(0.into(), Value::from_bool(out))]), - op.fold(&[(in_vals.len() as u64).into()], &in_vals) - ); + let in_vals = ins.into_iter().map(FoldVal::from_bool).collect_vec(); + let type_args = [(in_vals.len() as u64).into()]; + let mut outs = [FoldVal::Unknown]; + op.fold2(&type_args, &in_vals, &mut outs); + assert_eq!(outs, [FoldVal::from_bool(out)]); } #[rstest] @@ -272,18 +268,13 @@ pub(crate) mod test { #[case] ins: impl IntoIterator>, #[case] mb_out: Option, ) { - use itertools::Itertools; - - use crate::extension::ConstFold; - let in_vals0 = ins.into_iter().enumerate().collect_vec(); - let num_args = in_vals0.len() as u64; - let in_vals = in_vals0 + let in_vals = ins .into_iter() - .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b)))) + .map(|mb_b| mb_b.map_or(FoldVal::Unknown, FoldVal::from_bool)) .collect_vec(); - assert_eq!( - mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]), - op.fold(&[num_args.into()], &in_vals) - ); + let type_args = [(in_vals.len() as u64).into()]; + let mut outs = [FoldVal::Unknown]; + op.fold2(&type_args, &in_vals, &mut outs); + assert_eq!(outs, [mb_out.map_or(FoldVal::Unknown, FoldVal::from_bool)]); } } From c9b89600e13e4a29254d1e4fcf8c039e3f8f5ce0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 18:05:22 +0100 Subject: [PATCH 10/33] Deprecate Value::Function --- hugr-core/src/extension/const_fold.rs | 1 + hugr-core/src/extension/resolution/types.rs | 1 + .../src/extension/resolution/types_mut.rs | 1 + hugr-core/src/hugr/serialize/test.rs | 2 + hugr-core/src/ops/constant.rs | 63 +++++++++++-------- hugr-passes/src/const_fold/value_handle.rs | 2 + hugr-passes/src/dataflow.rs | 2 + hugr-passes/src/replace_types.rs | 1 + 8 files changed, 48 insertions(+), 25 deletions(-) diff --git a/hugr-core/src/extension/const_fold.rs b/hugr-core/src/extension/const_fold.rs index 819c3c8f7..aa08fce1b 100644 --- a/hugr-core/src/extension/const_fold.rs +++ b/hugr-core/src/extension/const_fold.rs @@ -109,6 +109,7 @@ impl From for FoldVal { fn from(value: Value) -> Self { match value { Value::Extension { e } => FoldVal::Extension(e), + #[allow(deprecated)] // remove when Value::Function removed Value::Function { .. } => FoldVal::Unknown, Value::Sum(Sum { tag, diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 6094f0aee..1cae0661f 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -245,6 +245,7 @@ fn collect_value_exts( let typ = e.get_type(); collect_type_exts(&typ, used_extensions, missing_extensions); } + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr: _ } => { // The extensions used by nested hugrs do not need to be counted for the root hugr. } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index d70d6b861..bc8c9e395 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -249,6 +249,7 @@ pub(super) fn resolve_value_exts( }); } } + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr } => { // We don't need to add the nested hugr's extensions to the main one here, // but we run resolution on it independently. diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 49a7b9321..048ffda2a 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -476,6 +476,7 @@ fn roundtrip_sumtype(#[case] sum_type: SumType) { #[case(Value::extension(ConstInt::new_u(2,1).unwrap()))] #[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())] #[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))] +#[allow(deprecated)] // remove when Value::Function removed #[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())] fn roundtrip_value(#[case] value: Value) { check_testing_roundtrip(value); @@ -538,6 +539,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::AliasDefn { name: "aliasdefn".into(), definition: Type::new_unit_sum(4)})] #[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Any})] #[case(ops::Const::new(Value::false_val()))] +#[allow(deprecated)] // remove when Value::Function removed #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 794e6eaaa..ac91594c8 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -197,28 +197,35 @@ impl From for SerialSum { } } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(tag = "v")] -/// A value that can be stored as a static constant. Representing core types and -/// extension types. -pub enum Value { - /// An extension constant value, that can check it is of a given [CustomType]. - Extension { - #[serde(flatten)] - /// The custom constant value. - e: OpaqueValue, - }, - /// A higher-order function value. - // TODO use a root parametrised hugr, e.g. Hugr. - Function { - /// A Hugr defining the function. - hugr: Box, - }, - /// A Sum variant, with a tag indicating the index of the variant and its - /// value. - #[serde(alias = "Tuple")] - Sum(Sum), +mod inner { + #![allow(deprecated)] // serde-generated code refers to the deprecated Value::Function + + use super::{Hugr, OpaqueValue, Sum}; + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] + #[serde(tag = "v")] + /// A value that can be stored as a static constant. Representing core types and + /// extension types. + pub enum Value { + /// An extension constant value, that can check it is of a given [CustomType]. + Extension { + #[serde(flatten)] + /// The custom constant value. + e: OpaqueValue, + }, + /// A higher-order function value. + // TODO use a root parametrised hugr, e.g. Hugr. + #[deprecated(note = "Flatten and lift contents to a FuncDefn")] + Function { + /// A Hugr defining the function. + hugr: Box, + }, + /// A Sum variant, with a tag indicating the index of the variant and its + /// value. + #[serde(alias = "Tuple")] + Sum(Sum), + } } +pub use inner::Value; /// An opaque newtype around a [`Box`](CustomConst). /// @@ -382,6 +389,7 @@ impl Value { match self { Self::Extension { e } => e.get_type(), Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(), + #[allow(deprecated)] // remove when Value::Function removed Self::Function { hugr } => { let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e)); Type::new_function(func_type.into_owned()) @@ -419,9 +427,11 @@ impl Value { /// # Errors /// /// Returns an error if the Hugr root node does not define a function. + #[deprecated(note = "Flatten and lift contents to a FuncDefn")] pub fn function(hugr: impl Into) -> Result { let hugr = hugr.into(); mono_fn_type(&hugr)?; + #[allow(deprecated)] // In deprecated function, remove at same time Ok(Self::Function { hugr: Box::new(hugr), }) @@ -501,6 +511,7 @@ impl Value { fn name(&self) -> OpName { match self { Self::Extension { e } => format!("const:custom:{}", e.name()), + #[allow(deprecated)] // remove when Value::Function removed Self::Function { hugr: h } => { let Ok(t) = mono_fn_type(h) else { panic!("HUGR root node isn't a valid function parent."); @@ -527,6 +538,7 @@ impl Value { pub fn extension_reqs(&self) -> ExtensionSet { match self { Self::Extension { e } => e.extension_reqs().clone(), + #[allow(deprecated)] // remove when Value::Function removed Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run) Self::Sum(Sum { values, .. }) => { ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs())) @@ -538,6 +550,7 @@ impl Value { pub fn validate(&self) -> Result<(), ConstTypeError> { match self { Self::Extension { e } => Ok(e.value().validate()?), + #[allow(deprecated)] // remove when Value::Function removed Self::Function { hugr } => { mono_fn_type(hugr)?; Ok(()) @@ -568,6 +581,7 @@ impl Value { pub fn try_hash(&self, st: &mut H) -> bool { match self { Value::Extension { e } => e.value().try_hash(&mut *st), + #[allow(deprecated)] // remove when Value::Function removed Value::Function { .. } => false, Value::Sum(s) => s.try_hash(st), } @@ -740,6 +754,7 @@ pub(crate) mod test { ); } + #[allow(deprecated)] // remove when Value::Function removed #[rstest] fn function_value(simple_dfg_hugr: Hugr) { let v = Value::function(simple_dfg_hugr).unwrap(); @@ -944,10 +959,8 @@ pub(crate) mod test { type Strategy = BoxedStrategy; fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { use ::proptest::collection::vec; - let leaf_strat = prop_oneof![ - any::().prop_map(|e| Self::Extension { e }), - crate::proptest::any_hugr().prop_map(|x| Value::function(x).unwrap()) - ]; + let leaf_strat = any::().prop_map(|e| Self::Extension { e }); + leaf_strat .prop_recursive( 3, // No more than 3 branch levels deep diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index 362f50b4b..de49a2975 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -164,6 +164,8 @@ impl From> for Value { } => Value::Extension { e: Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone()), }, + #[allow(deprecated)] + // When we remove Value::Function, have to change `leaf` to be OpaqueValue only ValueHandle::Unhashable { leaf: Either::Right(hugr), .. diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 7029ec0b2..756fd6a1d 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -70,6 +70,7 @@ pub trait ConstLoader { /// Produces an abstract value from a Hugr in a [Value::Function], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + #[deprecated(note = "Remove along with Value::Function")] fn value_from_const_hugr(&self, _loc: ConstLocation, _h: &Hugr) -> Option { None } @@ -112,6 +113,7 @@ where .value_from_opaque(loc, e) .map(PartialValue::from) .unwrap_or(PartialValue::Top), + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr } => cl .value_from_const_hugr(loc, hugr) .map(PartialValue::from) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index fb72b86f8..456af23ae 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -344,6 +344,7 @@ impl ReplaceTypes { false } }), + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr } => self.run_no_validate(&mut **hugr), } } From a90dc97a07e1203d0d70230c7fff892790779cb5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 20:41:05 +0100 Subject: [PATCH 11/33] First bit of test - default to Top if called-func unknown --- hugr-passes/src/dataflow/datalog.rs | 11 ++++- hugr-passes/src/dataflow/test.rs | 65 ++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d1f47be23..008d9a91d 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -6,7 +6,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -341,6 +341,15 @@ pub(super) fn run_datalog( indirect_call(call, func), output_child(func, outp), in_wire_value(outp, p, v); + + // Default out-value is Bottom, but if we can't determine the called function, + // assign everything to Top + out_wire_value(call, p, PV::Top) <-- + node(call), + if let OpType::CallIndirect(ci) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + if !matches!(v, PartialValue::LoadedFunction(_)), + for p in ci.signature().output_ports(); }; let out_wire_values = all_results .out_wire_value diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index ca0cfdb44..f51ad2600 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,9 +1,9 @@ use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::builder::{inout_sig, CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; -use hugr_core::ops::TailLoop; +use hugr_core::ops::{CallIndirect, TailLoop}; use hugr_core::types::TypeRow; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, @@ -547,3 +547,64 @@ fn test_module() { ); } } + +#[test] +fn call_indirect() { + let b2b = || Signature::new_endo(bool_t()); + let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t(); 3], vec![bool_t(); 2])).unwrap(); + + let [id1, id2] = ["id1", "[id2]"].map(|name| { + let fb = dfb.define_function(name, b2b()).unwrap(); + let [inp] = fb.input_wires_arr(); + fb.finish_with_outputs([inp]).unwrap() + }); + + let [inp_direct, which, inp_indirect] = dfb.input_wires_arr(); + let [res1] = dfb + .call(id1.handle(), &[], [inp_direct]) + .unwrap() + .outputs_arr(); + + // We'll unconditionally load both functions, to demonstrate that it's + // the CallIndirect that matters, not just which functions are loaded. + let lf1 = dfb.load_func(id1.handle(), &[]).unwrap(); + let lf2 = dfb.load_func(id2.handle(), &[]).unwrap(); + let bool_func = || Type::new_function(b2b()); + let mut cond = dfb + .conditional_builder( + (vec![type_row![]; 2], which), + [(bool_func(), lf1), (bool_func(), lf2)], + bool_func().into(), + ) + .unwrap(); + let case_false = cond.case_builder(0).unwrap(); + let [f0, _f1] = case_false.input_wires_arr(); + case_false.finish_with_outputs([f0]).unwrap(); + let case_true = cond.case_builder(1).unwrap(); + let [_f0, f1] = case_true.input_wires_arr(); + case_true.finish_with_outputs([f1]).unwrap(); + let [tgt] = cond.finish_sub_container().unwrap().outputs_arr(); + let [res2] = dfb + .add_dataflow_op(CallIndirect { signature: b2b() }, [tgt, inp_indirect]) + .unwrap() + .outputs_arr(); + let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); + + // 1. Test with `which` unknown -> second output unknown + let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); + for inp1 in [pv_false(), pv_true()] { + for inp2 in [pv_false(), pv_true()] { + let results = Machine::new(&h).run( + TestContext, + [ + (0.into(), inp1.clone()), + (1.into(), PartialValue::Top), + (2.into(), inp2), + ], + ); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); + ); + } + } +} From 6e2ad7f1f50a7160869b5a4911fe7283a2834323 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 20:49:21 +0100 Subject: [PATCH 12/33] Test2, fix port numbering and narrow unknown-called-func case --- hugr-passes/src/dataflow/datalog.rs | 6 ++++-- hugr-passes/src/dataflow/test.rs | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 008d9a91d..127d0c79d 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -335,7 +335,8 @@ pub(super) fn run_datalog( out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- indirect_call(call, func), input_child(func, inp), - in_wire_value(call, p, v); + in_wire_value(call, p, v) + if p.index() > 0; out_wire_value(call, OutgoingPort::from(p.index()), v) <-- indirect_call(call, func), @@ -348,7 +349,8 @@ pub(super) fn run_datalog( node(call), if let OpType::CallIndirect(ci) = hugr.get_optype(*call), in_wire_value(call, IncomingPort::from(0), v), - if !matches!(v, PartialValue::LoadedFunction(_)), + // Second alternative below addresses function::Value's: + if matches!(v, PartialValue::Top | PartialValue::Value(_)), for p in ci.signature().output_ports(); }; let out_wire_values = all_results diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index f51ad2600..66ffb891c 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -604,7 +604,22 @@ fn call_indirect() { ); assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); + } + } + + // 2. Test with `which` selecting second function -> both passthrough + for inp1 in [pv_false(), pv_true()] { + for inp2 in [pv_false(), pv_true()] { + let results = Machine::new(&h).run( + TestContext, + [ + (0.into(), inp1.clone()), + (1.into(), pv_true()), + (2.into(), inp2.clone()), + ], ); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); } } } From 2cd85479a77896c026bad76c1e10d95f4070622b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 20:58:39 +0100 Subject: [PATCH 13/33] Test3, plus refactor test --- hugr-passes/src/dataflow/test.rs | 36 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 66ffb891c..3e8a454d8 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -590,18 +590,17 @@ fn call_indirect() { .outputs_arr(); let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); + let run = |v0, v1, v2| { + Machine::new(&h).run( + TestContext, + [(0.into(), v0), (1.into(), v1), (2.into(), v2)], + ) + }; // 1. Test with `which` unknown -> second output unknown let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); for inp1 in [pv_false(), pv_true()] { for inp2 in [pv_false(), pv_true()] { - let results = Machine::new(&h).run( - TestContext, - [ - (0.into(), inp1.clone()), - (1.into(), PartialValue::Top), - (2.into(), inp2), - ], - ); + let results = run(inp1.clone(), PartialValue::Top, inp2); assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); } @@ -610,16 +609,21 @@ fn call_indirect() { // 2. Test with `which` selecting second function -> both passthrough for inp1 in [pv_false(), pv_true()] { for inp2 in [pv_false(), pv_true()] { - let results = Machine::new(&h).run( - TestContext, - [ - (0.into(), inp1.clone()), - (1.into(), pv_true()), - (2.into(), inp2.clone()), - ], - ); + let results = run(inp1.clone(), pv_true(), inp2.clone()); assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); } } + + //3. Test with `which` selecting first function -> alias + for (inp1, inp2) in [(pv_false(), pv_true()), (pv_true(), pv_false())] { + // A. same input bool to both calls + let results1 = run(inp1.clone(), pv_false(), inp1.clone()); + assert_eq!(results1.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results1.read_out_wire(w2), Some(inp1.clone())); + // B. different inputs to both calls. Both alias - even the Call. + let results2 = run(inp1, pv_false(), inp2); + assert_eq!(results2.read_out_wire(w1), Some(pv_true_or_false())); + assert_eq!(results2.read_out_wire(w2), Some(pv_true_or_false())); + } } From ba33062d3f5562017b52e9dab8e1a2c8c429753a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 21:23:25 +0100 Subject: [PATCH 14/33] refactor into cases of rstest --- hugr-passes/src/dataflow/test.rs | 53 +++++++++++++++----------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3e8a454d8..3737d65cd 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -548,8 +548,12 @@ fn test_module() { } } -#[test] -fn call_indirect() { +#[rstest] +#[case(pv_false(), pv_false())] +#[case(pv_false(), pv_true())] +#[case(pv_true(), pv_false())] +#[case(pv_true(), pv_true())] +fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue) { let b2b = || Signature::new_endo(bool_t()); let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t(); 3], vec![bool_t(); 2])).unwrap(); @@ -590,40 +594,31 @@ fn call_indirect() { .outputs_arr(); let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); - let run = |v0, v1, v2| { + let run = |which| { Machine::new(&h).run( TestContext, - [(0.into(), v0), (1.into(), v1), (2.into(), v2)], + [ + (0.into(), inp1.clone()), + (1.into(), which), + (2.into(), inp2.clone()), + ], ) }; - // 1. Test with `which` unknown -> second output unknown let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); - for inp1 in [pv_false(), pv_true()] { - for inp2 in [pv_false(), pv_true()] { - let results = run(inp1.clone(), PartialValue::Top, inp2); - assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); - assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); - } - } + + // 1. Test with `which` unknown -> second output unknown + let results = run(PartialValue::Top); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); // 2. Test with `which` selecting second function -> both passthrough - for inp1 in [pv_false(), pv_true()] { - for inp2 in [pv_false(), pv_true()] { - let results = run(inp1.clone(), pv_true(), inp2.clone()); - assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); - assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); - } - } + let results = run(pv_true()); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); //3. Test with `which` selecting first function -> alias - for (inp1, inp2) in [(pv_false(), pv_true()), (pv_true(), pv_false())] { - // A. same input bool to both calls - let results1 = run(inp1.clone(), pv_false(), inp1.clone()); - assert_eq!(results1.read_out_wire(w1), Some(inp1.clone())); - assert_eq!(results1.read_out_wire(w2), Some(inp1.clone())); - // B. different inputs to both calls. Both alias - even the Call. - let results2 = run(inp1, pv_false(), inp2); - assert_eq!(results2.read_out_wire(w1), Some(pv_true_or_false())); - assert_eq!(results2.read_out_wire(w2), Some(pv_true_or_false())); - } + let results = run(pv_false()); + let out = Some(inp1.join(inp2)); + assert_eq!(results.read_out_wire(w1), out); + assert_eq!(results.read_out_wire(w2), out); } From ee128ce8b1fe7603fcdf80ec323d831cfccad207 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 4 Apr 2025 22:32:11 +0100 Subject: [PATCH 15/33] PartialSum default N=Node --- hugr-passes/src/dataflow/partial_value.rs | 25 ++++++++++------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 46e981244..9472c62f9 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -65,7 +65,7 @@ pub struct LoadedFunction { /// A representation of a value of [SumType], that may have one or more possible tags, /// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] -pub struct PartialSum(pub HashMap>>); +pub struct PartialSum(pub HashMap>>); impl PartialSum { /// New instance for a single known tag. @@ -529,7 +529,6 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; - use hugr_core::Node; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; @@ -568,7 +567,7 @@ mod test { } impl TestSumType { - fn check_value(&self, pv: &PartialValue) -> bool { + fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), @@ -629,7 +628,7 @@ mod test { fn single_sum_strat( tag: usize, elems: Vec>, - ) -> impl Strategy> { + ) -> impl Strategy> { elems .iter() .map(Arc::as_ref) @@ -640,11 +639,11 @@ mod test { fn partial_sum_strat( variants: &[Vec>], - ) -> impl Strategy> { + ) -> impl Strategy> { // We have to clone the `variants` here but only as far as the Vec>> let tagged_variants = variants.iter().cloned().enumerate().collect::>(); // The type annotation here (and the .boxed() enabling it) are just for documentation - let sum_variants_strat: BoxedStrategy>> = + let sum_variants_strat: BoxedStrategy>> = subsequence(tagged_variants, 1..=variants.len()) .prop_flat_map(|selected_variants| { selected_variants @@ -653,7 +652,7 @@ mod test { .collect::>() }) .boxed(); - sum_variants_strat.prop_map(|psums: Vec>| { + sum_variants_strat.prop_map(|psums: Vec>| { let mut psums = psums.into_iter(); let first = psums.next().unwrap(); psums.fold(first, |mut a, b| { @@ -665,7 +664,7 @@ mod test { fn any_partial_value_of_type( ust: &TestSumType, - ) -> impl Strategy> { + ) -> impl Strategy> { match ust { TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), TestSumType::Leaf(Some(i)) => (0..*i) @@ -678,16 +677,15 @@ mod test { fn any_partial_value_with( params: ::Parameters, - ) -> impl Strategy> { + ) -> impl Strategy> { any_with::(params).prop_flat_map(|t| any_partial_value_of_type(&t)) } - fn any_partial_value() -> impl Strategy> { + fn any_partial_value() -> impl Strategy> { any_partial_value_with(Default::default()) } - fn any_partial_values( - ) -> impl Strategy; N]> { + fn any_partial_values() -> impl Strategy; N]> { any::().prop_flat_map(|ust| { TryInto::<[_; N]>::try_into( (0..N) @@ -698,8 +696,7 @@ mod test { }) } - fn any_typed_partial_value( - ) -> impl Strategy)> { + fn any_typed_partial_value() -> impl Strategy)> { any::() .prop_flat_map(|t| any_partial_value_of_type(&t).prop_map(move |v| (t.clone(), v))) } From 2757a3312e71b71211b05eb0722eea4c97e0635e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 5 Apr 2025 08:59:12 +0100 Subject: [PATCH 16/33] require TryFrom,Error=LoadedFunction>, format w/ Debug --- hugr-passes/src/const_fold.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 18 +++++++++--------- hugr-passes/src/dataflow/results.rs | 6 +++--- hugr-passes/src/dataflow/test.rs | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index ef5fc355c..0599d26af 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -127,7 +127,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 9472c62f9..41691fb06 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -175,14 +175,14 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum( + pub fn try_into_sum( self, typ: &Type, - ) -> Result, ExtractValueError> + ) -> Result, ExtractValueError> where V: TryInto, Sum: TryInto, - LoadedFunction: TryInto, + LoadedFunction: TryInto>, { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); @@ -215,7 +215,7 @@ impl PartialSum { /// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] @@ -226,8 +226,8 @@ pub enum ExtractValueError { CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] CouldNotBuildSum(#[source] SE), - #[error("Could not turn LoadedFunction into concrete")] - CouldNotLoadFunction(#[source] LE), + #[error("Could not convert into concrete function pointer {0}")] + CouldNotLoadFunction(LoadedFunction), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -395,14 +395,14 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete( + pub fn try_into_concrete( self, typ: &Type, - ) -> Result> + ) -> Result> where V: TryInto, Sum: TryInto, - LoadedFunction: TryInto, + LoadedFunction: TryInto>, { match self { Self::Value(v) => v diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index a6b864f5e..d5d6bb2fa 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -88,14 +88,14 @@ impl AnalysisResults { /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` #[allow(clippy::type_complexity)] - pub fn try_read_wire_concrete( + pub fn try_read_wire_concrete( &self, w: Wire, - ) -> Result>> + ) -> Result>> where V2: TryFrom + TryFrom, Error = SE> - + TryFrom, Error = LE>, + + TryFrom, Error = LoadedFunction>, { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3737d65cd..f57f3d7aa 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -296,7 +296,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); assert!(results - .try_read_wire_concrete::(cond_o2) + .try_read_wire_concrete::(cond_o2) .is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only From 3c79739926e4eef3368878be3e8e5baf1d927f33 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Apr 2025 10:18:28 +0100 Subject: [PATCH 17/33] Use PartialValue::new_load --- hugr-passes/src/dataflow/datalog.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 127d0c79d..ee0a1a144 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -409,11 +409,9 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::singleton(PartialValue::LoadedFunction( - LoadedFunction { - func_node, - args: load_op.type_args.clone(), - }, + Some(ValueRow::singleton(PartialValue::new_load( + func_node, + load_op.type_args.clone(), ))) } OpType::ExtensionOp(e) => { From 1191150e862035ee7e3dc539a980ea36c2e559f0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Apr 2025 16:13:11 +0100 Subject: [PATCH 18/33] Make call_indirect a lattice, add load_func --- hugr-passes/src/dataflow/datalog.rs | 59 ++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index ee0a1a144..986378dc5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -3,6 +3,8 @@ use std::collections::HashMap; use ascent::lattice::BoundedLattice; +use ascent::Lattice; +use hugr_core::core::HugrNode; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; @@ -325,21 +327,23 @@ pub(super) fn run_datalog( in_wire_value(outp, p, v); // CallIndirect -------------------- - relation indirect_call(H::Node, H::Node); // is an `IndirectCall` to `FuncDefn` - indirect_call(call, func_node) <-- + lattice indirect_call(H::Node, LatticeWrapper); // is an `IndirectCall` to `FuncDefn` + indirect_call(call, tgt) <-- node(call), if let OpType::CallIndirect(_) = hugr.get_optype(*call), in_wire_value(call, IncomingPort::from(0), v), - if let PartialValue::LoadedFunction(LoadedFunction {func_node, ..}) = v; + let tgt = load_func(v); out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- - indirect_call(call, func), + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, input_child(func, inp), in_wire_value(call, p, v) if p.index() > 0; out_wire_value(call, OutgoingPort::from(p.index()), v) <-- - indirect_call(call, func), + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, output_child(func, outp), in_wire_value(outp, p, v); @@ -367,6 +371,51 @@ pub(super) fn run_datalog( } } +#[derive(PartialEq, Eq, Hash, Clone, PartialOrd)] +enum LatticeWrapper { + Bottom, + Value(T), + Top, +} + +impl Lattice for LatticeWrapper { + fn meet_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + return false; + }; + if *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + *self = other; + return true; + }; + // Both are `Value`s and not equal + *self = LatticeWrapper::Bottom; + true + } + + fn join_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + return false; + }; + if *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + *self = other; + return true; + }; + // Both are `Value`s and are not equal + *self = LatticeWrapper::Top; + true + } +} + +fn load_func(v: &PV) -> LatticeWrapper { + match v { + PartialValue::Bottom | PartialValue::PartialSum(_) => LatticeWrapper::Bottom, + PartialValue::LoadedFunction(LoadedFunction { func_node, .. }) => { + LatticeWrapper::Value(*func_node) + } + PartialValue::Value(_) | PartialValue::Top => LatticeWrapper::Top, + } +} + fn propagate_leaf_op( ctx: &mut impl DFContext, hugr: &H, From 4ba5603bf0c2c0a6ebda744ee27799892e2c66a5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Apr 2025 13:07:17 +0100 Subject: [PATCH 19/33] Drop Hugr from inside ConstFoldContext --- hugr-passes/src/const_fold.rs | 27 +++++++++++++-------------- hugr-passes/src/const_fold/test.rs | 6 ++---- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 0599d26af..31a1ddfdb 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -98,7 +98,7 @@ impl ConstantFoldPass { n, in_vals.iter().map(|(p, v)| { let const_with_dummy_loc = partial_from_const( - &ConstFoldContext(hugr), + &ConstFoldContext, ConstLocation::Field(p.index(), &fresh_node.into()), v, ); @@ -108,7 +108,7 @@ impl ConstantFoldPass { .map_err(|opty| ConstFoldError::InvalidEntryPoint(n, opty))?; } - let results = m.run(ConstFoldContext(hugr), []); + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); let wires_to_break = hugr @@ -201,36 +201,35 @@ pub fn constant_fold_pass(h: &mut H) { c.run(h).unwrap() } -// Probably intend to remove this in a future PR, but not certain, so leaving in for now -struct ConstFoldContext<'a, H>(#[allow(unused)] &'a H); +struct ConstFoldContext; -impl> ConstLoader> for ConstFoldContext<'_, H> { - type Node = H::Node; +impl ConstLoader> for ConstFoldContext { + type Node = Node; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } } -impl> DFContext> for ConstFoldContext<'_, H> { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: H::Node, + node: Node, op: &ExtensionOp, - ins: &[PartialValue>], - outs: &mut [PartialValue>], + ins: &[PartialValue>], + outs: &mut [PartialValue>], ) { let sig = op.signature(); let known_ins = sig diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..58e69c568 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -42,8 +42,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&temp); + let ctx = ConstFoldContext; let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -114,8 +113,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let temp = Hugr::default(); - let mut ctx = ConstFoldContext(&temp); + let mut ctx = ConstFoldContext; let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); From 4db3f5d72d1661bd0ec099a33afdcb075740785a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Tue, 15 Apr 2025 13:52:31 +0100 Subject: [PATCH 20/33] ci: Run ci checks on PRs to any branch --- .github/workflows/ci-py.yml | 2 +- .github/workflows/ci-rs.yml | 2 +- .github/workflows/pr-title.yml | 2 +- .github/workflows/semver-checks.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index c393c195f..6ef3edce5 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '*' + - '**' merge_group: types: [checks_requested] workflow_dispatch: {} diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index d7c94e3a3..824291a8b 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '*' + - '**' merge_group: types: [checks_requested] workflow_dispatch: {} diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml index 333eec96f..f778c1363 100644 --- a/.github/workflows/pr-title.yml +++ b/.github/workflows/pr-title.yml @@ -2,7 +2,7 @@ name: Check Conventional Commits format on: pull_request_target: branches: - - main + - '**' types: - opened - edited diff --git a/.github/workflows/semver-checks.yml b/.github/workflows/semver-checks.yml index e884b2e36..2c410aa85 100644 --- a/.github/workflows/semver-checks.yml +++ b/.github/workflows/semver-checks.yml @@ -2,7 +2,7 @@ name: Rust Semver Checks on: pull_request_target: branches: - - main + - '**' jobs: # Check if changes were made to the relevant files. From 81447ecf83acdba6d50427a64637447899553054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:58:19 +0100 Subject: [PATCH 21/33] feat!: Allow generic Nodes in HugrMut insert operations (#2075) `insert_hugr`, `insert_from_view`, and `insert_subgraph` were written before we made `Node` a type generic, and incorrectly assumed the return type maps were always `hugr::Node`s. The methods were either unusable or incorrect when using generic `HugrView`s source/targets with non-base node types. This PR fixes that, and additionally allows us us to have `SiblingSubgraph::extract_subgraph` work for generic `HugrViews`. BREAKING CHANGE: Added Node type parameters to extraction operations in `HugrMut`. --- hugr-core/src/builder/build_traits.rs | 2 +- hugr-core/src/hugr/hugrmut.rs | 114 ++++++++++++------- hugr-core/src/hugr/views/sibling_subgraph.rs | 4 +- 3 files changed, 75 insertions(+), 45 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index f1613895d..e17d172ca 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -119,7 +119,7 @@ pub trait Container { } /// Insert a copy of a HUGR as a child of the container. - fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult { + fn add_hugr_view(&mut self, child: &H) -> InsertionResult { let parent = self.container_node(); self.hugr_mut().insert_from_view(parent, child) } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index f3ef094be..38eb59222 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,13 +1,15 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, PortMut, PortView, SecondaryMap}; +use crate::core::HugrNode; use crate::extension::ExtensionRegistry; +use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; @@ -162,10 +164,10 @@ pub trait HugrMut: HugrMutInternals { /// correspondingly for `Dom` edges) fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { panic_invalid_node(self, root); panic_invalid_node(self, new_parent); self.hugr_mut().copy_descendants(root, new_parent, subst) @@ -225,7 +227,7 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult { + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_hugr(root, other) } @@ -236,7 +238,11 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_from_view(root, other) } @@ -255,12 +261,12 @@ pub trait HugrMut: HugrMutInternals { // TODO: Try to preserve the order when possible? We cannot always ensure // it, since the subgraph may have arbitrary nodes without including their // parent. - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { panic_invalid_node(self, root); self.hugr_mut().insert_subgraph(root, other, subgraph) } @@ -307,20 +313,32 @@ pub trait HugrMut: HugrMutInternals { /// Records the result of inserting a Hugr or view /// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]. -pub struct InsertionResult { +/// +/// Contains a map from the nodes in the source HUGR to the nodes in the +/// target HUGR, using their respective `Node` types. +pub struct InsertionResult { /// The node, after insertion, that was the root of the inserted Hugr. /// /// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root] - pub new_root: Node, + pub new_root: TargetN, /// Map from nodes in the Hugr/view that was inserted, to their new /// positions in the Hugr into which said was inserted. - pub node_map: HashMap, + pub node_map: HashMap, } -fn translate_indices( +/// Translate a portgraph node index map into a map from nodes in the source +/// HUGR to nodes in the target HUGR. +/// +/// This is as a helper in `insert_hugr` and `insert_subgraph`, where the source +/// HUGR may be an arbitrary `HugrView` with generic node types. +fn translate_indices( + mut source_node: impl FnMut(portgraph::NodeIndex) -> N, + mut target_node: impl FnMut(portgraph::NodeIndex) -> Node, node_map: HashMap, -) -> impl Iterator { - node_map.into_iter().map(|(k, v)| (k.into(), v.into())) +) -> impl Iterator { + node_map + .into_iter() + .map(move |(k, v)| (source_node(k), target_node(v))) } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -406,7 +424,11 @@ impl + AsMut> HugrMut for T (src_port, dst_port) } - fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult { + fn insert_hugr( + &mut self, + root: Self::Node, + mut other: Hugr, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); // Update the optypes and metadata, taking them from the other graph. // @@ -423,11 +445,16 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); // Update the optypes and metadata, copying them from the other graph. // @@ -444,22 +471,28 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { // Create a portgraph view with the explicit list of nodes defined by the subgraph. - let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> = + let context: HashSet = subgraph + .nodes() + .iter() + .map(|&n| other.get_pg_index(n)) + .collect(); + let portgraph: NodeFiltered<_, NodeFilter>, _> = NodeFiltered::new_node_filtered( other.portgraph(), - |node, ctx| ctx.contains(&node.into()), - subgraph.nodes(), + |node, ctx| ctx.contains(&node), + context, ); let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. @@ -473,25 +506,24 @@ impl + AsMut> HugrMut for T self.use_extensions(exts); } } - translate_indices(node_map).collect() + translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map).collect() } fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); let root2 = descendants.next(); debug_assert_eq!(root2, Some(root.pg_index())); let nodes = Vec::from_iter(descendants); - let node_map = translate_indices( - portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) - .copy_in_parent() - .expect("Is a MultiPortGraph"), - ) - .collect::>(); + let node_map = portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) + .copy_in_parent() + .expect("Is a MultiPortGraph"); + let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) + .collect::>(); for node in self.children(root).collect::>() { self.set_parent(*node_map.get(&node).unwrap(), new_parent); @@ -563,10 +595,10 @@ fn insert_hugr_internal( /// sibling order in the hierarchy. This is due to the subgraph not necessarily /// having a single root, so the logic for reconstructing the hierarchy is not /// able to just do a BFS. -fn insert_subgraph_internal( +fn insert_subgraph_internal( hugr: &mut Hugr, root: Node, - other: &impl HugrView, + other: &impl HugrView, portgraph: &impl portgraph::LinkView, ) -> HashMap { let node_map = hugr diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index a0bf1a3da..c681fafc9 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -446,16 +446,14 @@ impl SiblingSubgraph { nu_out, )) } -} -impl SiblingSubgraph { /// Create a new Hugr containing only the subgraph. /// /// The new Hugr will contain a [FuncDefn][crate::ops::FuncDefn] root /// with the same signature as the subgraph and the specified `name` pub fn extract_subgraph( &self, - hugr: &impl HugrView, + hugr: &impl HugrView, name: impl Into, ) -> Hugr { let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap(); From ef1cba0a85f803423e9f14450844ad4f7300f1fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 16:02:17 +0100 Subject: [PATCH 22/33] fix!: Don't expose `HugrMutInternals` (#2071) `HugrMutInternals` is part of the semi-private traits defined in `hugr-core`. While most things get re-exported in `hugr`, we `*Internal` traits require you to explicitly declare a dependency on the `-core` package (as we don't want most users to have to interact with them). For some reason there was a public re-export of the trait in a re-exported module, so it ended up appearing in `hugr` anyways. BREAKING CHANGE: Removed public re-export of `HugrMutInternals` from `hugr`. --- hugr-core/src/hugr/rewrite/simple_replace.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index cf7f2922a..b4ec37db1 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -4,7 +4,6 @@ use std::collections::HashMap; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; -pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; From a8e45536ab85106d457c08e2b39fa5f81d912ad6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Apr 2025 13:29:22 +0100 Subject: [PATCH 23/33] trait AsConcrete --- hugr-passes/src/const_fold.rs | 2 +- hugr-passes/src/const_fold/value_handle.rs | 23 ++++-- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 82 ++++++++++------------ hugr-passes/src/dataflow/results.rs | 12 +--- hugr-passes/src/dataflow/test.rs | 29 ++++++-- 6 files changed, 84 insertions(+), 66 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 31a1ddfdb..e73e3cd0e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -127,7 +127,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..e5c99a8e7 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -1,16 +1,18 @@ //! Total equality (and hence [AbstractValue] support for [Value]s //! (by adding a source-Node and part unhashable constants) use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::convert::Infallible; use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; +use hugr_core::types::ConstTypeError; use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, ConstLocation}; +use crate::dataflow::{AbstractValue, AsConcrete, ConstLocation, LoadedFunction, Sum}; /// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] @@ -153,9 +155,12 @@ impl Hash for ValueHandle { // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl From> for Value { - fn from(value: ValueHandle) -> Self { - match value { +impl AsConcrete, N> for Value { + type ValErr = Infallible; + type SumErr = ConstTypeError; + + fn from_value(value: ValueHandle) -> Result { + Ok(match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { leaf: Either::Left(val), @@ -169,7 +174,15 @@ impl From> for Value { } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap(), - } + }) + } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 7029ec0b2..1f7c1ae5a 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -9,7 +9,7 @@ mod results; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, LoadedFunction, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, AsConcrete, LoadedFunction, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 41691fb06..a79ad8b28 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,6 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::Node; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; @@ -166,6 +165,30 @@ impl PartialSum { } } +/// Trait implemented by value types into which [PartialValue]s can be converted, +/// so long as the PV has no [Top](PartialValue::Top), [Bottom](PartialValue::Bottom) +/// or [PartialSum]s with more than one possible tag. See [PartialSum::try_into_sum] +/// and [PartialValue::try_into_concrete]. +/// +/// `V` is the type of [AbstractValue] from which `Self` can (fallibly) be constructed, +/// `N` is the type of [HugrNode](hugr_core::core::HugrNode) for function pointers +pub trait AsConcrete: Sized { + /// Kind of error raised when creating `Self` from a value `V`, see [Self::from_value] + type ValErr: std::error::Error; + /// Kind of error that may be raised when creating `Self` from a [Sum] of `Self`s, + /// see [Self::from_sum] + type SumErr: std::error::Error; + + /// Convert an abstract value into concrete + fn from_value(val: V) -> Result; + + /// Convert a sum (of concrete values, already recursively converted) into concrete + fn from_sum(sum: Sum) -> Result; + + /// Convert a function pointer into a concrete value + fn from_func(func: LoadedFunction) -> Result>; +} + impl PartialSum { /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. @@ -175,15 +198,11 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum( + #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases + pub fn try_into_sum>( self, typ: &Type, - ) -> Result, ExtractValueError> - where - V: TryInto, - Sum: TryInto, - LoadedFunction: TryInto>, - { + ) -> Result, ExtractValueError> { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); } @@ -395,49 +414,26 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete( + pub fn try_into_concrete>( self, typ: &Type, - ) -> Result> - where - V: TryInto, - Sum: TryInto, - LoadedFunction: TryInto>, - { + ) -> Result> { match self { - Self::Value(v) => v - .clone() - .try_into() - .map_err(|e| ExtractValueError::CouldNotConvert(v, e)), - Self::LoadedFunction(lf) => lf - .try_into() - .map_err(ExtractValueError::CouldNotLoadFunction), - Self::PartialSum(ps) => ps - .try_into_sum(typ)? - .try_into() - .map_err(ExtractValueError::CouldNotBuildSum), + Self::Value(v) => { + C::from_value(v.clone()).map_err(|e| ExtractValueError::CouldNotConvert(v, e)) + } + Self::LoadedFunction(lf) => { + C::from_func(lf).map_err(ExtractValueError::CouldNotLoadFunction) + } + Self::PartialSum(ps) => { + C::from_sum(ps.try_into_sum(typ)?).map_err(ExtractValueError::CouldNotBuildSum) + } Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } } -impl TryFrom> for Value { - type Error = ConstTypeError; - - fn try_from(value: Sum) -> Result { - Self::sum(value.tag, value.values, value.st) - } -} - -impl TryFrom> for Value { - type Error = LoadedFunction; - - fn try_from(value: LoadedFunction) -> Result { - Err(value) - } -} - impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index d5d6bb2fa..c4a94a9e7 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -3,8 +3,7 @@ use std::collections::HashMap; use hugr_core::{HugrView, PortIndex, Wire}; use super::{ - datalog::InWire, partial_value::ExtractValueError, AbstractValue, LoadedFunction, PartialValue, - Sum, + datalog::InWire, partial_value::ExtractValueError, AbstractValue, AsConcrete, PartialValue, }; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. @@ -88,15 +87,10 @@ impl AnalysisResults { /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` #[allow(clippy::type_complexity)] - pub fn try_read_wire_concrete( + pub fn try_read_wire_concrete>( &self, w: Wire, - ) -> Result>> - where - V2: TryFrom - + TryFrom, Error = SE> - + TryFrom, Error = LoadedFunction>, - { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index f57f3d7aa..1c4b4e439 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,10 +1,12 @@ +use std::convert::Infallible; + use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{inout_sig, CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; use hugr_core::ops::{CallIndirect, TailLoop}; -use hugr_core::types::TypeRow; +use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -19,7 +21,10 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{ + AbstractValue, AsConcrete, ConstLoader, DFContext, LoadedFunction, Machine, PartialValue, Sum, + TailLoopTermination, +}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,10 +40,22 @@ impl ConstLoader for TestContext { impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) -impl From for Value { - fn from(v: Void) -> Self { +impl AsConcrete for Value { + type ValErr = Infallible; + + type SumErr = ConstTypeError; + + fn from_value(v: Void) -> Result { match v {} } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) + } } fn pv_false() -> PartialValue { @@ -295,9 +312,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .try_read_wire_concrete::(cond_o2) - .is_err()); + assert!(results.try_read_wire_concrete::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); From f032848947029e6051517aa26a1c4a7fd5313ad5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Apr 2025 16:39:45 +0100 Subject: [PATCH 24/33] LatticeWrapper requires only PartialEq+PartialOrd --- hugr-passes/src/dataflow/datalog.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 986378dc5..a5911c665 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -4,7 +4,6 @@ use std::collections::HashMap; use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::core::HugrNode; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; @@ -378,7 +377,7 @@ enum LatticeWrapper { Top, } -impl Lattice for LatticeWrapper { +impl Lattice for LatticeWrapper { fn meet_mut(&mut self, other: Self) -> bool { if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { return false; From fac6c8b92bab8904f47eaa9ebf4581eb1e1f4095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:10:32 +0100 Subject: [PATCH 25/33] feat!: Mark all Error enums as non_exhaustive (#2056) #2027 ended up being breaking due to adding a new variant to an error enum missing the `non_exhaustive` marker. This (breaking) PR makes sure all error enums have the flag. BREAKING CHANGE: Marked all Error enums as `non_exhaustive` --- hugr-core/src/extension.rs | 1 + hugr-core/src/hugr/serialize/upgrade.rs | 1 + hugr-core/src/import.rs | 2 ++ hugr-model/src/v0/ast/resolve.rs | 1 + hugr-model/src/v0/table/mod.rs | 1 + hugr-passes/src/lower.rs | 1 + hugr-passes/src/non_local.rs | 1 + hugr-passes/src/replace_types/linearize.rs | 1 + hugr-passes/src/validation.rs | 1 + 9 files changed, 10 insertions(+) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 408c88e15..b6e059050 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -378,6 +378,7 @@ pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry { /// TODO: decide on failure modes #[derive(Debug, Clone, Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum SignatureError { /// Name mismatch #[error("Definition name ({0}) and instantiation name ({1}) do not match.")] diff --git a/hugr-core/src/hugr/serialize/upgrade.rs b/hugr-core/src/hugr/serialize/upgrade.rs index 2741b6175..ac1ac1eea 100644 --- a/hugr-core/src/hugr/serialize/upgrade.rs +++ b/hugr-core/src/hugr/serialize/upgrade.rs @@ -1,6 +1,7 @@ use thiserror::Error; #[derive(Debug, Error)] +#[non_exhaustive] pub enum UpgradeError { #[error(transparent)] Deserialize(#[from] serde_json::Error), diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 642c84c41..899deb17d 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -35,6 +35,7 @@ use thiserror::Error; /// Error during import. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ImportError { /// The model contains a feature that is not supported by the importer yet. /// Errors of this kind are expected to be removed as the model format and @@ -75,6 +76,7 @@ pub enum ImportError { /// Import error caused by incorrect order hints. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index 2f8a5ba6e..c9be8896b 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -362,6 +362,7 @@ impl<'a> Context<'a> { /// Error that may occur in [`Module::resolve`]. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ResolveError { /// Unknown variable. #[error("unknown var: {0}")] diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 756a52c1e..55a4b9889 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -456,6 +456,7 @@ pub struct VarId(pub NodeId, pub VarIndex); /// Errors that can occur when traversing and interpreting the model. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ModelError { /// There is a reference to a node that does not exist. #[error("node not found: {0}")] diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 09e02c41d..8f8920967 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -35,6 +35,7 @@ pub fn replace_many_ops>( /// Errors produced by the [`lower_ops`] function. #[derive(Debug, Error)] #[error(transparent)] +#[non_exhaustive] pub enum LowerError { /// Invalid subgraph. #[error("Subgraph formed by node is invalid: {0}")] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index fca74657b..180e9d6fc 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -23,6 +23,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator { #[error("Found {} nonlocal edges", .0.len())] Edges(Vec<(N, IncomingPort)>), diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 371798dce..b3fc20da9 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -127,6 +127,7 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] NeedCopyDiscard(Type), diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 5f53f403c..6c3e61fb4 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -25,6 +25,7 @@ pub enum ValidationLevel { #[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] InputError { From 3066b654e10ddb8bdc0ce2305d2795d1733dbdbf Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Apr 2025 17:22:35 +0100 Subject: [PATCH 26/33] PartialValue: Arbitrary+TestSumType etc. generates LoadedFunction --- hugr-passes/src/dataflow/partial_value.rs | 36 ++++++++++++++++------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index a79ad8b28..7ccb761b2 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -449,7 +449,7 @@ impl Lattice for PartialValue (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) if lf1.func_node == lf2.func_node => { - // TODO we should also require TypeArgs to be equal by at the moment these are ignored + // TODO we should also join the TypeArgs but at the moment these are ignored (Self::LoadedFunction(lf1), false) } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { @@ -476,7 +476,7 @@ impl Lattice for PartialValue (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) if lf1.func_node == lf2.func_node => { - // TODO we should also require TypeArgs to be equal by at the moment these are ignored + // TODO we should also meet the TypeArgs but at the moment these are ignored (Self::LoadedFunction(lf1), false) } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { @@ -525,19 +525,20 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::NodeIndex; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PartialSum, PartialValue}; + use super::{AbstractValue, LoadedFunction, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { Branch(Vec>>), - /// None => unit, Some => TestValue <= this *usize* - Leaf(Option), + LeafVal(usize), // contains a TestValue <= this usize + LeafPtr(usize), // contains a LoadedFunction with node <= this *usize* } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -566,8 +567,11 @@ mod test { fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::LeafVal(max), PartialValue::Value(TestValue(val))) => val <= max, + ( + Self::LeafPtr(max), + PartialValue::LoadedFunction(LoadedFunction { func_node, args }), + ) => args.len() == 0 && func_node.index() <= *max, (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { @@ -594,8 +598,11 @@ mod test { fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; - let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + let leaf_strat = prop_oneof![ + (0..usize::MAX).prop_map(TestSumType::LeafVal), + // This is the maximum value accepted by portgraph::NodeIndex::new + (0..(2usize ^ 31 - 2)).prop_map(TestSumType::LeafPtr) + ]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, @@ -662,11 +669,18 @@ mod test { ust: &TestSumType, ) -> impl Strategy> { match ust { - TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), - TestSumType::Leaf(Some(i)) => (0..*i) + TestSumType::LeafVal(i) => (0..=*i) .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), + TestSumType::LeafPtr(i) => (0..=*i) + .prop_map(|i| { + PartialValue::LoadedFunction(LoadedFunction { + func_node: portgraph::NodeIndex::new(i).into(), + args: vec![], + }) + }) + .boxed(), TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } From 3eb0c7f97cadc6c881d4bdc249926ea97fe2704f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Apr 2025 17:28:32 +0100 Subject: [PATCH 27/33] clippy, fix predecence --- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 7ccb761b2..240f4f2d6 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -571,7 +571,7 @@ mod test { ( Self::LeafPtr(max), PartialValue::LoadedFunction(LoadedFunction { func_node, args }), - ) => args.len() == 0 && func_node.index() <= *max, + ) => args.is_empty() && func_node.index() <= *max, (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { @@ -601,7 +601,7 @@ mod test { let leaf_strat = prop_oneof![ (0..usize::MAX).prop_map(TestSumType::LeafVal), // This is the maximum value accepted by portgraph::NodeIndex::new - (0..(2usize ^ 31 - 2)).prop_map(TestSumType::LeafPtr) + (0..((2usize ^ 31) - 2)).prop_map(TestSumType::LeafPtr) ]; leaf_strat.prop_mutually_recursive( params.depth as u32, From 887a3861253c25c39d665c0c2dc593a8f99ba371 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Apr 2025 18:13:14 +0100 Subject: [PATCH 28/33] test LatticeWrapper --- hugr-passes/src/dataflow/datalog.rs | 49 ++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index a5911c665..ad1a99345 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -370,7 +370,7 @@ pub(super) fn run_datalog( } } -#[derive(PartialEq, Eq, Hash, Clone, PartialOrd)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd)] enum LatticeWrapper { Bottom, Value(T), @@ -482,3 +482,50 @@ fn propagate_leaf_op( o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive } } + +#[cfg(test)] +mod test { + use ascent::Lattice; + + use super::LatticeWrapper; + + #[test] + fn latwrap_join() { + for lv in [ + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + LatticeWrapper::Top, + ] { + let mut subject = LatticeWrapper::Bottom; + assert!(subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.join_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Top + ); + assert_eq!(subject, LatticeWrapper::Top); + } + } + + #[test] + fn latwrap_meet() { + for lv in [ + LatticeWrapper::Bottom, + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + ] { + let mut subject = LatticeWrapper::Top; + assert!(subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.meet_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Bottom + ); + assert_eq!(subject, LatticeWrapper::Bottom); + } + } +} From baaca02359a307f8691ab3985313272339a8c494 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 11:11:13 +0100 Subject: [PATCH 29/33] feat!: Handle CallIndirect in Dataflow Analysis (#2059) * PartialValue now has a LoadedFunction variant, created by LoadFunction nodes (only, although other analyses are able to create PartialValues if they want) * This requires adding a type parameter to PartialValue for the type of Node, which gets everywhere :-(. * Use this to handle CallIndirects *with known targets* (it'll be a single known target or none at all) just like other Calls to the same function * deprecate (and ignore) `value_from_function` * Add a new trait `AsConcrete` for the result type of `PartialValue::try_into_concrete` and `PartialSum::try_into_sum` Note almost no change to constant folding (only to drop impl of `value_from_function`) BREAKING CHANGE: in dataflow framework, PartialValue now has additional variant; `try_into_concrete` requires the target type to implement AsConcrete. --- hugr-passes/src/const_fold.rs | 63 ++--- hugr-passes/src/const_fold/test.rs | 6 +- hugr-passes/src/const_fold/value_handle.rs | 23 +- hugr-passes/src/dataflow.rs | 17 +- hugr-passes/src/dataflow/datalog.rs | 171 +++++++++++-- hugr-passes/src/dataflow/partial_value.rs | 267 +++++++++++++-------- hugr-passes/src/dataflow/results.rs | 22 +- hugr-passes/src/dataflow/test.rs | 108 ++++++++- hugr-passes/src/dataflow/value_row.rs | 38 +-- 9 files changed, 492 insertions(+), 223 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..e73e3cd0e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,15 +7,11 @@ use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, + hugr::hugrmut::HugrMut, ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, }, - types::{EdgeKind, TypeArg}, + types::EdgeKind, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; @@ -102,7 +98,7 @@ impl ConstantFoldPass { n, in_vals.iter().map(|(p, v)| { let const_with_dummy_loc = partial_from_const( - &ConstFoldContext(hugr), + &ConstFoldContext, ConstLocation::Field(p.index(), &fresh_node.into()), v, ); @@ -112,7 +108,7 @@ impl ConstantFoldPass { .map_err(|opty| ConstFoldError::InvalidEntryPoint(n, opty))?; } - let results = m.run(ConstFoldContext(hugr), []); + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); let wires_to_break = hugr @@ -131,7 +127,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) @@ -205,60 +201,35 @@ pub fn constant_fold_pass(h: &mut H) { c.run(h).unwrap() } -struct ConstFoldContext<'a, H>(&'a H); - -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} +struct ConstFoldContext; -impl> ConstLoader> for ConstFoldContext<'_, H> { - type Node = H::Node; +impl ConstLoader> for ConstFoldContext { + type Node = Node; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } - - fn value_from_function( - &self, - node: H::Node, - type_args: &[TypeArg], - ) -> Option> { - if !type_args.is_empty() { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; - Some(ValueHandle::new_const_hugr( - ConstLocation::Node(node), - Box::new(func.extract_hugr()), - )) - } } -impl> DFContext> for ConstFoldContext<'_, H> { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: H::Node, + node: Node, op: &ExtensionOp, - ins: &[PartialValue>], - outs: &mut [PartialValue>], + ins: &[PartialValue>], + outs: &mut [PartialValue>], ) { let sig = op.signature(); let known_ins = sig diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..58e69c568 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -42,8 +42,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&temp); + let ctx = ConstFoldContext; let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -114,8 +113,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let temp = Hugr::default(); - let mut ctx = ConstFoldContext(&temp); + let mut ctx = ConstFoldContext; let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..e5c99a8e7 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -1,16 +1,18 @@ //! Total equality (and hence [AbstractValue] support for [Value]s //! (by adding a source-Node and part unhashable constants) use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::convert::Infallible; use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; +use hugr_core::types::ConstTypeError; use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, ConstLocation}; +use crate::dataflow::{AbstractValue, AsConcrete, ConstLocation, LoadedFunction, Sum}; /// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] @@ -153,9 +155,12 @@ impl Hash for ValueHandle { // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl From> for Value { - fn from(value: ValueHandle) -> Self { - match value { +impl AsConcrete, N> for Value { + type ValErr = Infallible; + type SumErr = ConstTypeError; + + fn from_value(value: ValueHandle) -> Result { + Ok(match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { leaf: Either::Left(val), @@ -169,7 +174,15 @@ impl From> for Value { } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap(), - } + }) + } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 43caa9c94..1f7c1ae5a 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -9,7 +9,7 @@ mod results; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, AsConcrete, LoadedFunction, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; @@ -31,8 +31,8 @@ pub trait DFContext: ConstLoader { &mut self, _node: Self::Node, _e: &ExtensionOp, - _ins: &[PartialValue], - _outs: &mut [PartialValue], + _ins: &[PartialValue], + _outs: &mut [PartialValue], ) { } } @@ -55,8 +55,8 @@ impl From for ConstLocation<'_, N> { } /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. -/// Implementors will likely want to override some/all of [Self::value_from_opaque], -/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// Implementors will likely want to override either/both of [Self::value_from_opaque] +/// and [Self::value_from_const_hugr]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { /// The type of nodes in the Hugr. @@ -81,6 +81,7 @@ pub trait ConstLoader { /// [FuncDefn]: hugr_core::ops::FuncDefn /// [FuncDecl]: hugr_core::ops::FuncDecl /// [LoadFunction]: hugr_core::ops::LoadFunction + #[deprecated(note = "Automatically handled by Datalog, implementation will be ignored")] fn value_from_function(&self, _node: Self::Node, _type_args: &[TypeArg]) -> Option { None } @@ -94,7 +95,7 @@ pub fn partial_from_const<'a, V, CL: ConstLoader>( cl: &CL, loc: impl Into>, cst: &Value, -) -> PartialValue +) -> PartialValue where CL::Node: 'a, { @@ -120,8 +121,8 @@ where /// A row of inputs to a node contains bottom (can't happen, the node /// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). -pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( - elements: impl IntoIterator>, +pub fn row_contains_bottom<'a, V: 'a, N: 'a>( + elements: impl IntoIterator>, ) -> bool { elements.into_iter().any(PartialValue::contains_bottom) } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 13e510daf..ad1a99345 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -3,19 +3,22 @@ use std::collections::HashMap; use ascent::lattice::BoundedLattice; +use ascent::Lattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; use super::{ partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, - PartialValue, + LoadedFunction, PartialValue, }; -type PV = PartialValue; +type PV = PartialValue; + +type NodeInputs = Vec<(IncomingPort, PV)>; /// Basic structure for performing an analysis. Usage: /// 1. Make a new instance via [Self::new()] @@ -25,10 +28,7 @@ type PV = PartialValue; /// [Self::prepopulate_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine( - H, - HashMap)>>, -); +pub struct Machine(H, HashMap>); impl Machine { /// Create a new Machine to analyse the given Hugr(View) @@ -40,7 +40,7 @@ impl Machine { impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed /// or any value previously prepopulated for the same Wire. - pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { for (n, inp) in self.0.linked_inputs(w.node(), w.source()) { self.1.entry(n).or_default().push((inp, v.clone())); } @@ -54,7 +54,7 @@ impl Machine { pub fn prepopulate_inputs( &mut self, parent: H::Node, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> Result<(), OpType> { match self.0.get_optype(parent) { OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => { @@ -102,7 +102,7 @@ impl Machine { pub fn run( mut self, context: impl DFContext, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = self.0.root(); if self.0.get_optype(root).is_module() { @@ -135,10 +135,12 @@ impl Machine { } } +pub(super) type InWire = (N, IncomingPort, PartialValue); + pub(super) fn run_datalog( mut ctx: impl DFContext, hugr: H, - in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, + in_wire_value_proto: Vec>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. @@ -155,9 +157,9 @@ pub(super) fn run_datalog( relation parent_of_node(H::Node, H::Node); // is parent of relation input_child(H::Node, H::Node); // has 1st child that is its `Input` relation output_child(H::Node, H::Node); // has 2nd child that is its `Output` - lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value - lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value - lattice node_in_value_row(H::Node, ValueRow); // 's inputs are + lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(H::Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -322,6 +324,37 @@ pub(super) fn run_datalog( func_call(call, func), output_child(func, outp), in_wire_value(outp, p, v); + + // CallIndirect -------------------- + lattice indirect_call(H::Node, LatticeWrapper); // is an `IndirectCall` to `FuncDefn` + indirect_call(call, tgt) <-- + node(call), + if let OpType::CallIndirect(_) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + let tgt = load_func(v); + + out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + input_child(func, inp), + in_wire_value(call, p, v) + if p.index() > 0; + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + output_child(func, outp), + in_wire_value(outp, p, v); + + // Default out-value is Bottom, but if we can't determine the called function, + // assign everything to Top + out_wire_value(call, p, PV::Top) <-- + node(call), + if let OpType::CallIndirect(ci) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + // Second alternative below addresses function::Value's: + if matches!(v, PartialValue::Top | PartialValue::Value(_)), + for p in ci.signature().output_ports(); }; let out_wire_values = all_results .out_wire_value @@ -337,13 +370,58 @@ pub(super) fn run_datalog( } } +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd)] +enum LatticeWrapper { + Bottom, + Value(T), + Top, +} + +impl Lattice for LatticeWrapper { + fn meet_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + return false; + }; + if *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + *self = other; + return true; + }; + // Both are `Value`s and not equal + *self = LatticeWrapper::Bottom; + true + } + + fn join_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + return false; + }; + if *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + *self = other; + return true; + }; + // Both are `Value`s and are not equal + *self = LatticeWrapper::Top; + true + } +} + +fn load_func(v: &PV) -> LatticeWrapper { + match v { + PartialValue::Bottom | PartialValue::PartialSum(_) => LatticeWrapper::Bottom, + PartialValue::LoadedFunction(LoadedFunction { func_node, .. }) => { + LatticeWrapper::Value(*func_node) + } + PartialValue::Value(_) | PartialValue::Top => LatticeWrapper::Top, + } +} + fn propagate_leaf_op( ctx: &mut impl DFContext, hugr: &H, n: H::Node, - ins: &[PV], + ins: &[PV], num_outs: usize, -) -> Option> { +) -> Option> { match hugr.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. @@ -362,8 +440,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent - OpType::Call(_) => None, // handled via Input/Output of FuncDefn - OpType::Const(_) => None, // handled by LoadConstant: + OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = hugr @@ -380,10 +457,10 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::singleton( - ctx.value_from_function(func_node, &load_op.type_args) - .map_or(PV::Top, PV::Value), - )) + Some(ValueRow::singleton(PartialValue::new_load( + func_node, + load_op.type_args.clone(), + ))) } OpType::ExtensionOp(e) => { Some(ValueRow::from_iter(if row_contains_bottom(ins) { @@ -401,6 +478,54 @@ fn propagate_leaf_op( outs })) } - o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + // We only call propagate_leaf_op for dataflow op non-containers, + o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive + } +} + +#[cfg(test)] +mod test { + use ascent::Lattice; + + use super::LatticeWrapper; + + #[test] + fn latwrap_join() { + for lv in [ + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + LatticeWrapper::Top, + ] { + let mut subject = LatticeWrapper::Bottom; + assert!(subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.join_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Top + ); + assert_eq!(subject, LatticeWrapper::Top); + } + } + + #[test] + fn latwrap_meet() { + for lv in [ + LatticeWrapper::Bottom, + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + ] { + let mut subject = LatticeWrapper::Top; + assert!(subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.meet_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Bottom + ); + assert_eq!(subject, LatticeWrapper::Bottom); + } } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f2a497806..240f4f2d6 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,7 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::Node; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -51,15 +51,25 @@ pub struct Sum { pub st: SumType, } +/// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" +/// to a function at a specific node, instantiated with the provided type-args. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct LoadedFunction { + /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded + pub func_node: N, + /// The type arguments provided when loading + pub args: Vec, +} + /// A representation of a value of [SumType], that may have one or more possible tags, /// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] -pub struct PartialSum(pub HashMap>>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { /// New instance for a single known tag. /// (Multi-tag instances can be created via [Self::try_join_mut].) - pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -75,9 +85,21 @@ impl PartialSum { pv.assert_invariants(); } } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } -impl PartialSum { +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -141,12 +163,33 @@ impl PartialSum { } Ok(changed) } +} - /// Whether this sum might have the specified tag - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.contains_key(&tag) - } +/// Trait implemented by value types into which [PartialValue]s can be converted, +/// so long as the PV has no [Top](PartialValue::Top), [Bottom](PartialValue::Bottom) +/// or [PartialSum]s with more than one possible tag. See [PartialSum::try_into_sum] +/// and [PartialValue::try_into_concrete]. +/// +/// `V` is the type of [AbstractValue] from which `Self` can (fallibly) be constructed, +/// `N` is the type of [HugrNode](hugr_core::core::HugrNode) for function pointers +pub trait AsConcrete: Sized { + /// Kind of error raised when creating `Self` from a value `V`, see [Self::from_value] + type ValErr: std::error::Error; + /// Kind of error that may be raised when creating `Self` from a [Sum] of `Self`s, + /// see [Self::from_sum] + type SumErr: std::error::Error; + + /// Convert an abstract value into concrete + fn from_value(val: V) -> Result; + + /// Convert a sum (of concrete values, already recursively converted) into concrete + fn from_sum(sum: Sum) -> Result; + + /// Convert a function pointer into a concrete value + fn from_func(func: LoadedFunction) -> Result>; +} +impl PartialSum { /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. /// @@ -155,11 +198,11 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> - where - V: TryInto, - Sum: TryInto, - { + #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases + pub fn try_into_sum>( + self, + typ: &Type, + ) -> Result, ExtractValueError> { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); } @@ -185,22 +228,15 @@ impl PartialSum { num_elements: v.len(), }) } - - /// Can this ever occur at runtime? See [PartialValue::contains_bottom] - pub fn contains_bottom(&self) -> bool { - self.0 - .iter() - .all(|(_tag, elements)| row_contains_bottom(elements)) - } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type /// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] - MultipleVariants(PartialSum), + MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] ValueIsBottom, #[error("Value contained `Top`")] @@ -209,6 +245,8 @@ pub enum ExtractValueError { CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] CouldNotBuildSum(#[source] SE), + #[error("Could not convert into concrete function pointer {0}")] + CouldNotLoadFunction(LoadedFunction), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -217,14 +255,14 @@ pub enum ExtractValueError { }, } -impl PartialSum { +impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. - pub fn variant_values(&self, variant: usize) -> Option>> { + pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } } -impl PartialOrd for PartialSum { +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -254,13 +292,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -273,30 +311,32 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { /// No possibilities known (so far) Bottom, + /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) + LoadedFunction(LoadedFunction), /// A single value (of the underlying representation) Value(V), /// Sum (with at least one, perhaps several, possible tags) of underlying values - PartialSum(PartialSum), + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } } -impl From> for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -312,33 +352,59 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } + + /// New instance of self for a [LoadFunction](hugr_core::ops::LoadFunction) + pub fn new_load(func_node: N, args: impl Into>) -> Self { + Self::LoadedFunction(LoadedFunction { + func_node, + args: args.into(), + }) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + false + } + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } -impl PartialValue { +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + return None + } PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) } +} - /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } - +impl PartialValue { /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by /// [PartialSum::try_into_sum]. @@ -348,47 +414,27 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete(self, typ: &Type) -> Result> - where - V: TryInto, - Sum: TryInto, - { + pub fn try_into_concrete>( + self, + typ: &Type, + ) -> Result> { match self { - Self::Value(v) => v - .clone() - .try_into() - .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), - Self::PartialSum(ps) => ps - .try_into_sum(typ)? - .try_into() - .map_err(ExtractValueError::CouldNotBuildSum), + Self::Value(v) => { + C::from_value(v.clone()).map_err(|e| ExtractValueError::CouldNotConvert(v, e)) + } + Self::LoadedFunction(lf) => { + C::from_func(lf).map_err(ExtractValueError::CouldNotLoadFunction) + } + Self::PartialSum(ps) => { + C::from_sum(ps.try_into_sum(typ)?).map_err(ExtractValueError::CouldNotBuildSum) + } Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } - - /// A value contains bottom means that it cannot occur during execution: - /// it may be an artefact during bootstrapping of the analysis, or else - /// the value depends upon a `panic` or a loop that - /// [never terminates](super::TailLoopTermination::NeverBreaks). - pub fn contains_bottom(&self) -> bool { - match self { - PartialValue::Bottom => true, - PartialValue::Top | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.contains_bottom(), - } - } } -impl TryFrom> for Value { - type Error = ConstTypeError; - - fn try_from(value: Sum) -> Result { - Self::sum(value.tag, value.values, value.st) - } -} - -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); let mut old_self = Self::Top; @@ -400,13 +446,17 @@ impl Lattice for PartialValue { Some((h3, b)) => (Self::Value(h3), b), None => (Self::Top, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also join the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Top, true) - } + _ => (Self::Top, true), }; *self = res; ch @@ -423,20 +473,24 @@ impl Lattice for PartialValue { Some((h3, ch)) => (Self::Value(h3), ch), None => (Self::Bottom, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also meet the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Bottom, true) - } + _ => (Self::Bottom, true), }; *self = res; ch } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self::Top } @@ -446,7 +500,7 @@ impl BoundedLattice for PartialValue { } } -impl PartialOrd for PartialValue { +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { @@ -457,6 +511,9 @@ impl PartialOrd for PartialValue { (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) => { + (lf1 == lf2).then_some(Ordering::Equal) + } (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } @@ -468,19 +525,20 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::NodeIndex; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PartialSum, PartialValue}; + use super::{AbstractValue, LoadedFunction, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { Branch(Vec>>), - /// None => unit, Some => TestValue <= this *usize* - Leaf(Option), + LeafVal(usize), // contains a TestValue <= this usize + LeafPtr(usize), // contains a LoadedFunction with node <= this *usize* } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -509,8 +567,11 @@ mod test { fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::LeafVal(max), PartialValue::Value(TestValue(val))) => val <= max, + ( + Self::LeafPtr(max), + PartialValue::LoadedFunction(LoadedFunction { func_node, args }), + ) => args.is_empty() && func_node.index() <= *max, (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { @@ -537,8 +598,11 @@ mod test { fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; - let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + let leaf_strat = prop_oneof![ + (0..usize::MAX).prop_map(TestSumType::LeafVal), + // This is the maximum value accepted by portgraph::NodeIndex::new + (0..((2usize ^ 31) - 2)).prop_map(TestSumType::LeafPtr) + ]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, @@ -605,11 +669,18 @@ mod test { ust: &TestSumType, ) -> impl Strategy> { match ust { - TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), - TestSumType::Leaf(Some(i)) => (0..*i) + TestSumType::LeafVal(i) => (0..=*i) .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), + TestSumType::LeafPtr(i) => (0..=*i) + .prop_map(|i| { + PartialValue::LoadedFunction(LoadedFunction { + func_node: portgraph::NodeIndex::new(i).into(), + args: vec![], + }) + }) + .boxed(), TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c40f1d87f..c4a94a9e7 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,17 +1,19 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, PortIndex, Wire}; +use hugr_core::{HugrView, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; +use super::{ + datalog::InWire, partial_value::ExtractValueError, AbstractValue, AsConcrete, PartialValue, +}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, - pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, + pub(super) in_wire_value: Vec>, pub(super) case_reachable: Vec<(H::Node, H::Node)>, pub(super) bb_reachable: Vec<(H::Node, H::Node)>, - pub(super) out_wire_values: HashMap, PartialValue>, + pub(super) out_wire_values: HashMap, PartialValue>, } impl AnalysisResults { @@ -21,7 +23,7 @@ impl AnalysisResults { } /// Gets the lattice value computed for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() } @@ -84,13 +86,11 @@ impl AnalysisResults { /// `None` if the analysis did not produce a result for that wire, or if /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` - pub fn try_read_wire_concrete( + #[allow(clippy::type_complexity)] + pub fn try_read_wire_concrete>( &self, w: Wire, - ) -> Result>> - where - V2: TryFrom + TryFrom, Error = SE>, - { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr @@ -116,7 +116,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - fn from_control_value(v: &PartialValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3af0097f7..1c4b4e439 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,10 +1,12 @@ +use std::convert::Infallible; + use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::builder::{inout_sig, CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; -use hugr_core::ops::TailLoop; -use hugr_core::types::TypeRow; +use hugr_core::ops::{CallIndirect, TailLoop}; +use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -19,7 +21,10 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{ + AbstractValue, AsConcrete, ConstLoader, DFContext, LoadedFunction, Machine, PartialValue, Sum, + TailLoopTermination, +}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,10 +40,22 @@ impl ConstLoader for TestContext { impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) -impl From for Value { - fn from(v: Void) -> Self { +impl AsConcrete for Value { + type ValErr = Infallible; + + type SumErr = ConstTypeError; + + fn from_value(v: Void) -> Result { match v {} } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) + } } fn pv_false() -> PartialValue { @@ -295,9 +312,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .try_read_wire_concrete::(cond_o2) - .is_err()); + assert!(results.try_read_wire_concrete::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); @@ -547,3 +562,78 @@ fn test_module() { ); } } + +#[rstest] +#[case(pv_false(), pv_false())] +#[case(pv_false(), pv_true())] +#[case(pv_true(), pv_false())] +#[case(pv_true(), pv_true())] +fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue) { + let b2b = || Signature::new_endo(bool_t()); + let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t(); 3], vec![bool_t(); 2])).unwrap(); + + let [id1, id2] = ["id1", "[id2]"].map(|name| { + let fb = dfb.define_function(name, b2b()).unwrap(); + let [inp] = fb.input_wires_arr(); + fb.finish_with_outputs([inp]).unwrap() + }); + + let [inp_direct, which, inp_indirect] = dfb.input_wires_arr(); + let [res1] = dfb + .call(id1.handle(), &[], [inp_direct]) + .unwrap() + .outputs_arr(); + + // We'll unconditionally load both functions, to demonstrate that it's + // the CallIndirect that matters, not just which functions are loaded. + let lf1 = dfb.load_func(id1.handle(), &[]).unwrap(); + let lf2 = dfb.load_func(id2.handle(), &[]).unwrap(); + let bool_func = || Type::new_function(b2b()); + let mut cond = dfb + .conditional_builder( + (vec![type_row![]; 2], which), + [(bool_func(), lf1), (bool_func(), lf2)], + bool_func().into(), + ) + .unwrap(); + let case_false = cond.case_builder(0).unwrap(); + let [f0, _f1] = case_false.input_wires_arr(); + case_false.finish_with_outputs([f0]).unwrap(); + let case_true = cond.case_builder(1).unwrap(); + let [_f0, f1] = case_true.input_wires_arr(); + case_true.finish_with_outputs([f1]).unwrap(); + let [tgt] = cond.finish_sub_container().unwrap().outputs_arr(); + let [res2] = dfb + .add_dataflow_op(CallIndirect { signature: b2b() }, [tgt, inp_indirect]) + .unwrap() + .outputs_arr(); + let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); + + let run = |which| { + Machine::new(&h).run( + TestContext, + [ + (0.into(), inp1.clone()), + (1.into(), which), + (2.into(), inp2.clone()), + ], + ) + }; + let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); + + // 1. Test with `which` unknown -> second output unknown + let results = run(PartialValue::Top); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); + + // 2. Test with `which` selecting second function -> both passthrough + let results = run(pv_true()); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); + + //3. Test with `which` selecting first function -> alias + let results = run(pv_false()); + let out = Some(inp1.join(inp2)); + assert_eq!(results.read_out_wire(w1), out); + assert_eq!(results.read_out_wire(w2), out); +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 50cf10318..43c842d91 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,25 +5,25 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::{lattice::BoundedLattice, Lattice}; +use ascent::Lattice; use itertools::zip_eq; use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] -pub(super) struct ValueRow(Vec>); +pub(super) struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PartialValue::Bottom; len]) } - pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { *self.0.get_mut(idx).unwrap() = v; self } - pub fn singleton(v: PartialValue) -> Self { + pub fn singleton(v: PartialValue) -> Self { Self(vec![v]) } @@ -34,25 +34,25 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option>> { + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } } -impl Lattice for ValueRow { +impl Lattice for ValueRow { fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; @@ -72,30 +72,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) From 63477565de0dbfb8027736cf905f6f148e2ddcab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:52:21 +0100 Subject: [PATCH 30/33] feat: Make NodeHandle generic (#2092) Adds a generic node type to the `NodeHandle` type. This is a required change for #2029. drive-by: Implement the "Link the NodeHandles to the OpType" TODO --- hugr-core/src/ops.rs | 16 +++++++- hugr-core/src/ops/handle.rs | 73 ++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 0c7d3bb3f..ce0d44de0 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -9,6 +9,7 @@ pub mod module; pub mod sum; pub mod tag; pub mod validate; +use crate::core::HugrNode; use crate::extension::resolution::{ collect_op_extension, collect_op_types_extensions, ExtensionCollectionError, }; @@ -20,6 +21,7 @@ use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use derive_more::Display; +use handle::NodeHandle; use paste::paste; use portgraph::NodeIndex; @@ -41,7 +43,6 @@ pub use tag::OpTag; #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] /// The concrete operation types for a node in the HUGR. -// TODO: Link the NodeHandles to the OpType. #[non_exhaustive] #[allow(missing_docs)] #[serde(tag = "op")] @@ -377,6 +378,19 @@ pub trait OpTrait: Sized + Clone { /// Tag identifying the operation. fn tag(&self) -> OpTag; + /// Tries to create a specific [`NodeHandle`] for a node with this operation + /// type. + /// + /// Fails if the operation's [`OpTrait::tag`] does not match the + /// [`NodeHandle::TAG`] of the requested handle. + fn try_node_handle(&self, node: N) -> Option + where + N: HugrNode, + H: NodeHandle + From, + { + H::TAG.is_superset(self.tag()).then(|| node.into()) + } + /// The signature of the operation. /// /// Only dataflow operations have a signature, otherwise returns None. diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index d7fe16419..a5a3c294a 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,4 +1,5 @@ //! Handles to nodes in HUGR. +use crate::core::HugrNode; use crate::types::{Type, TypeBound}; use crate::Node; @@ -9,12 +10,12 @@ use super::{AliasDecl, OpTag}; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. -pub trait NodeHandle: Clone { +pub trait NodeHandle: Clone { /// The most specific operation tag associated with the handle. const TAG: OpTag; /// Index of underlying node. - fn node(&self) -> Node; + fn node(&self) -> N; /// Operation tag for the handle. #[inline] @@ -23,7 +24,7 @@ pub trait NodeHandle: Clone { } /// Cast the handle to a different more general tag. - fn try_cast>(&self) -> Option { + fn try_cast + From>(&self) -> Option { T::TAG.is_superset(Self::TAG).then(|| self.node().into()) } @@ -36,30 +37,30 @@ pub trait NodeHandle: Clone { /// Trait for handles that contain children. /// /// The allowed children handles are defined by the associated type. -pub trait ContainerHandle: NodeHandle { +pub trait ContainerHandle: NodeHandle { /// Handle type for the children of this node. - type ChildrenHandle: NodeHandle; + type ChildrenHandle: NodeHandle; } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowOp](crate::ops::dataflow). -pub struct DataflowOpID(Node); +pub struct DataflowOpID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DFG](crate::ops::DFG) node. -pub struct DfgID(Node); +pub struct DfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [CFG](crate::ops::CFG) node. -pub struct CfgID(Node); +pub struct CfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a module [Module](crate::ops::Module) node. -pub struct ModuleRootID(Node); +pub struct ModuleRootID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [module op](crate::ops::module) node. -pub struct ModuleID(Node); +pub struct ModuleID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [def](crate::ops::OpType::FuncDefn) @@ -67,7 +68,7 @@ pub struct ModuleID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct FuncID(Node); +pub struct FuncID(N); #[derive(Debug, Clone, PartialEq, Eq)] /// Handle to an [AliasDefn](crate::ops::OpType::AliasDefn) @@ -75,15 +76,15 @@ pub struct FuncID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct AliasID { - node: Node, +pub struct AliasID { + node: N, name: SmolStr, bound: TypeBound, } -impl AliasID { +impl AliasID { /// Construct new AliasID - pub fn new(node: Node, name: SmolStr, bound: TypeBound) -> Self { + pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self { Self { node, name, bound } } @@ -99,27 +100,27 @@ impl AliasID { #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. -pub struct ConstID(Node); +pub struct ConstID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node. -pub struct BasicBlockID(Node); +pub struct BasicBlockID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Case](crate::ops::Case) node. -pub struct CaseID(Node); +pub struct CaseID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [TailLoop](crate::ops::TailLoop) node. -pub struct TailLoopID(Node); +pub struct TailLoopID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Conditional](crate::ops::Conditional) node. -pub struct ConditionalID(Node); +pub struct ConditionalID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a dataflow container node. -pub struct DataflowParentID(Node); +pub struct DataflowParentID(N); /// Implements the `NodeHandle` trait for a tuple struct that contains just a /// NodeIndex. Takes the name of the struct, and the corresponding OpTag. @@ -131,11 +132,11 @@ macro_rules! impl_nodehandle { impl_nodehandle!($name, $tag, 0); }; ($name:ident, $tag:expr, $node_attr:tt) => { - impl NodeHandle for $name { + impl NodeHandle for $name { const TAG: OpTag = $tag; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.$node_attr } } @@ -156,35 +157,35 @@ impl_nodehandle!(ConstID, OpTag::Const); impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock); -impl NodeHandle for FuncID { +impl NodeHandle for FuncID { const TAG: OpTag = OpTag::Function; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.0 } } -impl NodeHandle for AliasID { +impl NodeHandle for AliasID { const TAG: OpTag = OpTag::Alias; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.node } } -impl NodeHandle for Node { +impl NodeHandle for N { const TAG: OpTag = OpTag::Any; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { *self } } /// Implements the `ContainerHandle` trait, with the given child handle type. macro_rules! impl_containerHandle { - ($name:path, $children:ident) => { - impl ContainerHandle for $name { - type ChildrenHandle = $children; + ($name:ident, $children:ident) => { + impl ContainerHandle for $name { + type ChildrenHandle = $children; } }; } @@ -197,5 +198,9 @@ impl_containerHandle!(CaseID, DataflowOpID); impl_containerHandle!(ModuleRootID, ModuleID); impl_containerHandle!(CfgID, BasicBlockID); impl_containerHandle!(BasicBlockID, DataflowOpID); -impl_containerHandle!(FuncID, DataflowOpID); -impl_containerHandle!(AliasID, DataflowOpID); +impl ContainerHandle for FuncID { + type ChildrenHandle = DataflowOpID; +} +impl ContainerHandle for AliasID { + type ChildrenHandle = DataflowOpID; +} From 5b43c0d351720a2d1ba66053467d64573bbbb9c6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 17 Apr 2025 10:38:12 +0100 Subject: [PATCH 31/33] feat!: remove ExtensionValue (#2093) Closes #1595 BREAKING CHANGE: `values` field in `Extension` and `ExtensionValue` struct/class removed in rust and python. Use 0-input ops that return constant values. --- hugr-core/src/extension.rs | 64 +------------------ .../src/extension/resolution/extension.rs | 11 +--- hugr-core/src/hugr/validate/test.rs | 8 +-- hugr-core/src/std_extensions/logic.rs | 26 +------- hugr-py/src/hugr/_serialization/extension.py | 21 ------ hugr-py/src/hugr/ext.py | 42 +----------- .../_json_defs/arithmetic/conversions.json | 1 - .../hugr/std/_json_defs/arithmetic/float.json | 1 - .../_json_defs/arithmetic/float/types.json | 1 - .../hugr/std/_json_defs/arithmetic/int.json | 1 - .../std/_json_defs/arithmetic/int/types.json | 1 - .../std/_json_defs/collections/array.json | 1 - .../hugr/std/_json_defs/collections/list.json | 1 - .../_json_defs/collections/static_array.json | 1 - hugr-py/src/hugr/std/_json_defs/logic.json | 28 -------- hugr-py/src/hugr/std/_json_defs/prelude.json | 1 - hugr-py/src/hugr/std/_json_defs/ptr.json | 1 - specification/schema/hugr_schema_live.json | 30 --------- .../schema/hugr_schema_strict_live.json | 30 --------- .../schema/testing_hugr_schema_live.json | 30 --------- .../testing_hugr_schema_strict_live.json | 30 --------- .../arithmetic/conversions.json | 1 - .../std_extensions/arithmetic/float.json | 1 - .../arithmetic/float/types.json | 1 - .../std_extensions/arithmetic/int.json | 1 - .../std_extensions/arithmetic/int/types.json | 1 - .../std_extensions/collections/array.json | 1 - .../std_extensions/collections/list.json | 1 - .../collections/static_array.json | 1 - specification/std_extensions/logic.json | 28 -------- specification/std_extensions/prelude.json | 1 - specification/std_extensions/ptr.json | 1 - 32 files changed, 7 insertions(+), 361 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index b6e059050..23238ccfd 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -19,9 +19,8 @@ use derive_more::Display; use thiserror::Error; use crate::hugr::IdentList; -use crate::ops::constant::{ValueName, ValueNameRef}; use crate::ops::custom::{ExtensionOp, OpaqueOp}; -use crate::ops::{self, OpName, OpNameRef}; +use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::RowVariable; use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; @@ -497,37 +496,6 @@ impl CustomConcrete for CustomType { } } -/// A constant value provided by a extension. -/// Must be an instance of a type available to the extension. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct ExtensionValue { - extension: ExtensionId, - name: ValueName, - typed_value: ops::Value, -} - -impl ExtensionValue { - /// Returns a reference to the typed value of this [`ExtensionValue`]. - pub fn typed_value(&self) -> &ops::Value { - &self.typed_value - } - - /// Returns a mutable reference to the typed value of this [`ExtensionValue`]. - pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value { - &mut self.typed_value - } - - /// Returns a reference to the name of this [`ExtensionValue`]. - pub fn name(&self) -> &str { - self.name.as_str() - } - - /// Returns a reference to the extension this [`ExtensionValue`] belongs to. - pub fn extension(&self) -> &ExtensionId { - &self.extension - } -} - /// A unique identifier for a extension. /// /// The actual [`Extension`] is stored externally. @@ -583,8 +551,6 @@ pub struct Extension { pub runtime_reqs: ExtensionSet, /// Types defined by this extension. types: BTreeMap, - /// Static values defined by this extension. - values: BTreeMap, /// Operation declarations with serializable definitions. // Note: serde will serialize this because we configure with `features=["rc"]`. // That will clone anything that has multiple references, but each @@ -608,7 +574,6 @@ impl Extension { version, runtime_reqs: Default::default(), types: Default::default(), - values: Default::default(), operations: Default::default(), } } @@ -680,11 +645,6 @@ impl Extension { self.types.get(type_name) } - /// Allows read-only access to the values in this Extension - pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> { - self.values.get(value_name) - } - /// Returns the name of the extension. pub fn name(&self) -> &ExtensionId { &self.name @@ -705,25 +665,6 @@ impl Extension { self.types.iter() } - /// Add a named static value to the extension. - pub fn add_value( - &mut self, - name: impl Into, - typed_value: ops::Value, - ) -> Result<&mut ExtensionValue, ExtensionBuildError> { - let extension_value = ExtensionValue { - extension: self.name.clone(), - name: name.into(), - typed_value, - }; - match self.values.entry(extension_value.name.clone()) { - btree_map::Entry::Occupied(_) => { - Err(ExtensionBuildError::ValueExists(extension_value.name)) - } - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)), - } - } - /// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension. pub fn instantiate_extension_op( &self, @@ -784,9 +725,6 @@ pub enum ExtensionBuildError { /// Existing [`TypeDef`] #[error("Extension already has an type called {0}.")] TypeDefExists(TypeName), - /// Existing [`ExtensionValue`] - #[error("Extension already has an extension value called {0}.")] - ValueExists(ValueName), } /// A set of extensions identified by their unique [`ExtensionId`]. diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 61adc1dea..05c0faf69 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef}; -use super::types_mut::{resolve_signature_exts, resolve_value_exts}; +use super::types_mut::resolve_signature_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -59,14 +59,7 @@ impl Extension { for type_def in self.types.values_mut() { resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?; } - for val in self.values.values_mut() { - resolve_value_exts( - None, - val.typed_value_mut(), - extensions, - &mut used_extensions, - )?; - } + let ops = mem::take(&mut self.operations); for (op_id, mut op_def) in ops { // TODO: We should be able to clone the definition if needed by using `make_mut`, diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index ecb417ec5..37157020d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -20,7 +20,6 @@ use crate::ops::handle::NodeHandle; use crate::ops::{self, OpType, Value}; use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::LogicOp; -use crate::std_extensions::logic::{self}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, @@ -307,12 +306,7 @@ fn test_local_const() { port_kind: EdgeKind::Value(bool_t()) }) ); - let const_op: ops::Const = logic::EXTENSION - .get_value(&logic::TRUE_NAME) - .unwrap() - .typed_value() - .clone() - .into(); + let const_op: ops::Const = ops::Value::from_bool(true).into(); // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op); let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: bool_t() }); diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index fcc8be9d3..20977cb51 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -124,13 +124,6 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); fn extension() -> Arc { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { LogicOp::load_all_ops(extension, extension_ref).unwrap(); - - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); }) } @@ -172,12 +165,9 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { pub(crate) mod test { use std::sync::Arc; - use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; + use super::{extension, LogicOp}; use crate::{ - extension::{ - prelude::bool_t, - simple_op::{MakeOpDef, MakeRegisteredOp}, - }, + extension::simple_op::{MakeOpDef, MakeRegisteredOp}, ops::{NamedOp, Value}, Extension, }; @@ -207,18 +197,6 @@ pub(crate) mod test { } } - #[test] - fn test_values() { - let r: Arc = extension(); - let false_val = r.get_value(&FALSE_NAME).unwrap(); - let true_val = r.get_value(&TRUE_NAME).unwrap(); - - for v in [false_val, true_val] { - let simpl = v.typed_value().get_type(); - assert_eq!(simpl, bool_t()); - } - } - /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`] pub(crate) fn and_op() -> LogicOp { LogicOp::And diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 429bdd785..95e59754e 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -8,7 +8,6 @@ from hugr.hugr.base import Hugr from hugr.utils import deser_it -from .ops import Value from .serial_hugr import SerialHugr, serialization_version from .tys import ( ConfiguredBaseModel, @@ -20,7 +19,6 @@ ) if TYPE_CHECKING: - from .ops import Value from .serial_hugr import SerialHugr @@ -62,20 +60,6 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef: ) -class ExtensionValue(ConfiguredBaseModel): - extension: ExtensionId - name: str - typed_value: Value - - def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: - return extension.add_extension_value( - ext.ExtensionValue( - name=self.name, - val=self.typed_value.deserialize(), - ) - ) - - # -------------------------------------- # --------------- OpDef ---------------- # -------------------------------------- @@ -124,7 +108,6 @@ class Extension(ConfiguredBaseModel): name: ExtensionId runtime_reqs: set[ExtensionId] types: dict[str, TypeDef] - values: dict[str, ExtensionValue] operations: dict[str, OpDef] @classmethod @@ -146,10 +129,6 @@ def deserialize(self) -> ext.Extension: assert k == o.name, "Operation name must match key" e.add_op_def(o.deserialize(e)) - for k, v in self.values.items(): - assert k == v.name, "Value name must match key" - e.add_extension_value(v.deserialize(e)) - return e diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 494ea3c69..7bd02f982 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -8,7 +8,7 @@ from semver import Version import hugr._serialization.extension as ext_s -from hugr import ops, tys, val +from hugr import ops, tys from hugr.utils import ser_it __all__ = [ @@ -18,7 +18,6 @@ "FixedHugr", "OpDefSig", "OpDef", - "ExtensionValue", "Extension", "Version", ] @@ -246,23 +245,6 @@ def instantiate( return ops.ExtOp(self, concrete_signature, list(args or [])) -@dataclass -class ExtensionValue(ExtensionObject): - """A value defined in an :class:`Extension`.""" - - #: The name of the value. - name: str - #: Value payload. - val: val.Value - - def _to_serial(self) -> ext_s.ExtensionValue: - return ext_s.ExtensionValue( - extension=self.get_extension().name, - name=self.name, - typed_value=self.val._to_serial_root(), - ) - - T = TypeVar("T", bound=ops.RegisteredOp) @@ -278,8 +260,6 @@ class Extension: runtime_reqs: set[ExtensionId] = field(default_factory=set) #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) - #: Values defined in the extension. - values: dict[str, ExtensionValue] = field(default_factory=dict) #: Operation definitions in the extension. operations: dict[str, OpDef] = field(default_factory=dict) @@ -295,7 +275,6 @@ def _to_serial(self) -> ext_s.Extension: version=self.version, # type: ignore[arg-type] runtime_reqs=self.runtime_reqs, types={k: v._to_serial() for k, v in self.types.items()}, - values={k: v._to_serial() for k, v in self.values.items()}, operations={k: v._to_serial() for k, v in self.operations.items()}, ) @@ -347,19 +326,6 @@ def add_type_def(self, type_def: TypeDef) -> TypeDef: self.types[type_def.name] = type_def return self.types[type_def.name] - def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: - """Add a value to the extension. - - Args: - extension_value: The value to add. - - Returns: - The added value, now associated with the extension. - """ - extension_value._extension = self - self.values[extension_value.name] = extension_value - return self.values[extension_value.name] - @dataclass class OperationNotFound(NotFound): """Operation not found in extension.""" @@ -406,12 +372,6 @@ def get_type(self, name: str) -> TypeDef: class ValueNotFound(NotFound): """Value not found in extension.""" - def get_value(self, name: str) -> ExtensionValue: - try: - return self.values[name] - except KeyError as e: - raise self.ValueNotFound(name) from e - T = TypeVar("T", bound=ops.RegisteredOp) def register_op( diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/collections/array.json b/hugr-py/src/hugr/std/_json_defs/collections/array.json index 21e405151..375e13c72 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/list.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/list.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index ad9f02019..ff29d2c21 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index e11ba2388..ec392b155 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/hugr-py/src/hugr/std/_json_defs/ptr.json b/hugr-py/src/hugr/std/_json_defs/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/hugr-py/src/hugr/std/_json_defs/ptr.json +++ b/hugr-py/src/hugr/std/_json_defs/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 9e7d8c40c..ea08dff5b 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 6f436f969..8b65bae94 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index bc067d40e..91b121da6 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 47c9778d3..eae6a13a7 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/std_extensions/arithmetic/conversions.json b/specification/std_extensions/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/specification/std_extensions/arithmetic/conversions.json +++ b/specification/std_extensions/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/specification/std_extensions/arithmetic/float/types.json b/specification/std_extensions/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/specification/std_extensions/arithmetic/float/types.json +++ b/specification/std_extensions/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/specification/std_extensions/arithmetic/int/types.json b/specification/std_extensions/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/specification/std_extensions/arithmetic/int/types.json +++ b/specification/std_extensions/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/collections/array.json b/specification/std_extensions/collections/array.json index 21e405151..375e13c72 100644 --- a/specification/std_extensions/collections/array.json +++ b/specification/std_extensions/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/specification/std_extensions/collections/list.json b/specification/std_extensions/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/specification/std_extensions/collections/list.json +++ b/specification/std_extensions/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/specification/std_extensions/collections/static_array.json b/specification/std_extensions/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/specification/std_extensions/collections/static_array.json +++ b/specification/std_extensions/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index ad9f02019..ff29d2c21 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index e11ba2388..ec392b155 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/specification/std_extensions/ptr.json b/specification/std_extensions/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/specification/std_extensions/ptr.json +++ b/specification/std_extensions/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", From 89c2680912b47950ffd73a7c29a21386fdd0aee7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Apr 2025 17:16:31 +0100 Subject: [PATCH 32/33] feat!: ComposablePass trait allowing sequencing and validation (#1895) Currently We have several "passes": monomorphization, dead function removal, constant folding. Each has its own code to allow setting a validation level (before and after that pass). This PR adds the ability chain (sequence) passes;, and to add validation before+after any pass or sequence; and commons up validation code. The top-level `constant_fold_pass` (etc.) functions are left as wrappers that do a single pass with validation only in test. I've left ConstFoldPass as always including DCE, but an alternative could be to return a sequence of the two - ATM that means a tuple `(ConstFoldPass, DeadCodeElimPass)`. I also wondered about including a method `add_entry_point` in ComposablePass (e.g. for ConstFoldPass, that means `with_inputs` but no inputs, i.e. all Top). I feel this is not applicable to *all* passes, but near enough. This could be done in a later PR but `add_entry_point` would need a no-op default for that to be a non-breaking change. So if we wouldn't be happy with the no-op default then I could just add it here... Finally...docs are extremely minimal ATM (this is hugr-passes), I am hoping that most of this is reasonably obvious (it doesn't really do a lot!), but please flag anything you think is particularly in need of a doc comment! BREAKING CHANGE: quite a lot of calls to current pass routines will break, specific cases include (a) `with_validation_level` should be done by wrapping a ValidatingPass around the receiver; (b) XXXPass::run() requires `use ...ComposablePass` (however, such calls will cease to do any validation). closes #1832 --- hugr-passes/src/composable.rs | 361 +++++++++++++++++++++ hugr-passes/src/const_fold.rs | 45 +-- hugr-passes/src/const_fold/test.rs | 1 + hugr-passes/src/dead_code.rs | 50 ++- hugr-passes/src/dead_funcs.rs | 77 ++--- hugr-passes/src/lib.rs | 12 +- hugr-passes/src/monomorphize.rs | 92 ++---- hugr-passes/src/replace_types.rs | 105 +++--- hugr-passes/src/replace_types/linearize.rs | 2 +- hugr-passes/src/untuple.rs | 70 ++-- 10 files changed, 550 insertions(+), 265 deletions(-) create mode 100644 hugr-passes/src/composable.rs diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs new file mode 100644 index 000000000..fb3319155 --- /dev/null +++ b/hugr-passes/src/composable.rs @@ -0,0 +1,361 @@ +//! Compiler passes and utilities for composing them + +use std::{error::Error, marker::PhantomData}; + +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; +use itertools::Either; + +/// An optimization pass that can be sequenced with another and/or wrapped +/// e.g. by [ValidatingPass] +pub trait ComposablePass: Sized { + type Error: Error; + type Result; // Would like to default to () but currently unstable + + fn run(&self, hugr: &mut impl HugrMut) -> Result; + + fn map_err( + self, + f: impl Fn(Self::Error) -> E2, + ) -> impl ComposablePass { + ErrMapper::new(self, f) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", so long as + /// `other::Err` can be combined with ours. + fn then>( + self, + other: P, + ) -> impl ComposablePass { + struct Sequence(P1, P2, PhantomData); + impl ComposablePass for Sequence + where + P1: ComposablePass, + P2: ComposablePass, + E: ErrorCombiner, + { + type Error = E; + + type Result = (P1::Result, P2::Result); + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res1 = self.0.run(hugr).map_err(E::from_first)?; + let res2 = self.1.run(hugr).map_err(E::from_second)?; + Ok((res1, res2)) + } + } + + Sequence(self, other, PhantomData) + } +} + +/// Trait for combining the error types from two different passes +/// into a single error. +pub trait ErrorCombiner: Error { + fn from_first(a: A) -> Self; + fn from_second(b: B) -> Self; +} + +impl> ErrorCombiner for A { + fn from_first(a: A) -> Self { + a + } + + fn from_second(b: B) -> Self { + b.into() + } +} + +impl ErrorCombiner for Either { + fn from_first(a: A) -> Self { + Either::Left(a) + } + + fn from_second(b: B) -> Self { + Either::Right(b) + } +} + +// Note: in the short term we could wish for two more impls: +// impl ErrorCombiner for E +// impl ErrorCombiner for E +// however, these aren't possible as they conflict with +// impl> ErrorCombiner for A +// when A=E=Infallible, boo :-(. +// However this will become possible, indeed automatic, when Infallible is replaced +// by ! (never_type) as (unlike Infallible) ! converts Into anything + +// ErrMapper ------------------------------ +struct ErrMapper(P, F, PhantomData); + +impl E> ErrMapper { + fn new(pass: P, err_fn: F) -> Self { + Self(pass, err_fn, PhantomData) + } +} + +impl E> ComposablePass for ErrMapper { + type Error = E; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr).map_err(&self.1) + } +} + +// ValidatingPass ------------------------------ + +/// Error from a [ValidatingPass] +#[derive(thiserror::Error, Debug)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + Input { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + Output { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error(transparent)] + Underlying(#[from] E), +} + +/// Runs an underlying pass, but with validation of the Hugr +/// both before and afterwards. +pub struct ValidatingPass

(P, bool); + +impl ValidatingPass

{ + pub fn new_default(underlying: P) -> Self { + // Self(underlying, cfg!(feature = "extension_inference")) + // Sadly, many tests fail with extension inference, hence: + Self(underlying, false) + } + + pub fn new_validating_extensions(underlying: P) -> Self { + Self(underlying, true) + } + + pub fn new(underlying: P, validate_extensions: bool) -> Self { + Self(underlying, validate_extensions) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), ValidatePassError> { + match self.1 { + false => hugr.validate_no_extensions(), + true => hugr.validate(), + } + .map_err(|err| mk_err(err, hugr.mermaid_string())) + } +} + +impl ComposablePass for ValidatingPass

{ + type Error = ValidatePassError; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { + err, + pretty_hugr, + })?; + let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { + err, + pretty_hugr, + })?; + Ok(res) + } +} + +// IfThen ------------------------------ +/// [ComposablePass] that executes a first pass that returns a `bool` +/// result; and then, if-and-only-if that first result was true, +/// executes a second pass +pub struct IfThen(A, B, PhantomData); + +impl, B: ComposablePass, E: ErrorCombiner> + IfThen +{ + /// Make a new instance given the [ComposablePass] to run first + /// and (maybe) second + pub fn new(fst: A, opt_snd: B) -> Self { + Self(fst, opt_snd, PhantomData) + } +} + +impl, B: ComposablePass, E: ErrorCombiner> + ComposablePass for IfThen +{ + type Error = E; + + type Result = Option; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; + res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) + .transpose() + } +} + +pub(crate) fn validate_if_test( + pass: P, + hugr: &mut impl HugrMut, +) -> Result> { + if cfg!(test) { + ValidatingPass::new_default(pass).run(hugr) + } else { + pass.run(hugr).map_err(ValidatePassError::Underlying) + } +} + +#[cfg(test)] +mod test { + use itertools::{Either, Itertools}; + use std::convert::Infallible; + + use hugr_core::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use hugr_core::extension::prelude::{ + bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID, + }; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::{Signature, TypeRow}; + use hugr_core::{Hugr, HugrView, IncomingPort}; + + use crate::const_fold::{ConstFoldError, ConstantFoldPass}; + use crate::untuple::{UntupleRecursive, UntupleResult}; + use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass}; + + use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass}; + + #[test] + fn test_then() { + let mut mb = ModuleBuilder::new(); + let id1 = mb + .define_function("id1", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id1.input_wires(); + let id1 = id1.finish_with_outputs(inps).unwrap(); + let id2 = mb + .define_function("id2", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id2.input_wires(); + let id2 = id2.finish_with_outputs(inps).unwrap(); + let hugr = mb.finish_hugr().unwrap(); + + let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]); + let cfold = + ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]); + + cfold.run(&mut hugr.clone()).unwrap(); + + let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE); + let r: Result<_, Either> = + dce.clone().then(cfold.clone()).run(&mut hugr.clone()); + assert_eq!(r, Err(Either::Right(exp_err.clone()))); + + let r = dce + .clone() + .map_err(|inf| match inf {}) + .then(cfold.clone()) + .run(&mut hugr.clone()); + assert_eq!(r, Err(exp_err)); + + let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone()); + r2.unwrap(); + } + + #[test] + fn test_validation() { + let mut h = Hugr::new(DFG { + signature: Signature::new(usize_t(), bool_t()), + }); + let inp = h.add_node_with_parent( + h.root(), + Input { + types: usize_t().into(), + }, + ); + let outp = h.add_node_with_parent( + h.root(), + Output { + types: bool_t().into(), + }, + ); + h.connect(inp, 0, outp, 0); + let backup = h.clone(); + let err = backup.validate().unwrap_err(); + + let no_inputs: [(IncomingPort, _); 0] = []; + let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs); + cfold.run(&mut h).unwrap(); + assert_eq!(h, backup); // Did nothing + + let r = ValidatingPass(cfold, false).run(&mut h); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + } + + #[test] + fn test_if_then() { + let tr = TypeRow::from(vec![usize_t(); 2]); + + let h = { + let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID); + let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); + let [a, b] = fb.input_wires_arr(); + let tup = fb + .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b]) + .unwrap(); + let untup = fb + .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs()) + .unwrap(); + fb.finish_hugr_with_outputs(untup.outputs()).unwrap() + }; + + let untup = UntuplePass::new(UntupleRecursive::Recursive); + { + // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple + let mut repl = ReplaceTypes::default(); + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + repl.replace_type(usize_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup.clone()); + + let mut h = h.clone(); + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!( + r, + Some(UntupleResult { + rewrites_applied: 1 + }) + ); + let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap(); + assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]); + } + + // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple + let mut repl = ReplaceTypes::default(); + let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone(); + repl.replace_type(i32_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup); + let mut h = h; + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!(r, None); + assert_eq!(h.children(h.root()).count(), 4); + let mktup = h + .output_neighbours(h.first_child(h.root()).unwrap()) + .next() + .unwrap(); + assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr))); + } +} diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index e73e3cd0e..99ccc180c 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -21,12 +21,11 @@ use crate::dataflow::{ TailLoopTermination, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{composable::validate_if_test, ComposablePass}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { - validation: ValidationLevel, allow_increase_termination: bool, /// Each outer key Node must be either: /// - a FuncDefn child of the root, if the root is a module; or @@ -34,13 +33,10 @@ pub struct ConstantFoldPass { inputs: HashMap>, } -#[derive(Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), /// Error raised when a Node is specified as an entry-point but /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor /// a [Conditional](OpType::Conditional). @@ -49,12 +45,6 @@ pub enum ConstFoldError { } impl ConstantFoldPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their /// result (if/when they do terminate) is either known or not needed. /// @@ -86,9 +76,19 @@ impl ConstantFoldPass { .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); self } +} + +impl ComposablePass for ConstantFoldPass { + type Error = ConstFoldError; + type Result = (); /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + /// + /// # Errors + /// + /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] + /// was of an invalid [OpType] + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -164,23 +164,10 @@ impl ConstantFoldPass { } }) }) - .run(hugr)?; + .run(hugr) + .map_err(|inf| match inf {})?; // TODO use into_ok when available Ok(()) } - - /// Run the pass using this configuration. - /// - /// # Errors - /// - /// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards - /// (if [Self::validation_level] is set, or in tests) - /// - /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] - /// was of an invalid OpType - pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } /// Exhaustively apply constant folding to a HUGR. @@ -198,7 +185,7 @@ pub fn constant_fold_pass(h: &mut H) { } else { c }; - c.run(h).unwrap() + validate_if_test(c, h).unwrap() } struct ConstFoldContext; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 58e69c568..ff5cd93a5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::ComposablePass as _; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index b714dd6fd..899e30243 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,13 +1,14 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -18,7 +19,6 @@ pub struct DeadCodeElimPass { /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [PreserveNode::default_for]. preserve_callback: Arc, - validation: ValidationLevel, } impl Default for DeadCodeElimPass { @@ -26,7 +26,6 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), - validation: ValidationLevel::default(), } } } @@ -39,13 +38,11 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a> { entry_points: &'a Vec, - validation: ValidationLevel, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, - validation: self.validation, }, f, ) @@ -86,13 +83,6 @@ impl PreserveNode { } impl DeadCodeElimPass { - /// Sets the validation level used before and after the pass is run - #[allow(unused)] - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows setting a callback that determines whether a node must be preserved /// (even when its result is not used) pub fn set_preserve_callback(mut self, cb: Arc) -> Self { @@ -146,24 +136,6 @@ impl DeadCodeElimPass { needed } - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - self.validation.run_validated_pass(hugr, |h, _| { - self.run_no_validate(h); - Ok(()) - }) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) { - let needed = self.find_needed_nodes(&*hugr); - let remove = hugr - .nodes() - .filter(|n| !needed.contains(n)) - .collect::>(); - for n in remove { - hugr.remove_node(n); - } - } - fn must_preserve( &self, h: &impl HugrView, @@ -185,6 +157,22 @@ impl DeadCodeElimPass { } } +impl ComposablePass for DeadCodeElimPass { + type Error = Infallible; + type Result = (); + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } +} #[cfg(test)] mod test { use std::sync::Arc; @@ -196,6 +184,8 @@ mod test { use hugr_core::{ops::Value, type_row, HugrView}; use itertools::Itertools; + use crate::ComposablePass; + use super::{DeadCodeElimPass, PreserveNode}; #[test] diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index b114a9e42..7071d5335 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -10,7 +10,10 @@ use hugr_core::{ }; use petgraph::visit::{Dfs, Walker}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{ + composable::{validate_if_test, ValidatePassError}, + ComposablePass, +}; use super::call_graph::{CallGraph, CallGraphNode}; @@ -26,9 +29,6 @@ pub enum RemoveDeadFuncsError { /// The invalid node. node: N, }, - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), } fn reachable_funcs<'a, H: HugrView>( @@ -64,17 +64,10 @@ fn reachable_funcs<'a, H: HugrView>( #[derive(Debug, Clone, Default)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - validation: ValidationLevel, entry_points: Vec, } impl RemoveDeadFuncsPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Adds new entry points - these must be [FuncDefn] nodes /// that are children of the [Module] at the root of the Hugr. /// @@ -87,16 +80,32 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } +} - /// Runs the pass (see [remove_dead_funcs]) with this configuration - pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - remove_dead_funcs(hugr, self.entry_points.iter().cloned()) - }) +impl ComposablePass for RemoveDeadFuncsPass { + type Error = RemoveDeadFuncsError; + type Result = (); + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs( + &CallGraph::new(hugr), + hugr, + self.entry_points.iter().cloned(), + )? + .collect::>(); + let unreachable = hugr + .nodes() + .filter(|n| { + OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n) + }) + .collect::>(); + for n in unreachable { + hugr.remove_subtree(n); + } + Ok(()) } } -/// Delete from the Hugr any functions that are not used by either [Call] or +/// Deletes from the Hugr any functions that are not used by either [Call] or /// [LoadFunction] nodes in reachable parts. /// /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, @@ -118,16 +127,11 @@ impl RemoveDeadFuncsPass { pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, -) -> Result<(), RemoveDeadFuncsError> { - let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); - let unreachable = h - .nodes() - .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) - .collect::>(); - for n in unreachable { - h.remove_subtree(n); - } - Ok(()) +) -> Result<(), ValidatePassError> { + validate_if_test( + RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + h, + ) } #[cfg(test)] @@ -142,7 +146,7 @@ mod test { }; use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; - use super::RemoveDeadFuncsPass; + use super::remove_dead_funcs; #[rstest] #[case([], vec![])] // No entry_points removes everything! @@ -182,15 +186,14 @@ mod test { }) .collect::>(); - RemoveDeadFuncsPass::default() - .with_module_entry_points( - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .run(&mut hugr) - .unwrap(); + remove_dead_funcs( + &mut hugr, + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 961c4da47..83ff71b67 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod call_graph; +pub mod composable; +pub use composable::ComposablePass; pub mod const_fold; pub mod dataflow; pub mod dead_code; @@ -21,19 +23,11 @@ pub mod untuple; )] #[allow(deprecated)] pub use monomorphize::remove_polyfuncs; -// TODO: Deprecated re-export. Remove on a breaking release. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -#[allow(deprecated)] -pub use monomorphize::monomorphize; -pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{monomorphize, MonomorphizePass}; pub mod replace_types; pub use replace_types::ReplaceTypes; pub mod nest_cfgs; pub mod non_local; -pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 4f4e9bda2..875ee9355 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,5 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, + convert::Infallible, fmt::Write, ops::Deref, }; @@ -12,7 +13,9 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; -use thiserror::Error; + +use crate::composable::{validate_if_test, ValidatePassError}; +use crate::ComposablePass; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -30,26 +33,8 @@ use thiserror::Error; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. -pub fn monomorphize(mut h: Hugr) -> Hugr { - monomorphize_ref(&mut h); - h -} - -fn monomorphize_ref(h: &mut impl HugrMut) { - let root = h.root(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - if !h.get_optype(root).is_module() { - #[allow(deprecated)] // TODO remove in next breaking release and update docs - remove_polyfuncs_ref(h); - } - } +pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + validate_if_test(MonomorphizePass, hugr) } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -254,8 +239,6 @@ fn instantiate( mono_tgt } -use crate::validation::{ValidatePassError, ValidationLevel}; - /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. /// @@ -271,38 +254,25 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone, Default)] -pub struct MonomorphizePass { - validation: ValidationLevel, -} - -#[derive(Debug, Error)] -#[non_exhaustive] -/// Errors produced by [MonomorphizePass]. -pub enum MonomorphizeError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), -} - -impl MonomorphizePass { - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Monomorphization pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { - monomorphize_ref(hugr); +#[derive(Debug, Clone)] +pub struct MonomorphizePass; + +impl ComposablePass for MonomorphizePass { + type Error = Infallible; + type Result = (); + + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(h, root, None, &mut HashMap::new()); + if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs + remove_polyfuncs_ref(h); + } + } Ok(()) } - - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } struct TypeArgsList<'a>(&'a [TypeArg]); @@ -387,9 +357,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::remove_dead_funcs; + use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -410,7 +380,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass::default().run(&mut hugr2).unwrap(); + monomorphize(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -472,7 +442,7 @@ mod test { .count(), 3 ); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -493,7 +463,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass::default().run(&mut mono2)?; + monomorphize(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -601,7 +571,7 @@ mod test { .outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -662,7 +632,7 @@ mod test { let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); @@ -719,7 +689,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a9..e81a640e3 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -26,7 +26,7 @@ use hugr_core::types::{ }; use hugr_core::{Hugr, HugrView, Node, Wire}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; mod linearize; pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; @@ -143,7 +143,6 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - validation: ValidationLevel, } impl Default for ReplaceTypes { @@ -184,8 +183,6 @@ pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), #[error(transparent)] - ValidationError(#[from] ValidatePassError), - #[error(transparent)] ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), @@ -203,16 +200,9 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - validation: Default::default(), } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Configures this instance to replace occurrences of type `src` with `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this @@ -323,36 +313,6 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { - let mut changed = false; - for n in hugr.nodes().collect::>() { - changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.root()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } - } - } - Ok(changed) - } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) @@ -472,11 +432,40 @@ impl ReplaceTypes { false } }), - Value::Function { hugr } => self.run_no_validate(&mut **hugr), + Value::Function { hugr } => self.run(&mut **hugr), } } } +impl ComposablePass for ReplaceTypes { + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let mut changed = false; + for n in hugr.nodes().collect::>() { + changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + } + Ok(changed) + } +} + pub mod handlers; #[derive(Clone, Hash, PartialEq, Eq)] @@ -532,29 +521,26 @@ mod test { use hugr_core::extension::prelude::{ bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, }; - use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; - - use hugr_core::ops::constant::OpaqueValue; - use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::hugr::{IdentList, ValidationError}; + use hugr_core::ops::{ + constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, + }; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{ list_type, list_type_def, ListOp, ListValue, }; - - use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; - use crate::validation::ValidatePassError; + use crate::ComposablePass; - use super::ReplaceTypesError; use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; @@ -979,13 +965,16 @@ mod test { let cu = cst.value().downcast_ref::().unwrap(); Ok(ConstInt::new_u(6, cu.value())?.into()) }); + + let mut h = backup.clone(); + repl.run(&mut h).unwrap(); // No validation here assert!( - matches!(repl.run(&mut backup.clone()), Err(ReplaceTypesError::ValidationError(ValidatePassError::OutputError { - err: ValidationError::IncompatiblePorts {from, to, ..}, .. - })) if backup.get_optype(from).is_const() && to == c.node()) + matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..}) + if backup.get_optype(from).is_const() && to == c.node()) ); repl.replace_consts_parametrized(array_type_def(), array_const); let mut h = backup; - repl.run(&mut h).unwrap(); // Includes validation + repl.run(&mut h).unwrap(); + h.validate_no_extensions().unwrap(); } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5b4da7184..bc508bd53 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -377,7 +377,7 @@ mod test { use crate::replace_types::handlers::linearize_array; use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; - use crate::ReplaceTypes; + use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index dbe04edd1..874fd9ec3 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -10,19 +10,19 @@ use hugr_core::hugr::views::SiblingSubgraph; use hugr_core::hugr::SimpleReplacementError; use hugr_core::ops::{NamedOp, OpTrait, OpType}; use hugr_core::types::Type; -use hugr_core::{HugrView, SimpleReplacement}; +use hugr_core::{HugrView, Node, SimpleReplacement}; use itertools::Itertools; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration enum for the untuple rewrite pass. /// /// Indicates whether the pattern match should traverse the HUGR recursively. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum UntupleRecursive { - /// Traverse the HUGR recursively. + /// Traverse the HUGR recursively, i.e. consider the entire subtree Recursive, - /// Do not traverse the HUGR recursively. + /// Do not traverse the HUGR recursively, i.e. consider only the sibling subgraph #[default] NonRecursive, } @@ -48,22 +48,20 @@ pub enum UntupleRecursive { pub struct UntuplePass { /// Whether to traverse the HUGR recursively. recursive: UntupleRecursive, - /// The level of validation to perform on the rewrite. - validation: ValidationLevel, + /// Parent node under which to operate; None indicates the Hugr root + parent: Option, } #[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)] #[non_exhaustive] /// Errors produced by [UntuplePass]. pub enum UntupleError { - /// An error occurred while validating the rewrite. - ValidationError(ValidatePassError), /// Rewriting the circuit failed. RewriteError(SimpleReplacementError), } /// Result type for the untuple pass. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy, Default, PartialEq)] pub struct UntupleResult { /// Number of `MakeTuple` rewrites applied. pub rewrites_applied: usize, @@ -71,16 +69,16 @@ pub struct UntupleResult { impl UntuplePass { /// Create a new untuple pass with the given configuration. - pub fn new(recursive: UntupleRecursive, validation: ValidationLevel) -> Self { + pub fn new(recursive: UntupleRecursive) -> Self { Self { recursive, - validation, + parent: None, } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; + /// Sets the parent node to optimize (overwrites any previous setting) + pub fn set_parent(mut self, parent: impl Into>) -> Self { + self.parent = parent.into(); self } @@ -90,31 +88,6 @@ impl UntuplePass { self } - /// Run the pass using specified configuration. - pub fn run( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr, parent)) - } - - /// Run the Monomorphization pass. - fn run_no_validate( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - let rewrites = self.find_rewrites(hugr, parent); - let rewrites_applied = rewrites.len(); - // The rewrites are independent, so we can always apply them all. - for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; - } - Ok(UntupleResult { rewrites_applied }) - } - /// Find tuple pack operations followed by tuple unpack operations /// and generate rewrites to remove them. /// @@ -148,6 +121,22 @@ impl UntuplePass { } } +impl ComposablePass for UntuplePass { + type Error = UntupleError; + + type Result = UntupleResult; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); + let rewrites_applied = rewrites.len(); + // The rewrites are independent, so we can always apply them all. + for rewrite in rewrites { + hugr.apply_rewrite(rewrite)?; + } + Ok(UntupleResult { rewrites_applied }) + } +} + /// Returns true if the given optype is a MakeTuple operation. /// /// Boilerplate required due to https://github.com/CQCL/hugr/issues/1496 @@ -421,7 +410,8 @@ mod test { let parent = hugr.root(); let res = pass - .run(&mut hugr, parent) + .set_parent(parent) + .run(&mut hugr) .unwrap_or_else(|e| panic!("{e}")); assert_eq!(res.rewrites_applied, expected_rewrites); assert_eq!(hugr.children(parent).count(), remaining_nodes); From d8a5d6794526f22bc99d7a5489cbcc2d39e3c59a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 10:55:43 +0100 Subject: [PATCH 33/33] feat!: ReplaceTypes: allow lowering ops into a Call to a function already in the Hugr (#2094) There are two issues: * Errors. The previous NodeTemplates still always work, but the Call one can fail if the Hugr doesn't contain the target function node. ATM there is no channel for reporting that error so I've had to panic. Otherwise it's an even-more-breaking change to add an error type to `NodeTemplate::add()` and `NodeTemplate::add_hugr()`. Should we? (I note `HugrMut::connect` panics if the node isn't there, but could make the `NodeTemplate::add` builder method return a BuildError...and propagate that everywhere of course) * There's a big limitation in `linearize_array` that it'll break if the *element* says it should be copied/discarded via a NodeTemplate::Call, as `linearize_array` puts the elementwise copy/discard function into a *nested Hugr* (`Value::Function`) that won't contain the function. This could be fixed via lifting those to toplevel FuncDefns with name-mangling, but I'd rather leave that for #2086 .... BREAKING CHANGE: Add new variant NodeTemplate::Call; LinearizeError no longer derives Eq. --- hugr-passes/src/replace_types.rs | 234 ++++++++++++++++----- hugr-passes/src/replace_types/handlers.rs | 4 +- hugr-passes/src/replace_types/linearize.rs | 104 +++++++-- 3 files changed, 268 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e81a640e3..df4c14075 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -15,16 +15,17 @@ use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; -use hugr_core::ops::handle::DataflowOpID; +use hugr_core::ops::handle::{DataflowOpID, FuncID}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, + TypeTransformer, }; -use hugr_core::{Hugr, HugrView, Node, Wire}; +use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; use crate::ComposablePass; @@ -45,21 +46,37 @@ pub enum NodeTemplate { /// Note this will be of limited use before [monomorphization](super::monomorphize()) /// because the new subtree will not be able to use type variables present in the /// parent Hugr or previous op. - // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s - // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - // TODO allow also Call to a Node in the existing Hugr - // (can't see any other way to achieve multiple calls to the same decl. - // So client should add the functions before replacement, then remove unused ones afterwards.) + /// A Call to an existing function. + Call(Node, Vec), } impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + /// + /// # Panics + /// + /// * If `parent` is not in the `hugr` + /// + /// # Errors + /// + /// * If `self` is a [Self::Call] and the target Node either + /// * is neither a [FuncDefn] nor a [FuncDecl] + /// * has a [`signature`] which the type-args of the [Self::Call] do not match + /// + /// [`signature`]: hugr_core::types::PolyFuncType + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { match self { - NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), + NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), + NodeTemplate::Call(target, type_args) => { + let c = call(hugr, target, type_args)?; + let tgt_port = c.called_function_port(); + let n = hugr.add_node_with_parent(parent, c); + hugr.connect(target, 0, n, tgt_port); + Ok(n) + } } } @@ -72,10 +89,15 @@ impl NodeTemplate { match self { NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + // Really we should check whether func points at a FuncDecl or FuncDefn and create + // the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call. + NodeTemplate::Call(func, type_args) => { + dfb.call(&FuncID::::from(func), &type_args, inputs) + } } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -88,19 +110,57 @@ impl NodeTemplate { } root_opty } + NodeTemplate::Call(func, type_args) => { + let c = call(hugr, func, type_args)?; + let static_inport = c.called_function_port(); + // insert an input for the Call static input + hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1); + // connect the function to (what will be) the call + hugr.connect(func, 0, n, static_inport); + c.into() + } }; *hugr.optype_mut(n) = new_optype; + Ok(()) } - fn signature(&self) -> Option> { - match self { + fn check_signature( + &self, + inputs: &TypeRow, + outputs: &TypeRow, + ) -> Result<(), Option> { + let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::Call(_, _) => return Ok(()), // no way to tell + } + .dataflow_signature(); + if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) { + Ok(()) + } else { + Err(sig.map(Cow::into_owned)) } - .dataflow_signature() } } +fn call>( + h: &H, + func: Node, + type_args: Vec, +) -> Result { + let func_sig = match h.get_optype(func) { + OpType::FuncDecl(fd) => fd.signature.clone(), + OpType::FuncDefn(fd) => fd.signature.clone(), + _ => { + return Err(BuildError::UnexpectedType { + node: func, + op_desc: "func defn/decl", + }) + } + }; + Ok(Call::try_new(func_sig, type_args)?) +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. /// @@ -186,6 +246,8 @@ pub enum ReplaceTypesError { ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), + #[error("Replacement op for {0} could not be added because {1}")] + AddTemplateError(Node, BuildError), } impl ReplaceTypes { @@ -370,8 +432,11 @@ impl ReplaceTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => Ok( + // Copy/discard insertion done by caller if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - replacement.replace(hugr, n); // Copy/discard insertion done by caller + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { let def = ext_op.def_arc(); @@ -382,7 +447,9 @@ impl ReplaceTypes { .get(&def.as_ref().into()) .and_then(|rep_fn| rep_fn(&args)) { - replacement.replace(hugr, n); + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { if ch { @@ -515,24 +582,22 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID, }; - use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version}; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::{IdentList, ValidationError}; - use hugr_core::ops::{ - constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, - }; - use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::ops::constant::OpaqueValue; + use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; + use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef}; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use hugr_core::std_extensions::collections::array::{ - array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, - }; - use hugr_core::std_extensions::collections::list::{ - list_type, list_type_def, ListOp, ListValue, + use hugr_core::std_extensions::collections::{ + array::{self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, + list::{list_type, list_type_def, ListOp, ListValue}, }; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{type_row, Extension, HugrView}; @@ -601,30 +666,37 @@ mod test { ) } - fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { - let ty = just_elem_type(args); - let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], - ty.clone(), - )) + fn lowered_read( + elem_ty: Type, + new: impl Fn(Signature) -> Result, + ) -> T { + let mut dfb = new(Signature::new( + vec![array_type(64, elem_ty.clone()), i64_t()], + elem_ty.clone(), + ) + .with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + array::EXTENSION_ID, + conversions::EXTENSION_ID, + ]))) + .unwrap(); + let [val, idx] = dfb.input_wires_arr(); + let [idx] = dfb + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let [opt] = dfb + .add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx]) + .unwrap() + .outputs_arr(); + let [res] = dfb + .build_unwrap_sum(1, option_type(Type::from(elem_ty)), opt) .unwrap(); - let [val, idx] = dfb.input_wires_arr(); - let [idx] = dfb - .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) - .unwrap() - .outputs_arr(); - let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) - .unwrap() - .outputs_arr(); - let [res] = dfb - .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) - .unwrap(); - Some(NodeTemplate::CompoundOp(Box::new( - dfb.finish_hugr_with_outputs([res]).unwrap(), - ))) - } + dfb.set_outputs([res]).unwrap(); + dfb + } + + fn lowerer(ext: &Arc) -> ReplaceTypes { let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); @@ -640,7 +712,13 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + Some(NodeTemplate::CompoundOp(Box::new( + lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) + .finish_hugr() + .unwrap(), + ))) + }); lw } @@ -977,4 +1055,52 @@ mod test { repl.run(&mut h).unwrap(); h.validate_no_extensions().unwrap(); } + + #[test] + fn op_to_call() { + let e = ext(); + let pv = e.get_type(PACKED_VEC).unwrap(); + let inner = pv.instantiate([usize_t().into()]).unwrap(); + let outer = pv + .instantiate([Type::new_extension(inner.clone()).into()]) + .unwrap(); + let mut dfb = DFGBuilder::new(inout_sig(vec![outer.into(), i64_t()], usize_t())).unwrap(); + let [outer, idx] = dfb.input_wires_arr(); + let [inner] = dfb + .add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx]) + .unwrap() + .outputs_arr(); + let res = dfb + .add_dataflow_op(read_op(&e, usize_t()), [inner, idx]) + .unwrap(); + let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap(); + let read_func = h + .insert_hugr( + h.root(), + lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { + FunctionBuilder::new( + "lowered_read", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + ) + }) + .finish_hugr() + .unwrap(), + ) + .new_root; + + let mut lw = lowerer(&e); + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + lw.run(&mut h).unwrap(); + + assert_eq!(h.output_neighbours(read_func).count(), 2); + let ext_op_names = h + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .map(|e| e.def().name()) + .sorted() + .collect_vec(); + assert_eq!(ext_op_names, ["get", "itousize", "panic",]); + } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b..b6e6e6780 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -92,7 +92,7 @@ pub fn linearize_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .unwrap(); + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; @@ -162,7 +162,7 @@ pub fn linearize_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .unwrap() + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bc508bd53..5c4a4a707 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,10 +1,9 @@ -use std::borrow::Cow; use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ - inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, + inout_sig, BuildError, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; @@ -76,9 +75,11 @@ pub trait Linearizer { tgt_parent, }); } + let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); + .copy_discard_op(&typ, targets.len())? + .add_hugr(hugr, src_parent) + .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -133,7 +134,7 @@ impl Default for DelegatingLinearizer { // rather than passing a &DelegatingLinearizer directly) pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); -#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[derive(Clone, Debug, thiserror::Error, PartialEq)] #[allow(missing_docs)] #[non_exhaustive] pub enum LinearizeError { @@ -163,6 +164,10 @@ pub enum LinearizeError { /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] CopyableType(Type), + /// Error may be returned by a callback for e.g. a container because it could + /// not generate a [NodeTemplate] because of a problem with an element + #[error("Could not generate NodeTemplate for contained type {0} because {1}")] + NestedTemplateError(Type, BuildError), } impl DelegatingLinearizer { @@ -185,8 +190,10 @@ impl DelegatingLinearizer { /// /// * [LinearizeError::CopyableType] If `typ` is /// [Copyable](hugr_core::types::TypeBound::Copyable) - /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the - /// expected inputs or outputs + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the expected + /// inputs or outputs (for [NodeTemplate::SingleOp] and [NodeTemplate::CompoundOp] + /// only: the signature for a [NodeTemplate::Call] cannot be checked until it is used + /// in a Hugr). pub fn register_simple( &mut self, cty: CustomType, @@ -230,18 +237,12 @@ impl DelegatingLinearizer { } fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { - let sig = tmpl.signature(); - if sig.as_ref().is_some_and(|sig| { - sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) - }) { - Ok(()) - } else { - Err(LinearizeError::WrongSignature { + tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + .map_err(|sig| LinearizeError::WrongSignature { typ: typ.clone(), num_outports, - sig: sig.map(Cow::into_owned), + sig, }) - } } impl Linearizer for DelegatingLinearizer { @@ -353,7 +354,10 @@ mod test { use std::iter::successors; use std::sync::Arc; - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; + use hugr_core::builder::{ + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, + }; use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; @@ -768,4 +772,68 @@ mod test { )); assert_eq!(copy_sig.input[2..], copy_sig.output[1..]); } + + #[test] + fn call_ok_except_in_array() { + let (e, _) = ext_lowerer(); + let lin_ct = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t: Type = lin_ct.clone().into(); + + // A simple Hugr that discards a usize_t, with a "drop" function + let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); + let discard_fn = { + let mut fb = dfb + .define_function( + "drop", + Signature::new(lin_t.clone(), type_row![]) + .with_extension_delta(e.name().clone()), + ) + .unwrap(); + let ins = fb.input_wires(); + fb.add_dataflow_op( + ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(), + ins, + ) + .unwrap(); + fb.finish_with_outputs([]).unwrap() + } + .node(); + let backup = dfb.finish_hugr().unwrap(); + + let mut lower_discard_to_call = ReplaceTypes::default(); + // The `copy_fn` here will break completely, but we don't use it + lower_discard_to_call + .linearizer() + .register_simple( + lin_ct.clone(), + NodeTemplate::Call(backup.root(), vec![]), + NodeTemplate::Call(discard_fn, vec![]), + ) + .unwrap(); + + // Ok to lower usize_t to lin_t and call that function + { + let mut lowerer = lower_discard_to_call.clone(); + lowerer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + let mut h = backup.clone(); + lowerer.run(&mut h).unwrap(); + assert_eq!(h.output_neighbours(discard_fn).count(), 1); + } + + // But if we lower usize_t to array, the call will fail + lower_discard_to_call.replace_type( + usize_t().as_extension().unwrap().clone(), + array_type(4, lin_ct.into()), + ); + let r = lower_discard_to_call.run(&mut backup.clone()); + assert!(matches!( + r, + Err(ReplaceTypesError::LinearizeError( + LinearizeError::NestedTemplateError( + nested_t, + BuildError::UnexpectedType { node, .. } + ) + )) if nested_t == lin_t && node == discard_fn + )); + } }