diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index dff471cc5..9c886325a 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1036,6 +1036,7 @@ impl<'a> Context<'a> { self.make_term(table::Term::Apply(symbol, args)) } + #[allow(deprecated)] // Remove when Value::Function removed Value::Function { hugr } => { let outer_hugr = std::mem::replace(&mut self.hugr, hugr); let outer_node_to_id = std::mem::take(&mut self.node_to_id); diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index bb5034e1b..af585b692 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -33,7 +33,10 @@ pub mod resolution; pub mod simple_op; mod type_def; -pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row}; +#[deprecated(note = "Use ConstFolder")] +#[allow(deprecated)] // Remove when ConstFold removed +pub use const_fold::{ConstFold, Folder}; +pub use const_fold::{ConstFoldResult, ConstFolder, FoldVal, fold_out_row}; pub use op_def::{ CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, ValidateJustArgs, ValidateTypeArgs, diff --git a/hugr-core/src/extension/const_fold.rs b/hugr-core/src/extension/const_fold.rs index a2cd66f42..6c3b6bf1d 100644 --- a/hugr-core/src/extension/const_fold.rs +++ b/hugr-core/src/extension/const_fold.rs @@ -1,19 +1,134 @@ -use std::fmt::Formatter; - -use std::fmt::Debug; +use std::fmt::{Debug, Formatter}; use crate::ops::Value; -use crate::types::TypeArg; +use crate::ops::constant::{CustomConst, OpaqueValue, Sum}; +use crate::types::{SumType, TypeArg}; +use crate::{IncomingPort, Node, OutgoingPort, PortIndex}; + +/// Representation of values used for constant folding. +/// See [ConstFold], which is used as `dyn` so we cannot parametrize by +/// [HugrNode](crate::core::HugrNode). +// Should we be non-exhaustive?? +#[derive(Clone, Debug, PartialEq, Default)] +pub enum FoldVal { + /// Value is unknown, must assume that it could be anything + #[default] + Unknown, + /// A variant of a [SumType] + Sum { + /// Which variant of the sum type this value is. + tag: usize, + /// Describes the type of the whole value. + // Can we deprecate this immediately? It is only for converting to Value + sum_type: SumType, + /// A value for each element (type) within the variant + items: Vec, + }, + /// A constant value defined by an extension + Extension(OpaqueValue), + /// A function pointer loaded from a [FuncDefn](crate::ops::FuncDefn) or `FuncDecl` + LoadedFunction(Node, Vec), // Deliberately skipping Function(Box) ATM +} + +impl From for FoldVal +where + T: CustomConst, +{ + fn from(value: T) -> Self { + Self::Extension(value.into()) + } +} + +impl FoldVal { + /// Returns a constant "false" value, i.e. the first variant of Sum((), ()). + pub const fn false_val() -> Self { + Self::Sum { + tag: 0, + sum_type: SumType::Unit { size: 2 }, + items: vec![], + } + } + + /// Returns a constant "true" value, i.e. the second variant of Sum((), ()). + pub const fn true_val() -> Self { + Self::Sum { + tag: 1, + sum_type: SumType::Unit { size: 2 }, + items: vec![], + } + } + + /// Returns a constant boolean - either [Self::false_val] or [Self::true_val] + pub const fn from_bool(b: bool) -> Self { + if b { + Self::true_val() + } else { + Self::false_val() + } + } + + /// Extract the specified type of [CustomConst] fro this instance, if it is one + pub fn get_custom_value(&self) -> Option<&T> { + let Self::Extension(e) = self else { + return None; + }; + e.value().downcast_ref() + } +} -use crate::IncomingPort; -use crate::OutgoingPort; +impl TryFrom for Value { + type Error = Option; -use crate::ops; + fn try_from(value: FoldVal) -> Result { + match value { + FoldVal::Unknown => Err(None), + FoldVal::Sum { + tag, + sum_type, + items, + } => { + let values = items + .into_iter() + .map(Value::try_from) + .collect::, _>>()?; + Ok(Value::Sum(Sum { + tag, + values, + sum_type, + })) + } + FoldVal::Extension(e) => Ok(Value::Extension { e }), + FoldVal::LoadedFunction(node, _) => Err(Some(node)), + } + } +} + +impl From for FoldVal { + fn from(value: Value) -> Self { + match value { + Value::Extension { e } => FoldVal::Extension(e), + #[allow(deprecated)] // remove when Value::Function removed + Value::Function { .. } => FoldVal::Unknown, + Value::Sum(Sum { + tag, + values, + sum_type, + }) => { + let items = values.into_iter().map(FoldVal::from).collect(); + FoldVal::Sum { + tag, + sum_type, + items, + } + } + } + } +} /// Output of constant folding an operation, None indicates folding was either /// not possible or unsuccessful. An empty vector indicates folding was /// successful and no values are output. -pub type ConstFoldResult = Option>; +pub type ConstFoldResult = Option>; /// Tag some output constants with [`OutgoingPort`] inferred from the ordering. pub fn fold_out_row(consts: impl IntoIterator) -> ConstFoldResult { @@ -25,7 +140,9 @@ pub fn fold_out_row(consts: impl IntoIterator) -> ConstFoldResult Some(vec) } -/// Trait implemented by extension operations that can perform constant folding. +#[deprecated(note = "Use ConstFolder")] +/// Old trait implemented by extension operations that can perform constant folding. +/// Deprecated: see [ConstFolder] pub trait ConstFold: Send + Sync { /// Given type arguments `type_args` and /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, @@ -37,14 +154,58 @@ pub trait ConstFold: Send + Sync { ) -> ConstFoldResult; } +/// Trait implemented by extension operations that can perform constant folding. +pub trait ConstFolder: Send + Sync { + /// Given type arguments `type_args` and [`FoldVal`]s for each input, + /// update the outputs (these will be initialized to [FoldVal::Unknown]). + fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]); +} + +pub(super) fn fold_vals_to_indexed_vals>(fvs: &[FoldVal]) -> Vec<(P, Value)> { + fvs.iter() + .cloned() + .enumerate() + .filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?))) + .collect::>() +} + +#[allow(deprecated)] // Legacy conversion routine, remove when ConstFold removed +fn do_fold( + old_fold: &impl ConstFold, + type_args: &[TypeArg], + inputs: &[FoldVal], + outputs: &mut [FoldVal], +) { + let consts = fold_vals_to_indexed_vals(inputs); + let outs = old_fold.fold(type_args, &consts); + for (p, v) in outs.unwrap_or_default() { + outputs[p.index()] = v.into(); + } +} + +#[allow(deprecated)] // Remove when ConstFold removed +impl ConstFolder for T { + fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + do_fold(self, type_args, inputs, outputs) + } +} + +#[allow(deprecated)] // Remove when ConstFold removed impl Debug for Box { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "") } } +impl Debug for Box { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + /// Blanket implementation for functions that only require the constants to /// evaluate - type arguments are not relevant. +#[allow(deprecated)] // Remove when ConstFold removed impl ConstFold for T where T: Fn(&[(crate::IncomingPort, crate::ops::Value)]) -> ConstFoldResult + Send + Sync, @@ -60,14 +221,25 @@ where type FoldFn = dyn Fn(&[TypeArg], &[(IncomingPort, Value)]) -> ConstFoldResult + Send + Sync; -/// Type holding a boxed const-folding function. +/// Legacy type holding a boxed const-folding function. +/// Deprecated: use [BoxedFolder] instead. +#[deprecated(note = "Use BoxedFolder")] pub struct Folder { /// Const-folding function. pub folder: Box, } +#[allow(deprecated)] // Remove when ConstFold removed impl ConstFold for Folder { fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { (self.folder)(type_args, consts) } } + +pub struct BoxedFolder(Box); + +impl ConstFolder for BoxedFolder { + fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + self.0(type_args, inputs, outputs) + } +} diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 9c30cbdd4..bba164c57 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -6,16 +6,19 @@ use std::sync::{Arc, Weak}; use serde_with::serde_as; +use crate::envelope::serde_with::AsStringEnvelope; +use crate::extension::const_fold::fold_vals_to_indexed_vals; +use crate::ops::{OpName, OpNameRef, Value}; +use crate::types::type_param::{TypeArg, TypeParam, check_type_args}; +use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; +use crate::{Hugr, IncomingPort, PortIndex}; + +use super::const_fold::FoldVal; use super::{ - ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet, + ConstFoldResult, ConstFolder, Extension, ExtensionBuildError, ExtensionId, ExtensionSet, SignatureError, }; -use crate::Hugr; -use crate::envelope::serde_with::AsStringEnvelope; -use crate::ops::{OpName, OpNameRef}; -use crate::types::type_param::{TypeArg, TypeParam, check_type_args}; -use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; mod serialize_signature_func; /// Trait necessary for binary computations of `OpDef` signature @@ -327,7 +330,7 @@ pub struct OpDef { /// Operations can optionally implement [`ConstFold`] to implement constant folding. #[serde(skip)] - constant_folder: Option>, + constant_folder: Option>, } impl OpDef { @@ -457,19 +460,37 @@ impl OpDef { /// Set the constant folding function for this Op, which can evaluate it /// given constant inputs. - pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { + pub fn set_constant_folder(&mut self, fold: impl ConstFolder + 'static) { self.constant_folder = Some(Box::new(fold)); } /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s. #[must_use] + #[deprecated(note = "use const_fold")] pub fn constant_fold( &self, type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Value)], + consts: &[(IncomingPort, Value)], ) -> ConstFoldResult { - (self.constant_folder.as_ref())?.fold(type_args, consts) + let folder = self.constant_folder.as_ref()?; + let sig = self.compute_signature(type_args).unwrap(); + let mut inputs = vec![FoldVal::Unknown; sig.input_count()]; + for (p, v) in consts { + inputs[p.index()] = v.clone().into(); + } + let mut outputs = vec![FoldVal::Unknown; sig.output_count()]; + folder.fold(type_args, &inputs, &mut outputs); + Some(fold_vals_to_indexed_vals(&outputs)) + } + + /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given + /// [FoldVal] values for each input, and update the outputs, which should be + /// initialised to [FoldVal::Unknown]. + pub fn const_fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + if let Some(cf) = self.constant_folder.as_ref() { + cf.fold(type_args, inputs, outputs) + } } /// Returns a reference to the signature function of this [`OpDef`]. diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 3af70b75b..37d00fe76 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -11,7 +11,7 @@ use crate::extension::simple_op::{ MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name, }; use crate::extension::{ - ConstFold, ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDefBound, + ConstFolder, ExtensionId, FoldVal, OpDef, SignatureError, SignatureFunc, TypeDefBound, }; use crate::ops::OpName; use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; @@ -585,7 +585,8 @@ pub enum TupleOpDef { UnpackTuple, } -impl ConstFold for TupleOpDef { +#[allow(deprecated)] // TODO: need a way to handle types of tuples. Or drop that SumType... +impl super::ConstFold for TupleOpDef { fn fold( &self, _type_args: &[TypeArg], @@ -823,13 +824,9 @@ impl MakeOpDef for NoopDef { } } -impl ConstFold for NoopDef { - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, Value)], - ) -> crate::extension::ConstFoldResult { - fold_out_row([consts.first()?.1.clone()]) +impl ConstFolder for NoopDef { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + outputs[0] = inputs[0].clone() } } diff --git a/hugr-core/src/extension/prelude/generic.rs b/hugr-core/src/extension/prelude/generic.rs index 9ea231e1b..01454120b 100644 --- a/hugr-core/src/extension/prelude/generic.rs +++ b/hugr-core/src/extension/prelude/generic.rs @@ -1,30 +1,19 @@ use std::str::FromStr; use std::sync::{Arc, Weak}; -use crate::extension::OpDef; -use crate::extension::SignatureFunc; +use crate::Extension; use crate::extension::prelude::usize_custom_t; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ConstFold, ExtensionId}; -use crate::ops::ExtensionOp; -use crate::ops::OpName; +use crate::extension::{ConstFolder, ExtensionId, FoldVal, OpDef, SignatureError, SignatureFunc}; +use crate::ops::{ExtensionOp, OpName}; use crate::type_row; -use crate::types::FuncValueType; - -use crate::types::Type; - -use crate::extension::SignatureError; - -use crate::types::PolyFuncTypeRV; - -use crate::Extension; -use crate::types::type_param::TypeArg; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type}; use super::PRELUDE; use super::{ConstUsize, PRELUDE_ID}; -use crate::types::type_param::TypeParam; /// Name of the operation for loading generic `BoundedNat` parameters. pub static LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat"); @@ -41,21 +30,11 @@ impl FromStr for LoadNatDef { } } -impl ConstFold for LoadNatDef { - fn fold( - &self, - type_args: &[TypeArg], - _consts: &[(crate::IncomingPort, crate::ops::Value)], - ) -> crate::extension::ConstFoldResult { - let [arg] = type_args else { - return None; - }; - let nat = arg.as_nat(); - if let Some(n) = nat { - let n_const = ConstUsize::new(n); - Some(vec![(0.into(), n_const.into())]) - } else { - None +impl ConstFolder for LoadNatDef { + fn fold(&self, type_args: &[TypeArg], _inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let [arg] = type_args else { return }; + if let Some(n) = arg.as_nat() { + outputs[0] = ConstUsize::new(n).into(); } } } @@ -161,10 +140,11 @@ impl HasConcrete for LoadNatDef { #[cfg(test)] mod tests { use crate::{ - HugrView, OutgoingPort, + HugrView, builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig}, + extension::FoldVal, extension::prelude::{ConstUsize, usize_t}, - ops::{OpType, constant}, + ops::OpType, type_row, types::TypeArg, }; @@ -201,10 +181,10 @@ mod tests { let optype: OpType = op.into(); if let OpType::ExtensionOp(ext_op) = optype { - let result = ext_op.constant_fold(&[]); - let exp_port: OutgoingPort = 0.into(); - let exp_val: constant::Value = ConstUsize::new(5).into(); - assert_eq!(result, Some(vec![(exp_port, exp_val)])); + let mut out = [FoldVal::Unknown]; + ext_op.const_fold(&[], &mut out); + let exp_val: FoldVal = ConstUsize::new(5).into(); + assert_eq!(out, [exp_val]) } else { panic!() } diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 531509d6e..1011ee5fe 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -246,6 +246,7 @@ fn collect_value_exts( let typ = e.get_type(); collect_type_exts(&typ, used_extensions, missing_extensions); } + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr: _ } => { // The extensions used by nested hugrs do not need to be counted for the root hugr. } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index c4093a18c..f98fba83d 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -257,6 +257,7 @@ pub(super) fn resolve_value_exts( }); } } + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr } => { // We don't need to add the nested hugr's extensions to the main one here, // but we run resolution on it independently. diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 2b500ed03..5ee511fd4 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -451,6 +451,7 @@ fn roundtrip_sumtype(#[case] sum_type: SumType) { #[case(Value::extension(ConstInt::new_u(2,1).unwrap()))] #[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())] #[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))] +#[allow(deprecated)] // remove when Value::Function removed #[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())] fn roundtrip_value(#[case] value: Value) { check_testing_roundtrip(value); @@ -511,6 +512,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::AliasDefn { name: "aliasdefn".into(), definition: Type::new_unit_sum(4)})] #[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Any})] #[case(ops::Const::new(Value::false_val()))] +#[allow(deprecated)] // remove when Value::Function removed #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index d27a4a0ad..dc1619da0 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -15,7 +15,6 @@ use crate::{Hugr, HugrView}; use delegate::delegate; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use serde_with::serde_as; use smol_str::SmolStr; use thiserror::Error; @@ -197,29 +196,37 @@ impl From for SerialSum { } } -#[serde_as] -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(tag = "v")] -/// A value that can be stored as a static constant. Representing core types and -/// extension types. -pub enum Value { - /// An extension constant value, that can check it is of a given [CustomType]. - Extension { - #[serde(flatten)] - /// The custom constant value. - e: OpaqueValue, - }, - /// A higher-order function value. - Function { - /// A Hugr defining the function. - #[serde_as(as = "Box")] - hugr: Box, - }, - /// A Sum variant, with a tag indicating the index of the variant and its - /// value. - #[serde(alias = "Tuple")] - Sum(Sum), -} +mod inner { + #![allow(deprecated)] // serde-generated code refers to the deprecated Value::Function + use super::{AsStringEnvelope, Hugr, OpaqueValue, Sum}; + use serde_with::serde_as; + + #[serde_as] + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] + #[serde(tag = "v")] + /// A value that can be stored as a static constant. Representing core types and + /// extension types. + pub enum Value { + /// An extension constant value + Extension { + #[serde(flatten)] + /// The custom constant value. + e: OpaqueValue, + }, + /// A higher-order function value. + #[deprecated(note = "Flatten and lift contents to a FuncDefn")] + Function { + /// A Hugr defining the function. + #[serde_as(as = "Box")] + hugr: Box, + }, + /// A Sum variant, with a tag indicating the index of the variant and its + /// value. + #[serde(alias = "Tuple")] + Sum(Sum), + } +} //end mod inner +pub use inner::Value; /// An opaque newtype around a [`Box`](CustomConst). /// @@ -381,6 +388,7 @@ impl Value { match self { Self::Extension { e } => e.get_type(), Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(), + #[allow(deprecated)] // remove when Value::Function removed Self::Function { hugr } => { let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e)); Type::new_function(func_type.into_owned()) @@ -418,9 +426,11 @@ impl Value { /// # Errors /// /// Returns an error if the Hugr root node does not define a function. + #[deprecated(note = "Flatten and lift contents to a FuncDefn")] pub fn function(hugr: impl Into) -> Result { let hugr = hugr.into(); mono_fn_type(&hugr)?; + #[allow(deprecated)] // In deprecated function, remove at same time Ok(Self::Function { hugr: Box::new(hugr), }) @@ -506,6 +516,7 @@ impl Value { fn name(&self) -> OpName { match self { Self::Extension { e } => format!("const:custom:{}", e.name()), + #[allow(deprecated)] // remove when Value::Function removed Self::Function { hugr: h } => { let Ok(t) = mono_fn_type(h) else { panic!("HUGR root node isn't a valid function parent."); @@ -532,6 +543,7 @@ impl Value { pub fn validate(&self) -> Result<(), ConstTypeError> { match self { Self::Extension { e } => Ok(e.value().validate()?), + #[allow(deprecated)] // remove when Value::Function removed Self::Function { hugr } => { mono_fn_type(hugr)?; Ok(()) @@ -563,6 +575,7 @@ impl Value { pub fn try_hash(&self, st: &mut H) -> bool { match self { Value::Extension { e } => e.value().try_hash(&mut *st), + #[allow(deprecated)] // remove when Value::Function removed Value::Function { .. } => false, Value::Sum(s) => s.try_hash(st), } @@ -732,6 +745,7 @@ pub(crate) mod test { ); } + #[allow(deprecated)] // remove when Value::Function removed #[rstest] fn function_value(simple_dfg_hugr: Hugr) { let v = Value::function(simple_dfg_hugr).unwrap(); @@ -960,10 +974,8 @@ pub(crate) mod test { type Strategy = BoxedStrategy; fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { use ::proptest::collection::vec; - let leaf_strat = prop_oneof![ - any::().prop_map(|e| Self::Extension { e }), - crate::proptest::any_hugr().prop_map(|x| Value::function(x).unwrap()) - ]; + let leaf_strat = any::().prop_map(|e| Self::Extension { e }); + leaf_strat .prop_recursive( 3, // No more than 3 branch levels deep diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index f63958478..6b784843a 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -12,7 +12,7 @@ use { use crate::core::HugrNode; use crate::extension::simple_op::MakeExtensionOp; -use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError}; +use crate::extension::{ConstFoldResult, ExtensionId, FoldVal, OpDef, SignatureError}; use crate::types::{Signature, type_param::TypeArg}; use crate::{IncomingPort, ops}; @@ -95,11 +95,19 @@ impl ExtensionOp { } /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. + /// Deprecated: use [Self::const_fold] #[must_use] + #[deprecated(note = "Use const_fold")] pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult { + #[allow(deprecated)] // in deprecated function, remove at same time self.def().constant_fold(self.args(), consts) } + /// Attempt to evaluate this operation, See [`OpDef::const_fold`] + pub fn const_fold(&self, inputs: &[FoldVal], outputs: &mut [FoldVal]) { + self.def().const_fold(self.args(), inputs, outputs) + } + /// Creates a new [`OpaqueOp`] as a downgraded version of this /// [`ExtensionOp`]. /// diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index c954ef2bb..b114344d4 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -214,9 +214,7 @@ impl HasDef for ConvertOpType { mod test { use rstest::rstest; - use crate::IncomingPort; - use crate::extension::prelude::ConstUsize; - use crate::ops::Value; + use crate::extension::{FoldVal, prelude::ConstUsize}; use crate::std_extensions::arithmetic::int_types::ConstInt; use super::*; @@ -251,35 +249,21 @@ mod test { } #[rstest] - #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])] - #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])] - #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])] - #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])] + #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[FoldVal::false_val()])] + #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[FoldVal::true_val()])] + #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[FoldVal::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])] + #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[FoldVal::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])] #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])] #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])] fn convert_fold( #[case] op: ConvertOpType, - #[case] inputs: &[Value], - #[case] outputs: &[Value], + #[case] inputs: &[FoldVal], + #[case] outputs: &[FoldVal], ) { - use crate::ops::Value; - - let consts: Vec<(IncomingPort, Value)> = inputs - .iter() - .enumerate() - .map(|(i, v)| (i.into(), v.clone())) - .collect(); - - let res = op - .to_extension_op() + let mut working = vec![FoldVal::Unknown; outputs.len()]; + op.to_extension_op() .unwrap() - .constant_fold(&consts) - .unwrap(); - - for (i, expected) in outputs.iter().enumerate() { - let res_val: &Value = &res.get(i).unwrap().1; - - assert_eq!(res_val, expected); - } + .const_fold(inputs, &mut working); + assert_eq!(&working, outputs); } } diff --git a/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs b/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs index 90615d730..3f9a52ffd 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs @@ -1,3 +1,4 @@ +#![allow(deprecated)] // TODO use crate::extension::prelude::{ConstString, ConstUsize}; use crate::ops::Value; use crate::ops::constant::get_single_input_value; diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index ecf99cda0..80f46cb2f 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -133,7 +133,10 @@ impl MakeRegisteredOp for FloatOps { #[cfg(test)] mod test { + use crate::extension::FoldVal; + use crate::std_extensions::arithmetic::float_types::ConstF64; use cgmath::AbsDiffEq; + use itertools::Itertools; use rstest::rstest; use super::*; @@ -158,28 +161,20 @@ mod test { #[case::fceil(FloatOps::fceil, &[42.42], &[43.])] #[case::fround(FloatOps::fround, &[42.42], &[42.])] fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) { - use crate::ops::Value; - use crate::std_extensions::arithmetic::float_types::ConstF64; - let consts: Vec<_> = inputs .iter() - .enumerate() - .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x)))) + .map(|&f| FoldVal::from(ConstF64::new(f))) .collect(); - let res = op - .to_extension_op() + let mut actual = vec![FoldVal::Unknown; outputs.len()]; + op.to_extension_op() .unwrap() - .constant_fold(&consts) - .unwrap(); - - for (i, expected) in outputs.iter().enumerate() { - let res_val: f64 = res - .get(i) - .unwrap() - .1 + .const_fold(&consts, &mut actual); + + for (act, expected) in actual.into_iter().zip_eq(outputs) { + let res_val: f64 = act .get_custom_value::() - .expect("This function assumes all incoming constants are floats.") + .expect("This function assumes all output constants are floats.") .value(); assert!( diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs b/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs index 40cc82a55..af36ac548 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs @@ -1,3 +1,4 @@ +#![allow(deprecated)] // TODO use crate::{ IncomingPort, extension::{ConstFold, ConstFoldResult, OpDef, prelude::ConstString}, diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 0b4619fbc..3c27cf8df 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -357,12 +357,11 @@ fn sum_ty_with_err(t: Type) -> Type { #[cfg(test)] mod test { + use itertools::Itertools; use rstest::rstest; - use crate::{ - ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type, - types::Signature, - }; + use crate::std_extensions::arithmetic::int_types::{ConstInt, int_type}; + use crate::{extension::FoldVal, ops::dataflow::DataflowOpTrait, types::Signature}; use super::*; @@ -449,33 +448,18 @@ mod test { #[case] outputs: &[u64], #[case] log_width: u8, ) { - use crate::ops::Value; - use crate::std_extensions::arithmetic::int_types::ConstInt; - let consts: Vec<_> = inputs .iter() - .enumerate() - .map(|(i, &x)| { - ( - i.into(), - Value::extension(ConstInt::new_u(log_width, x).unwrap()), - ) - }) + .map(|&x| FoldVal::from(ConstInt::new_u(log_width, x).unwrap())) .collect(); - let res = op - .to_extension_op() - .unwrap() - .constant_fold(&consts) - .unwrap(); + let mut outs = vec![FoldVal::Unknown; outputs.len()]; + op.to_extension_op().unwrap().const_fold(&consts, &mut outs); - for (i, &expected) in outputs.iter().enumerate() { - let res_val: u64 = res - .get(i) - .unwrap() - .1 + for (act, &expected) in outs.into_iter().zip_eq(outputs) { + let res_val: u64 = act .get_custom_value::() - .expect("This function assumes all incoming constants are floats.") + .expect("This function assumes all result constants are floats.") .value_u(); assert_eq!(res_val, expected); diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs index 5e91d5057..6806c1934 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -1,3 +1,4 @@ +#![allow(deprecated)] // TODO use std::cmp::{max, min}; use crate::{ diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 05d05048a..b484b9810 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -387,13 +387,12 @@ impl ListOpInst { mod test { use rstest::rstest; - use crate::PortIndex; use crate::extension::prelude::{ - const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, + ConstUsize, const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, qb_t, usize_t, }; - use crate::ops::OpTrait; use crate::{ - extension::prelude::{ConstUsize, qb_t, usize_t}, + extension::FoldVal, + ops::OpTrait, std_extensions::arithmetic::float_types::{ConstF64, float64_type}, types::TypeRow, }; @@ -513,29 +512,23 @@ mod test { #[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])] #[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])] fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) { - let consts: Vec<_> = inputs + let inputs: Vec<_> = inputs .iter() - .enumerate() - .map(|(i, x)| (i.into(), x.to_value())) + .map(TestVal::to_value) + .map(FoldVal::from) .collect(); - let res = op - .with_type(usize_t()) + let mut actual = vec![FoldVal::Unknown; outputs.len()]; + + op.with_type(usize_t()) .to_extension_op() .unwrap() - .constant_fold(&consts) - .unwrap(); + .const_fold(&inputs, &mut actual); - for (i, expected) in outputs.iter().enumerate() { - let expected = expected.to_value(); - let res_val = res - .iter() - .find(|(port, _)| port.index() == i) - .unwrap() - .1 - .clone(); + for (act, expected) in actual.into_iter().zip_eq(outputs) { + let expected = expected.to_value().into(); - assert_eq!(res_val, expected); + assert_eq!(act, expected); } } } diff --git a/hugr-core/src/std_extensions/collections/list/list_fold.rs b/hugr-core/src/std_extensions/collections/list/list_fold.rs index 4bd88a691..5b0dd3a5c 100644 --- a/hugr-core/src/std_extensions/collections/list/list_fold.rs +++ b/hugr-core/src/std_extensions/collections/list/list_fold.rs @@ -1,14 +1,12 @@ //! Folding definitions for list operations. -use crate::IncomingPort; use crate::extension::prelude::{ ConstUsize, const_fail, const_none, const_ok, const_ok_tuple, const_some, }; -use crate::extension::{ConstFold, ConstFoldResult, OpDef}; +use crate::extension::{ConstFolder, FoldVal, OpDef}; use crate::ops::Value; use crate::types::Type; use crate::types::type_param::TypeArg; -use crate::utils::sorted_consts; use super::{ListOp, ListValue}; @@ -25,114 +23,128 @@ pub(super) fn set_fold(op: &ListOp, def: &mut OpDef) { pub struct PopFold; -impl ConstFold for PopFold { - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, Value)], - ) -> ConstFoldResult { - let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); +impl ConstFolder for PopFold { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let [fv] = inputs else { + panic!("Expected one input") + }; + let list: &ListValue = fv.get_custom_value().expect("Should be list value."); let mut list = list.clone(); - if let Some(elem) = list.0.pop() { - Some(vec![(0.into(), list.into()), (1.into(), const_some(elem))]) - } else { - let elem_type = list.1.clone(); - Some(vec![ - (0.into(), list.into()), - (1.into(), const_none(elem_type)), - ]) - } + let elem_type = list.1.clone(); + outputs[1] = list + .0 + .pop() + .map_or(const_none(elem_type), const_some) + .into(); + outputs[0] = list.into(); } } pub struct PushFold; -impl ConstFold for PushFold { - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, Value)], - ) -> ConstFoldResult { - let [list, elem]: [&Value; 2] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - let mut list = list.clone(); - list.0.push(elem.clone()); - - Some(vec![(0.into(), list.into())]) +impl ConstFolder for PushFold { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let [list, elem] = inputs else { + panic!("Expected two inputs") + }; + if let Some(list) = list.get_custom_value::() { + // We have to convert `elem` to a Value to store it in the list (TODO) + // So e.g. a LoadedFunction would mean we can't constant-fold. + if let Ok(elem) = elem.clone().try_into() { + let mut list = list.clone(); + list.0.push(elem); + outputs[0] = list.into(); + } + } } } pub struct GetFold; -impl ConstFold for GetFold { - fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let [list, index]: [&Value; 2] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - let index: &ConstUsize = index.get_custom_value().expect("Should be int value."); - let idx = index.value() as usize; - - match list.0.get(idx) { - Some(elem) => Some(vec![(0.into(), const_some(elem.clone()))]), - None => Some(vec![(0.into(), const_none(list.1.clone()))]), +impl ConstFolder for GetFold { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let [list, index] = inputs else { + panic!("Expected two inputs") + }; + if let Some(list) = list.get_custom_value::() { + if let Some(index) = index.get_custom_value::() { + let idx = index.value() as usize; + + outputs[0] = match list.0.get(idx) { + Some(elem) => const_some(elem.clone()), + None => const_none(list.1.clone()), + } + .into(); + } } } } pub struct SetFold; -impl ConstFold for SetFold { - fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let [list, idx, elem]: [&Value; 3] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - - let idx: &ConstUsize = idx.get_custom_value().expect("Should be int value."); - let idx = idx.value() as usize; - - let mut list = list.clone(); - let mut elem = elem.clone(); - let res_elem: Value = match list.0.get_mut(idx) { - Some(old_elem) => { - std::mem::swap(old_elem, &mut elem); - const_ok(elem, list.1.clone()) - } - None => const_fail(elem, list.1.clone()), +impl ConstFolder for SetFold { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let [list, idx, elem] = inputs else { + panic!("Expected 3 inputs") }; - Some(vec![(0.into(), list.into()), (1.into(), res_elem)]) + if let Some(list) = list.get_custom_value::() { + if let Some(idx) = idx.get_custom_value::() { + if let Ok(mut elem) = Value::try_from(elem.clone()) { + let idx = idx.value() as usize; + + let mut list = list.clone(); + let res_elem: Value = match list.0.get_mut(idx) { + Some(old_elem) => { + std::mem::swap(old_elem, &mut elem); + const_ok(elem, list.1.clone()) + } + None => const_fail(elem, list.1.clone()), + }; + outputs[0] = list.into(); + outputs[1] = res_elem.into(); + } + } + } } } pub struct InsertFold; -impl ConstFold for InsertFold { - fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let [list, idx, elem]: [&Value; 3] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - - let idx: &ConstUsize = idx.get_custom_value().expect("Should be int value."); - let idx = idx.value() as usize; - - let mut list = list.clone(); - let elem = elem.clone(); - let res_elem: Value = if list.0.len() > idx { - list.0.insert(idx, elem); - const_ok_tuple([], list.1.clone()) - } else { - const_fail(elem, Type::UNIT) +impl ConstFolder for InsertFold { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let [list, idx, elem] = inputs else { + panic!("Expected 3 inputs") }; - Some(vec![(0.into(), list.into()), (1.into(), res_elem)]) + if let Some(list) = list.get_custom_value::() { + if let Some(idx) = idx.get_custom_value::() { + if let Ok(elem) = Value::try_from(elem.clone()) { + let idx = idx.value() as usize; + + let mut list = list.clone(); + let res_elem: Value = if list.0.len() > idx { + list.0.insert(idx, elem); + const_ok_tuple([], list.1.clone()) + } else { + const_fail(elem, Type::UNIT) + }; + outputs[0] = list.into(); + outputs[1] = res_elem.into(); + } + } + } } } pub struct LengthFold; -impl ConstFold for LengthFold { - fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - let len = list.0.len(); - - Some(vec![(0.into(), ConstUsize::new(len as u64).into())]) +impl ConstFolder for LengthFold { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let [list] = inputs else { + panic!("Expected one input") + }; + if let Some(list) = list.get_custom_value::() { + outputs[0] = ConstUsize::new(list.0.len() as u64).into(); + } } } diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index 438dcf802..a3ff04a75 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -4,62 +4,51 @@ use std::sync::{Arc, Weak}; use strum::{EnumIter, EnumString, IntoStaticStr}; -use crate::extension::{ConstFold, ConstFoldResult}; -use crate::ops::constant::ValueName; -use crate::ops::{OpName, Value}; -use crate::types::Signature; -use crate::{ - Extension, IncomingPort, - extension::{ - ExtensionId, OpDef, SignatureFunc, - prelude::bool_t, - simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name}, - }, - ops, - types::type_param::TypeArg, - utils::sorted_consts, +use crate::Extension; +use crate::extension::{ + ConstFolder, ExtensionId, FoldVal, OpDef, SignatureFunc, + prelude::bool_t, + simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name}, }; +use crate::ops::{OpName, constant::ValueName}; +use crate::types::{Signature, type_param::TypeArg}; + use lazy_static::lazy_static; /// Name of extension false value. pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE"); /// Name of extension true value. pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE"); -impl ConstFold for LogicOp { - fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - match self { +impl ConstFolder for LogicOp { + fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { + let inps = read_known_inputs(inputs); + let out = match self { Self::And => { - let inps = read_inputs(consts)?; let res = inps.iter().all(|x| *x); // We can only fold to true if we have a const for all our inputs. - (!res || inps.len() as u64 == 2) - .then_some(vec![(0.into(), ops::Value::from_bool(res))]) + (!res || inps.len() as u64 == 2).then_some(res) } Self::Or => { - let inps = read_inputs(consts)?; let res = inps.iter().any(|x| *x); // We can only fold to false if we have a const for all our inputs - (res || inps.len() as u64 == 2) - .then_some(vec![(0.into(), ops::Value::from_bool(res))]) + (res || inps.len() == 2).then_some(res) } Self::Eq => { - let inps = read_inputs(consts)?; - let res = inps.iter().copied().reduce(|a, b| a == b)?; - // If we have only some inputs, we can still fold to false, but not to true - (!res || inps.len() as u64 == 2) - .then_some(vec![(0.into(), ops::Value::from_bool(res))]) + debug_assert_eq!(inputs.len(), 2); + (inps.len() == 2).then(|| inps[0] == inps[1]) } Self::Not => { - let inps = read_inputs(consts)?; - let res = inps.iter().all(|x| !*x); - (!res || inps.len() as u64 == 1) - .then_some(vec![(0.into(), ops::Value::from_bool(res))]) + debug_assert_eq!(inputs.len(), 1); + inps.first().map(|b| !*b) } Self::Xor => { - let inps = read_inputs(consts)?; - let res = inps.iter().fold(false, |acc, x| acc ^ *x); - (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))]) + debug_assert_eq!(inputs.len(), 2); + (inps.len() == 2).then(|| inps[0] ^ inps[1]) } + }; + debug_assert_eq!(outputs, &[FoldVal::Unknown]); + if let Some(res) = out { + outputs[0] = FoldVal::from_bool(res); } } } @@ -84,9 +73,9 @@ impl MakeOpDef for LogicOp { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { match self { LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => { - Signature::new(vec![bool_t(); 2], vec![bool_t()]) + Signature::new(vec![bool_t(); 2], bool_t()) } - LogicOp::Not => Signature::new_endo(vec![bool_t()]), + LogicOp::Not => Signature::new_endo(bool_t()), } .into() } @@ -146,23 +135,22 @@ impl MakeRegisteredOp for LogicOp { } } -fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { - let true_val = ops::Value::true_val(); - let false_val = ops::Value::false_val(); - let inps: Option> = sorted_consts(consts) - .into_iter() - .map(|c| { - if c == &true_val { - Some(true) - } else if c == &false_val { - Some(false) - } else { - None - } - }) - .collect(); - let inps = inps?; - Some(inps) +fn read_known_inputs(consts: &[FoldVal]) -> Vec { + let true_val = FoldVal::true_val(); + let false_val = FoldVal::false_val(); + let mut res = Vec::new(); + for c in consts { + if c == &true_val { + res.push(true) + } else if c == &false_val { + res.push(false) + } else if c != &FoldVal::Unknown { + // Preserving legacy behaviour, but if any input is not true/false, + // bail completely. + return vec![]; + } + } + res } #[cfg(test)] @@ -173,9 +161,10 @@ pub(crate) mod test { use crate::{ Extension, extension::simple_op::{MakeOpDef, MakeRegisteredOp}, - ops::Value, + extension::{ConstFolder, FoldVal}, }; + use itertools::Itertools; use rstest::rstest; use strum::IntoEnumIterator; @@ -229,16 +218,11 @@ pub(crate) mod test { ) { use itertools::Itertools; - use crate::extension::ConstFold; - let in_vals = ins - .into_iter() - .enumerate() - .map(|(i, b)| (i.into(), Value::from_bool(b))) - .collect_vec(); - assert_eq!( - Some(vec![(0.into(), Value::from_bool(out))]), - op.fold(&[(in_vals.len() as u64).into()], &in_vals) - ); + let in_vals = ins.into_iter().map(FoldVal::from_bool).collect_vec(); + let type_args = [(in_vals.len() as u64).into()]; + let mut outs = [FoldVal::Unknown]; + op.fold(&type_args, &in_vals, &mut outs); + assert_eq!(outs, [FoldVal::from_bool(out)]); } #[rstest] @@ -254,18 +238,13 @@ pub(crate) mod test { #[case] ins: impl IntoIterator>, #[case] mb_out: Option, ) { - use itertools::Itertools; - - use crate::extension::ConstFold; - let in_vals0 = ins.into_iter().enumerate().collect_vec(); - let num_args = in_vals0.len() as u64; - let in_vals = in_vals0 + let in_vals = ins .into_iter() - .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b)))) + .map(|mb_b| mb_b.map_or(FoldVal::Unknown, FoldVal::from_bool)) .collect_vec(); - assert_eq!( - mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]), - op.fold(&[num_args.into()], &in_vals) - ); + let type_args = [(in_vals.len() as u64).into()]; + let mut outs = [FoldVal::Unknown]; + op.fold(&type_args, &in_vals, &mut outs); + assert_eq!(outs, [mb_out.map_or(FoldVal::Unknown, FoldVal::from_bool)]); } } diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 54765a38d..6c3b33e44 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -118,6 +118,7 @@ pub fn emit_value<'c, H: HugrView>( ) -> Result> { match v { Value::Extension { e } => context.emit_custom_const(e.value()), + #[allow(deprecated)] // Yay, will be able to remove this Value::Function { .. } => bail!( "Value::Function Const nodes are not supported. \ Ensure you eliminate these from the HUGR before lowering to LLVM. \ diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index 5e9cf158a..73bd64b22 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -1,3 +1,4 @@ +#![allow(deprecated)] // Remove pass when Value::Function is removed use hugr_core::{ HugrView, Node, NodeIndex as _, hugr::{hugrmut::HugrMut, internal::HugrMutInternals}, diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 11a92faa4..c9a7e18f1 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -3,11 +3,13 @@ //! An (example) use of the [dataflow analysis framework](super::dataflow). pub mod value_handle; +use itertools::Itertools; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, + extension::FoldVal, hugr::hugrmut::HugrMut, ops::{ Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, constant::OpaqueValue, @@ -17,7 +19,7 @@ use hugr_core::{ use value_handle::ValueHandle; use crate::dataflow::{ - ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination, + ConstLoader, ConstLocation, DFContext, Machine, PartialSum, PartialValue, TailLoopTermination, partial_from_const, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; @@ -141,6 +143,9 @@ impl + 'static> ComposablePass for ConstantFoldPass { (!hugr.get_optype(src).is_load_constant() && Some(src) != mb_root_inp).then_some(( n, ip, + // TODO switch to using FoldVal rather than Value here, thus enabling turning CallIndirect + // into Call when the function input is known. (This will mean we will be unable to handle + // Value::Function's, so best to wait until those are removed.) results .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, @@ -235,22 +240,43 @@ impl DFContext> for ConstFoldContext { outs: &mut [PartialValue>], ) { let sig = op.signature(); - let known_ins = sig + let inp_fvs = sig .input_types() .iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_concrete(ty) - .ok() - .map(|v| (IncomingPort::from(i), v)) - }) + .zip_eq(ins.iter()) + .map(|(ty, pv)| pv.clone().try_into_concrete(ty).unwrap_or_default()) .collect::>(); - for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { - outs[p.index()] = - partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v); + let mut out_fvs = vec![FoldVal::Unknown; outs.len()]; + op.const_fold(&inp_fvs, &mut out_fvs); + for ((p, out), out_fv) in outs.iter_mut().enumerate().zip(out_fvs) { + // UGH. Need a partial_from_const for FoldVal, *as well* as the one from Value + // 'coz we need to keep the latter for constants!! + *out = pv_from_fold_val(ConstLocation::Field(p, &node.into()), out_fv); + } + } +} + +fn pv_from_fold_val( + loc: ConstLocation, + value: FoldVal, +) -> PartialValue, Node> { + match value { + FoldVal::Unknown => PartialValue::Top, + FoldVal::Sum { + tag, + sum_type: _, + items, + } => PartialValue::PartialSum(PartialSum::new_variant( + tag, + items + .into_iter() + .enumerate() + .map(|(i, v)| pv_from_fold_val(ConstLocation::Field(i, &loc), v)), + )), + FoldVal::Extension(opaque_value) => { + PartialValue::Value(ValueHandle::new_opaque(loc, opaque_value)) } + FoldVal::LoadedFunction(node, args) => PartialValue::new_load(node, args), } } diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index 03464459e..8c445b95f 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -6,6 +6,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; +use hugr_core::extension::FoldVal; use hugr_core::ops::Value; use hugr_core::ops::constant::OpaqueValue; use hugr_core::types::ConstTypeError; @@ -168,6 +169,8 @@ impl AsConcrete, N> for Value { } => Value::Extension { e: Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone()), }, + #[allow(deprecated)] + // When we remove Value::Function, have to change `leaf` to be OpaqueValue only ValueHandle::Unhashable { leaf: Either::Right(hugr), .. @@ -186,6 +189,37 @@ impl AsConcrete, N> for Value { } } +impl AsConcrete, Node> for FoldVal { + type ValErr = Infallible; + + type SumErr = Infallible; + + fn from_value(val: ValueHandle) -> Result { + Ok(match val { + ValueHandle::Hashable(HashedConst { val, .. }) + | ValueHandle::Unhashable { + leaf: Either::Left(val), + .. + } => FoldVal::Extension(Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone())), + _ => FoldVal::Unknown, + }) + } + + fn from_sum(sum: Sum) -> Result { + let Sum { tag, values, st } = sum; + Ok(Self::Sum { + tag, + sum_type: st, + items: values, + }) + } + + fn from_func(func: LoadedFunction) -> Result> { + let LoadedFunction { func_node, args } = func; + Ok(FoldVal::LoadedFunction(func_node, args)) + } +} + #[cfg(test)] mod test { use hugr_core::{ diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a97901c61..d07dbb4d1 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -69,6 +69,7 @@ pub trait ConstLoader { /// Produces an abstract value from a Hugr in a [`Value::Function`], if possible. /// The default just returns `None`, which will be interpreted as [`PartialValue::Top`]. + #[deprecated(note = "Remove along with Value::Function")] fn value_from_const_hugr(&self, _loc: ConstLocation, _h: &Hugr) -> Option { None } @@ -98,6 +99,7 @@ where Value::Extension { e } => cl .value_from_opaque(loc, e) .map_or(PartialValue::Top, PartialValue::from), + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr } => cl .value_from_const_hugr(loc, hugr) .map_or(PartialValue::Top, PartialValue::from), diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ac19094c1..00ea6f010 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -516,6 +516,7 @@ impl ReplaceTypes { false } }), + #[allow(deprecated)] // remove when Value::Function removed Value::Function { hugr } => self.run(&mut **hugr), } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 25abb846b..2d8eff4e6 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -1,7 +1,9 @@ //! Callbacks for use with [`ReplaceTypes::replace_consts_parametrized`] //! and [`DelegatingLinearizer::register_callback`](super::DelegatingLinearizer::register_callback) -use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig}; +use hugr_core::builder::{ + DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig, inout_sig, +}; use hugr_core::extension::prelude::{UnwrapBuilder, option_type}; use hugr_core::ops::constant::CustomConst; use hugr_core::ops::{OpTrait, OpType, Tag}; @@ -110,25 +112,27 @@ pub fn linearize_generic_array( panic!("Illegal TypeArgs to array: {args:?}") }; if num_outports == 0 { - // "Simple" discard - first map each element to unit (via type-specific discard): - let map_fn = { - let mut dfb = DFGBuilder::new(inout_sig(ty.clone(), Type::UNIT)).unwrap(); - let [to_discard] = dfb.input_wires_arr(); - lin.copy_discard_op(ty, 0)? - .add(&mut dfb, [to_discard]) - .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; - let ret = dfb.add_load_value(Value::unary_unit_sum()); - dfb.finish_hugr_with_outputs([ret]).unwrap() - }; - // Now array.scan that over the input array to get an array of unit (which can be discarded) - let array_scan = GenericArrayScan::::new(ty.clone(), Type::UNIT, vec![], *n); - let in_type = AK::ty(*n, ty.clone()); return Ok(NodeTemplate::CompoundOp(Box::new({ + let in_type = AK::ty(*n, ty.clone()); let mut dfb = DFGBuilder::new(inout_sig(in_type, type_row![])).unwrap(); + // "Simple" discard - first map each element to unit (via type-specific discard): + let map_fn = { + let mut mb = dfb.module_root_builder(); + let mut fb = mb + .define_function("discard_elem", inout_sig(ty.clone(), Type::UNIT)) + .unwrap(); + let [to_discard] = fb.input_wires_arr(); + lin.copy_discard_op(ty, 0)? + .add(&mut fb, [to_discard]) + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; + let ret = fb.add_load_value(Value::unary_unit_sum()); + fb.finish_with_outputs([ret]).unwrap() + }; + // Now array.scan that over the input array to get an array of unit (which can be discarded) + let array_scan = GenericArrayScan::::new(ty.clone(), Type::UNIT, vec![], *n); + let [in_array] = dfb.input_wires_arr(); - let map_fn = dfb.add_load_value(Value::Function { - hugr: Box::new(map_fn), - }); + let map_fn = dfb.load_func(map_fn.handle(), &[]).unwrap(); // scan has one output, an array of unit, so just ignore/discard that let unit_arr = dfb .add_dataflow_op(array_scan, [in_array, map_fn]) @@ -153,14 +157,17 @@ pub fn linearize_generic_array( let option_ty = Type::from(option_sty.clone()); let arrays_of_none = { let fn_none = { - let mut dfb = DFGBuilder::new(inout_sig(vec![], option_ty.clone())).unwrap(); - let none = dfb + let mut mb = dfb.module_root_builder(); + let mut fb = mb + .define_function("mk_none", inout_sig(vec![], option_ty.clone())) + .unwrap(); + let none = fb .add_dataflow_op(Tag::new(0, vec![type_row![], ty.clone().into()]), []) .unwrap(); - dfb.finish_hugr_with_outputs(none.outputs()).unwrap() + fb.finish_with_outputs(none.outputs()).unwrap() }; let repeats = vec![GenericArrayRepeat::::new(option_ty.clone(), *n); num_new]; - let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); + let fn_none = dfb.load_func(fn_none.handle(), &[]).unwrap(); repeats .into_iter() .map(|rpt| { @@ -177,18 +184,21 @@ pub fn linearize_generic_array( let copy_elem = { let mut io = vec![ty.clone(), i64_t.clone()]; io.extend(vec![option_array.clone(); num_new]); - let mut dfb = DFGBuilder::new(endo_sig(io)).unwrap(); - let mut inputs = dfb.input_wires(); + let mut mb = dfb.module_root_builder(); + let mut fb = mb + .define_function(format!("copy{num_outports}"), endo_sig(io)) + .unwrap(); + let mut inputs = fb.input_wires(); let elem = inputs.next().unwrap(); let idx = inputs.next().unwrap(); let opt_arrays = inputs.collect::>(); - let [idx_usz] = dfb + let [idx_usz] = fb .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) .unwrap() .outputs_arr(); let mut copies = lin .copy_discard_op(ty, num_outports)? - .add(&mut dfb, [elem]) + .add(&mut fb, [elem]) .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly @@ -203,32 +213,32 @@ pub fn linearize_generic_array( .into_iter() .zip_eq(copies) .map(|(opt_array, copy1)| { - let [tag] = dfb + let [tag] = fb .add_dataflow_op(Tag::new(1, vec![type_row![], ty.clone().into()]), [copy1]) .unwrap() .outputs_arr(); - let [set_result] = dfb + let [set_result] = fb .add_dataflow_op(set_op.clone(), [opt_array, idx_usz, tag]) .unwrap() .outputs_arr(); // set should always be successful - let [none, opt_array] = dfb + let [none, opt_array] = fb .build_unwrap_sum(1, either_st.clone(), set_result) .unwrap(); //the removed element is an option, which should always be none (and thus discardable) - let [] = dfb + let [] = fb .build_unwrap_sum(0, SumType::new_option(ty.clone()), none) .unwrap(); opt_array }) .collect::>(); // stop borrowing dfb - let cst1 = dfb.add_load_value(ConstInt::new_u(6, 1).unwrap()); - let [new_idx] = dfb + let cst1 = fb.add_load_value(ConstInt::new_u(6, 1).unwrap()); + let [new_idx] = fb .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [idx, cst1]) .unwrap() .outputs_arr(); - dfb.finish_hugr_with_outputs([copy0, new_idx].into_iter().chain(opt_arrays)) + fb.finish_with_outputs([copy0, new_idx].into_iter().chain(opt_arrays)) .unwrap() }; let [in_array] = dfb.input_wires_arr(); @@ -241,7 +251,7 @@ pub fn linearize_generic_array( *n, ); - let copy_elem = dfb.add_load_value(Value::function(copy_elem).unwrap()); + let copy_elem = dfb.load_func(copy_elem.handle(), &[]).unwrap(); let cst0 = dfb.add_load_value(ConstInt::new_u(6, 0).unwrap()); let mut outs = dfb @@ -259,15 +269,20 @@ pub fn linearize_generic_array( //3. Scan each array-of-options, 'unwrapping' each element into a non-option let unwrap_elem = { - let mut dfb = - DFGBuilder::new(inout_sig(Type::from(option_ty.clone()), ty.clone())).unwrap(); + let mut mb = dfb.module_root_builder(); + let mut dfb = mb + .define_function( + "unwrap", + inout_sig(Type::from(option_ty.clone()), ty.clone()), + ) + .unwrap(); let [opt] = dfb.input_wires_arr(); let [val] = dfb.build_unwrap_sum(1, option_sty.clone(), opt).unwrap(); - dfb.finish_hugr_with_outputs([val]).unwrap() + dfb.finish_with_outputs([val]).unwrap() }; let unwrap_scan = GenericArrayScan::::new(option_ty.clone(), ty.clone(), vec![], *n); - let unwrap_elem = dfb.add_load_value(Value::function(unwrap_elem).unwrap()); + let unwrap_elem = dfb.load_func(unwrap_elem.handle(), &[]).unwrap(); let out_arrays = std::iter::once(out_array1) .chain(opt_arrays.map(|opt_array| { diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bd76b2801..6d8819326 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -367,8 +367,8 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - inout_sig, + BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, + HugrBuilder, inout_sig, }; use hugr_core::extension::prelude::{option_type, usize_t}; @@ -798,9 +798,9 @@ mod test { let lin_t: Type = lin_ct.clone().into(); // A simple Hugr that discards a usize_t, with a "drop" function - let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); + let mut fb = FunctionBuilder::new("main", inout_sig(usize_t(), type_row![])).unwrap(); let discard_fn = { - let mut mb = dfb.module_root_builder(); + let mut mb = fb.module_root_builder(); let mut fb = mb .define_function("drop", Signature::new(lin_t.clone(), type_row![])) .unwrap(); @@ -813,7 +813,7 @@ mod test { fb.finish_with_outputs([]).unwrap() } .node(); - let backup = dfb.finish_hugr().unwrap(); + let backup = fb.finish_hugr().unwrap(); let mut lower_discard_to_call = ReplaceTypes::default(); // The `copy_fn` here is just any random node, we don't use it @@ -842,16 +842,15 @@ mod test { ); let r = lower_discard_to_call.run(&mut backup.clone()); // Note the error (or success) can be quite fragile, according to what the `discard_fn` - // Node points at in the (hidden here) inner Hugr inside the Value::Function built by - // the array linearization helper. + // Node points at in the (hidden here) inner Hugr built by the array linearization helper. assert!(matches!( r, Err(ReplaceTypesError::LinearizeError( LinearizeError::NestedTemplateError( nested_t, - BuildError::NodeNotFound { node } // Note `..` would be somewhat less fragile + BuildError::UnexpectedType { .. } ) - )) if nested_t == lin_t && node == discard_fn + )) if nested_t == lin_t )); } }