Skip to content

feat: Introduce new enum for constant folding; deprecate Value::Function #2060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
692a759
No need for ConstFoldContext to impl Deref
acl-cqc Mar 31, 2025
48143f5
Add PartialValue::LoadedFunction
acl-cqc Apr 4, 2025
5809594
Docs, clippy::type_complexity
acl-cqc Apr 4, 2025
cccefe4
datalog handles LoadFunction, drop value_from_function (TODO ConstFol…
acl-cqc Apr 4, 2025
ca442c4
And handle CallIndirect
acl-cqc Apr 4, 2025
c4f4c80
Remove Hugr from ConstFoldContext, deparametrize/drop bounds
acl-cqc Apr 4, 2025
c6abad4
Add FoldVal, constant_fold2 (taking FoldVals)
acl-cqc Apr 4, 2025
f6196e5
Switch fold2 to take &[FoldVal] not Vec<FoldVal>
acl-cqc Apr 4, 2025
7c5466a
Deprecate old fold routines; port tests, add helpers
acl-cqc Apr 4, 2025
c9b8960
Deprecate Value::Function
acl-cqc Apr 4, 2025
a90dc97
First bit of test - default to Top if called-func unknown
acl-cqc Apr 4, 2025
6e2ad7f
Test2, fix port numbering and narrow unknown-called-func case
acl-cqc Apr 4, 2025
2cd8547
Test3, plus refactor test
acl-cqc Apr 4, 2025
ba33062
refactor into cases of rstest
acl-cqc Apr 4, 2025
5c9c6ec
Merge remote-tracking branch 'origin/main' into acl/dataflow_call_ind…
acl-cqc Apr 4, 2025
ee128ce
PartialSum default N=Node
acl-cqc Apr 4, 2025
2757a33
require TryFrom<LoadedFunction<N>,Error=LoadedFunction<N>>, format w/…
acl-cqc Apr 5, 2025
3c79739
Use PartialValue::new_load
acl-cqc Apr 7, 2025
1191150
Make call_indirect a lattice, add load_func
acl-cqc Apr 7, 2025
4ba5603
Drop Hugr from inside ConstFoldContext
acl-cqc Apr 8, 2025
4db3f5d
ci: Run ci checks on PRs to any branch
aborgna-q Apr 15, 2025
209b2ea
Merge branch 'main' into release-rs-v0.16.0
aborgna-q Apr 15, 2025
81447ec
feat!: Allow generic Nodes in HugrMut insert operations (#2075)
aborgna-q Apr 15, 2025
ef1cba0
fix!: Don't expose `HugrMutInternals` (#2071)
aborgna-q Apr 15, 2025
a8e4553
trait AsConcrete
acl-cqc Apr 8, 2025
f032848
LatticeWrapper requires only PartialEq+PartialOrd
acl-cqc Apr 15, 2025
fac6c8b
feat!: Mark all Error enums as non_exhaustive (#2056)
aborgna-q Apr 15, 2025
3066b65
PartialValue: Arbitrary+TestSumType etc. generates LoadedFunction
acl-cqc Apr 15, 2025
3eb0c7f
clippy, fix predecence
acl-cqc Apr 15, 2025
887a386
test LatticeWrapper
acl-cqc Apr 15, 2025
baaca02
feat!: Handle CallIndirect in Dataflow Analysis (#2059)
acl-cqc Apr 16, 2025
195f30c
Merge branch 'main' into release-rs-v0.16.0
aborgna-q Apr 16, 2025
6347756
feat: Make NodeHandle generic (#2092)
aborgna-q Apr 16, 2025
5b43c0d
feat!: remove ExtensionValue (#2093)
ss2165 Apr 17, 2025
89c2680
feat!: ComposablePass trait allowing sequencing and validation (#1895)
acl-cqc Apr 22, 2025
d8a5d67
feat!: ReplaceTypes: allow lowering ops into a Call to a function alr…
acl-cqc Apr 23, 2025
d352662
Merge commit 'baaca02359a307f8691ab3985313272339a8c494' into acl/data…
acl-cqc Apr 23, 2025
ce87f72
Merge remote-tracking branch 'origin/release-rs-v0.16.0' into acl/dat…
acl-cqc Apr 23, 2025
e147cbe
Merge branch 'acl/dataflow_call_indirect' into acl/foldval2 (define A…
acl-cqc Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pub trait Container {
}

/// Insert a copy of a HUGR as a child of the container.
fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult {
fn add_hugr_view<H: HugrView>(&mut self, child: &H) -> InsertionResult<H::Node, Node> {
let parent = self.container_node();
self.hugr_mut().insert_from_view(parent, child)
}
Expand Down
67 changes: 3 additions & 64 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ use derive_more::Display;
use thiserror::Error;

use crate::hugr::IdentList;
use crate::ops::constant::{ValueName, ValueNameRef};
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, OpName, OpNameRef};
use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
use crate::types::RowVariable;
use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName};
Expand All @@ -34,7 +33,7 @@ pub mod resolution;
pub mod simple_op;
mod type_def;

pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder};
pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, FoldVal, Folder};
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
Expand Down Expand Up @@ -378,6 +377,7 @@ pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
/// TODO: decide on failure modes
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum SignatureError {
/// Name mismatch
#[error("Definition name ({0}) and instantiation name ({1}) do not match.")]
Expand Down Expand Up @@ -496,37 +496,6 @@ impl CustomConcrete for CustomType {
}
}

/// A constant value provided by a extension.
/// Must be an instance of a type available to the extension.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ExtensionValue {
extension: ExtensionId,
name: ValueName,
typed_value: ops::Value,
}

impl ExtensionValue {
/// Returns a reference to the typed value of this [`ExtensionValue`].
pub fn typed_value(&self) -> &ops::Value {
&self.typed_value
}

/// Returns a mutable reference to the typed value of this [`ExtensionValue`].
pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value {
&mut self.typed_value
}

/// Returns a reference to the name of this [`ExtensionValue`].
pub fn name(&self) -> &str {
self.name.as_str()
}

/// Returns a reference to the extension this [`ExtensionValue`] belongs to.
pub fn extension(&self) -> &ExtensionId {
&self.extension
}
}

/// A unique identifier for a extension.
///
/// The actual [`Extension`] is stored externally.
Expand Down Expand Up @@ -582,8 +551,6 @@ pub struct Extension {
pub runtime_reqs: ExtensionSet,
/// Types defined by this extension.
types: BTreeMap<TypeName, TypeDef>,
/// Static values defined by this extension.
values: BTreeMap<ValueName, ExtensionValue>,
/// Operation declarations with serializable definitions.
// Note: serde will serialize this because we configure with `features=["rc"]`.
// That will clone anything that has multiple references, but each
Expand All @@ -607,7 +574,6 @@ impl Extension {
version,
runtime_reqs: Default::default(),
types: Default::default(),
values: Default::default(),
operations: Default::default(),
}
}
Expand Down Expand Up @@ -679,11 +645,6 @@ impl Extension {
self.types.get(type_name)
}

/// Allows read-only access to the values in this Extension
pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> {
self.values.get(value_name)
}

/// Returns the name of the extension.
pub fn name(&self) -> &ExtensionId {
&self.name
Expand All @@ -704,25 +665,6 @@ impl Extension {
self.types.iter()
}

/// Add a named static value to the extension.
pub fn add_value(
&mut self,
name: impl Into<ValueName>,
typed_value: ops::Value,
) -> Result<&mut ExtensionValue, ExtensionBuildError> {
let extension_value = ExtensionValue {
extension: self.name.clone(),
name: name.into(),
typed_value,
};
match self.values.entry(extension_value.name.clone()) {
btree_map::Entry::Occupied(_) => {
Err(ExtensionBuildError::ValueExists(extension_value.name))
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)),
}
}

/// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension.
pub fn instantiate_extension_op(
&self,
Expand Down Expand Up @@ -783,9 +725,6 @@ pub enum ExtensionBuildError {
/// Existing [`TypeDef`]
#[error("Extension already has an type called {0}.")]
TypeDefExists(TypeName),
/// Existing [`ExtensionValue`]
#[error("Extension already has an extension value called {0}.")]
ValueExists(ValueName),
}

/// A set of extensions identified by their unique [`ExtensionId`].
Expand Down
147 changes: 142 additions & 5 deletions hugr-core/src/extension/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,135 @@ use std::fmt::Formatter;

use std::fmt::Debug;

use crate::ops::constant::CustomConst;
use crate::ops::constant::{OpaqueValue, Sum};
use crate::ops::Value;
use crate::types::TypeArg;
use crate::types::{SumType, TypeArg};
use crate::{IncomingPort, Node, OutgoingPort, PortIndex};

use crate::IncomingPort;
use crate::OutgoingPort;
/// Representation of values used for constant folding.
// Should we be non-exhaustive??
// No point in parametrizing by HugrNode since then ConstFold would not be dyn/object-safe
#[derive(Clone, Debug, PartialEq, Default)]
pub enum FoldVal {
/// Value is unknown, must assume that it could be anything
#[default]
Unknown,
/// A variant of a [SumType]
Sum {
/// Which variant of the sum type this value is.
tag: usize,
/// Describes the type of the whole value.
// Can we deprecate this immediately? It is only for converting to Value
sum_type: SumType,
/// A value for each element (type) within the variant
items: Vec<FoldVal>,
},
/// A constant value defined by an extension
Extension(OpaqueValue),
/// A function pointer loaded from a [FuncDefn](crate::ops::FuncDefn) or `FuncDecl`
LoadedFunction(Node, Vec<TypeArg>), // Deliberately skipping Function(Box<Hugr>) ATM
}

impl<T> From<T> for FoldVal
where
T: CustomConst,
{
fn from(value: T) -> Self {
Self::Extension(value.into())
}
}

use crate::ops;
impl FoldVal {
/// Returns a constant "false" value, i.e. the first variant of Sum((), ()).
pub const fn false_val() -> Self {
Self::Sum {
tag: 0,
sum_type: SumType::Unit { size: 2 },
items: vec![],
}
}

/// Returns a constant "true" value, i.e. the second variant of Sum((), ()).
pub const fn true_val() -> Self {
Self::Sum {
tag: 1,
sum_type: SumType::Unit { size: 2 },
items: vec![],
}
}

/// Returns a constant boolean - either [Self::false_val] or [Self::true_val]
pub const fn from_bool(b: bool) -> Self {
if b {
Self::true_val()
} else {
Self::false_val()
}
}

/// Extract the specified type of [CustomConst] fro this instance, if it is one
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
let Self::Extension(e) = self else {
return None;
};
e.value().downcast_ref()
}
}

impl TryFrom<FoldVal> for Value {
type Error = Option<Node>;

fn try_from(value: FoldVal) -> Result<Self, Self::Error> {
match value {
FoldVal::Unknown => Err(None),
FoldVal::Sum {
tag,
sum_type,
items,
} => {
let values = items
.into_iter()
.map(Value::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(Value::Sum(Sum {
tag,
values,
sum_type,
}))
}
FoldVal::Extension(e) => Ok(Value::Extension { e }),
FoldVal::LoadedFunction(node, _) => Err(Some(node)),
}
}
}

impl From<Value> for FoldVal {
fn from(value: Value) -> Self {
match value {
Value::Extension { e } => FoldVal::Extension(e),
#[allow(deprecated)] // remove when Value::Function removed
Value::Function { .. } => FoldVal::Unknown,
Value::Sum(Sum {
tag,
values,
sum_type,
}) => {
let items = values.into_iter().map(FoldVal::from).collect();
FoldVal::Sum {
tag,
sum_type,
items,
}
}
}
}
}

/// Output of constant folding an operation, None indicates folding was either
/// not possible or unsuccessful. An empty vector indicates folding was
/// successful and no values are output.
pub type ConstFoldResult = Option<Vec<(OutgoingPort, ops::Value)>>;
pub type ConstFoldResult = Option<Vec<(OutgoingPort, Value)>>;

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult {
Expand All @@ -27,9 +144,29 @@ pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult

/// Trait implemented by extension operations that can perform constant folding.
pub trait ConstFold: Send + Sync {
/// Given type arguments `type_args` and [`FoldVal`]s for each input,
/// update the outputs (these will be initialized to [FoldVal::Unknown]).
///
/// Defaults to calling [Self::fold] with those arguments that can be converted ---
/// [FoldVal::LoadedFunction]s will be lost as these are not representable as [Value]s.
fn fold2(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) {
let consts = inputs
.iter()
.cloned()
.enumerate()
.filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?)))
.collect::<Vec<_>>();
#[allow(deprecated)] // remove this when fold is removed
let outs = self.fold(type_args, &consts);
for (p, v) in outs.unwrap_or_default() {
outputs[p.index()] = v.into();
}
}

/// Given type arguments `type_args` and
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s,
/// try to evaluate the operation.
#[deprecated(note = "Use fold2")]
fn fold(
&self,
type_args: &[TypeArg],
Expand Down
23 changes: 20 additions & 3 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::{Arc, Weak};

use super::const_fold::FoldVal;
use super::{
ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
SignatureError,
};

use crate::ops::{OpName, OpNameRef};
use crate::ops::{OpName, OpNameRef, Value};
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
use crate::Hugr;
use crate::{Hugr, IncomingPort};
mod serialize_signature_func;

/// Trait necessary for binary computations of OpDef signature
Expand Down Expand Up @@ -457,14 +458,30 @@ impl OpDef {

/// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s.
#[deprecated(note = "use constant_fold2")]
pub fn constant_fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Value)],
consts: &[(IncomingPort, Value)],
) -> ConstFoldResult {
#[allow(deprecated)] // we are in deprecated function, remove at same time
(self.constant_folder.as_ref())?.fold(type_args, consts)
}

/// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
/// [FoldVal] values for each input, and update the outputs, which should be
/// initialised to [FoldVal::Unknown].
pub fn constant_fold2(
&self,
type_args: &[TypeArg],
inputs: &[FoldVal],
outputs: &mut [FoldVal],
) {
if let Some(cf) = self.constant_folder.as_ref() {
cf.fold2(type_args, inputs, outputs)
}
}

/// Returns a reference to the signature function of this [`OpDef`].
pub fn signature_func(&self) -> &SignatureFunc {
&self.signature_func
Expand Down
17 changes: 10 additions & 7 deletions hugr-core/src/extension/prelude/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,14 @@ impl HasConcrete for LoadNatDef {
mod tests {
use crate::{
builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::{usize_t, ConstUsize},
ops::{constant, OpType},
extension::{
prelude::{usize_t, ConstUsize},
FoldVal,
},
ops::OpType,
type_row,
types::TypeArg,
HugrView, OutgoingPort,
HugrView,
};

use super::LoadNat;
Expand Down Expand Up @@ -209,10 +212,10 @@ mod tests {

match optype {
OpType::ExtensionOp(ext_op) => {
let result = ext_op.constant_fold(&[]);
let exp_port: OutgoingPort = 0.into();
let exp_val: constant::Value = ConstUsize::new(5).into();
assert_eq!(result, Some(vec![(exp_port, exp_val)]))
let mut out = [FoldVal::Unknown];
ext_op.constant_fold2(&[], &mut out);
let exp_val: FoldVal = ConstUsize::new(5).into();
assert_eq!(out, [exp_val])
}
_ => panic!(),
}
Expand Down
Loading
Loading