diff --git a/Cargo.lock b/Cargo.lock index 7198a4ea5..79f16982b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1212,6 +1212,7 @@ dependencies = [ "proptest", "proptest-recurse", "rstest", + "strum", "thiserror 2.0.12", ] diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 655507b09..3c8812f67 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -15,14 +15,15 @@ use std::sync::Arc; use delegate::delegate; use lazy_static::lazy_static; +use crate::builder::{BuildError, Dataflow}; use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry}; -use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; +use crate::extension::simple_op::{HasConcrete, MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; use crate::ops::constant::{CustomConst, ValueName}; use crate::ops::{ExtensionOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName}; -use crate::Extension; +use crate::{Extension, Wire}; pub use array_clone::{GenericArrayClone, GenericArrayCloneDef, ARRAY_CLONE_OP_ID}; pub use array_conversion::{Direction, GenericArrayConvert, GenericArrayConvertDef, FROM, INTO}; @@ -32,7 +33,8 @@ pub use array_op::{GenericArrayOp, GenericArrayOpDef}; pub use array_repeat::{GenericArrayRepeat, GenericArrayRepeatDef, ARRAY_REPEAT_OP_ID}; pub use array_scan::{GenericArrayScan, GenericArrayScanDef, ARRAY_SCAN_OP_ID}; pub use array_value::GenericArrayValue; -pub use op_builder::ArrayOpBuilder; + +use op_builder::GenericArrayOpBuilder; /// Reported unique name of the array type. pub const ARRAY_TYPENAME: TypeName = TypeName::new_inline("array"); @@ -170,6 +172,236 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { op.to_extension_op().unwrap() } +/// Trait for building array operations in a dataflow graph. +pub trait ArrayOpBuilder: GenericArrayOpBuilder { + /// Adds a new array operation to the dataflow graph and return the wire + /// representing the new array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `values` - An iterator over the values to initialize the array with. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wire representing the new array. + fn add_new_array( + &mut self, + elem_ty: Type, + values: impl IntoIterator, + ) -> Result { + self.add_new_generic_array::(elem_ty, values) + } + + /// Adds an array clone operation to the dataflow graph and return the wires + /// representing the originala and cloned array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wires representing the original and cloned array. + fn add_array_clone( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_clone::(elem_ty, size, input) + } + + /// Adds an array discard operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + fn add_array_discard( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(), BuildError> { + self.add_generic_array_discard::(elem_ty, size, input) + } + + /// Adds an array get operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to get. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// * The wire representing the value at the specified index in the array + /// * The wire representing the array + fn add_array_get( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_get::(elem_ty, size, input, index) + } + + /// Adds an array set operation to the dataflow graph. + /// + /// This operation sets the value at a specified index in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to set. + /// * `value` - The wire representing the value to set at the specified index. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the set operation. + fn add_array_set( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + value: Wire, + ) -> Result { + self.add_generic_array_set::(elem_ty, size, input, index, value) + } + + /// Adds an array swap operation to the dataflow graph. + /// + /// This operation swaps the values at two specified indices in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index1` - The wire representing the first index to swap. + /// * `index2` - The wire representing the second index to swap. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the swap operation. + fn add_array_swap( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index1: Wire, + index2: Wire, + ) -> Result { + let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self + .add_dataflow_op(op, vec![input, index1, index2])? + .outputs_arr(); + Ok(out) + } + + /// Adds an array pop-left operation to the dataflow graph. + /// + /// This operation removes the leftmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_left( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_left::(elem_ty, size, input) + } + + /// Adds an array pop-right operation to the dataflow graph. + /// + /// This operation removes the rightmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_right( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_right::(elem_ty, size, input) + } + + /// Adds an operation to discard an empty array from the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { + self.add_generic_array_discard_empty::(elem_ty, input) + } +} + +impl ArrayOpBuilder for D {} + #[cfg(test)] mod test { use crate::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}; diff --git a/hugr-core/src/std_extensions/collections/array/array_kind.rs b/hugr-core/src/std_extensions/collections/array/array_kind.rs index 61a8cd3ae..88c729d3c 100644 --- a/hugr-core/src/std_extensions/collections/array/array_kind.rs +++ b/hugr-core/src/std_extensions/collections/array/array_kind.rs @@ -1,10 +1,12 @@ use std::sync::Arc; +use crate::std_extensions::collections::array::op_builder::GenericArrayOpBuilder; use crate::{ + builder::{BuildError, Dataflow}, extension::{ExtensionId, SignatureError, TypeDef}, ops::constant::ValueName, types::{CustomType, Type, TypeArg, TypeName}, - Extension, + Extension, Wire, }; /// Trait capturing a concrete array implementation in an extension. @@ -90,4 +92,28 @@ pub trait ArrayKind: ) -> Result { Self::instantiate_ty(Self::type_def(), size, element_ty) } + + /// Adds a operation to a dataflow graph that clones an array of copyable values. + /// + /// The default implementation uses the array clone operation. + fn build_clone( + builder: &mut D, + elem_ty: Type, + size: u64, + arr: Wire, + ) -> Result<(Wire, Wire), BuildError> { + builder.add_generic_array_clone::(elem_ty, size, arr) + } + + /// Adds a operation to a dataflow graph that clones an array of copyable values. + /// + /// The default implementation uses the array clone operation. + fn build_discard( + builder: &mut D, + elem_ty: Type, + size: u64, + arr: Wire, + ) -> Result<(), BuildError> { + builder.add_generic_array_discard::(elem_ty, size, arr) + } } diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index 135c96cfa..2e2a7acb1 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -1,6 +1,7 @@ //! Builder trait for array operations in the dataflow graph. use crate::std_extensions::collections::array::GenericArrayOpDef; +use crate::std_extensions::collections::value_array::ValueArray; use crate::{ builder::{BuildError, Dataflow}, extension::simple_op::HasConcrete as _, @@ -11,8 +12,13 @@ use itertools::Itertools as _; use super::{Array, ArrayKind, GenericArrayClone, GenericArrayDiscard}; -/// Trait for building array operations in a dataflow graph. -pub trait ArrayOpBuilder: Dataflow { +use crate::extension::prelude::{ + either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _, +}; + +/// Trait for building array operations in a dataflow graph that are generic +/// over the concrete array implementation. +pub trait GenericArrayOpBuilder: Dataflow { /// Adds a new array operation to the dataflow graph and return the wire /// representing the new array. /// @@ -28,7 +34,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the new array. - fn add_new_array( + fn add_new_generic_array( &mut self, elem_ty: Type, values: impl IntoIterator, @@ -59,7 +65,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wires representing the original and cloned array. - fn add_array_clone( + fn add_generic_array_clone( &mut self, elem_ty: Type, size: u64, @@ -81,7 +87,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Errors /// /// If building the operation fails. - fn add_array_discard( + fn add_generic_array_discard( &mut self, elem_ty: Type, size: u64, @@ -109,7 +115,7 @@ pub trait ArrayOpBuilder: Dataflow { /// /// * The wire representing the value at the specified index in the array /// * The wire representing the array - fn add_array_get( + fn add_generic_array_get( &mut self, elem_ty: Type, size: u64, @@ -140,7 +146,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the updated array after the set operation. - fn add_array_set( + fn add_generic_array_set( &mut self, elem_ty: Type, size: u64, @@ -174,7 +180,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the updated array after the swap operation. - fn add_array_swap( + fn add_generic_array_swap( &mut self, elem_ty: Type, size: u64, @@ -206,7 +212,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the Option> - fn add_array_pop_left( + fn add_generic_array_pop_left( &mut self, elem_ty: Type, size: u64, @@ -233,7 +239,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the Option> - fn add_array_pop_right( + fn add_generic_array_pop_right( &mut self, elem_ty: Type, size: u64, @@ -253,7 +259,11 @@ pub trait ArrayOpBuilder: Dataflow { /// # Errors /// /// Returns an error if building the operation fails. - fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { + fn add_generic_array_discard_empty( + &mut self, + elem_ty: Type, + input: Wire, + ) -> Result<(), BuildError> { self.add_dataflow_op( GenericArrayOpDef::::discard_empty .instantiate(&[elem_ty.into()]) @@ -264,86 +274,111 @@ pub trait ArrayOpBuilder: Dataflow { } } -impl ArrayOpBuilder for D {} +impl GenericArrayOpBuilder for D {} -#[cfg(test)] -mod test { - use crate::extension::prelude::PRELUDE_ID; - use crate::extension::ExtensionSet; - use crate::std_extensions::collections::array::{self, array_type}; - use crate::{ - builder::{DFGBuilder, HugrBuilder}, - extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, - types::Signature, - Hugr, +/// Helper function to build a Hugr that contains all basic array operations. +/// +/// Generic over the concrete array implementation. +pub fn build_all_array_ops_generic(mut builder: B) -> B { + let us0 = builder.add_load_value(ConstUsize::new(0)); + let us1 = builder.add_load_value(ConstUsize::new(1)); + let us2 = builder.add_load_value(ConstUsize::new(2)); + let arr = builder + .add_new_generic_array::(usize_t(), [us1, us2]) + .unwrap(); + let [arr] = { + let r = builder + .add_generic_array_swap::(usize_t(), 2, arr, us0, us1) + .unwrap(); + let res_sum_ty = { + let array_type = AK::ty(2, usize_t()); + either_type(array_type.clone(), array_type) + }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() }; - use rstest::rstest; - use super::*; + let ([elem_0], arr) = { + let (r, arr) = builder + .add_generic_array_get::(usize_t(), 2, arr, us0) + .unwrap(); + ( + builder + .build_unwrap_sum(1, option_type(usize_t()), r) + .unwrap(), + arr, + ) + }; - #[rstest::fixture] - #[default(DFGBuilder)] - pub fn all_array_ops( - #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW) - .with_extension_delta(ExtensionSet::from_iter([ - PRELUDE_ID, - array::EXTENSION_ID - ]))).unwrap())] - mut builder: B, - ) -> B { - let us0 = builder.add_load_value(ConstUsize::new(0)); - let us1 = builder.add_load_value(ConstUsize::new(1)); - let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); - let [arr] = { - let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); - let res_sum_ty = { - let array_type = array_type(2, usize_t()); - either_type(array_type.clone(), array_type) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + let [_elem_1, arr] = { + let r = builder + .add_generic_array_set::(usize_t(), 2, arr, us1, elem_0) + .unwrap(); + let res_sum_ty = { + let row = vec![usize_t(), AK::ty(2, usize_t())]; + either_type(row.clone(), row) }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + }; - let ([elem_0], arr) = { - let (r, arr) = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); - ( - builder - .build_unwrap_sum(1, option_type(usize_t()), r) - .unwrap(), - arr, - ) - }; + let [_elem_left, arr] = { + let r = builder + .add_generic_array_pop_left::(usize_t(), 2, arr) + .unwrap(); + builder + .build_unwrap_sum(1, option_type(vec![usize_t(), AK::ty(1, usize_t())]), r) + .unwrap() + }; + let [_elem_right, arr] = { + let r = builder + .add_generic_array_pop_right::(usize_t(), 1, arr) + .unwrap(); + builder + .build_unwrap_sum(1, option_type(vec![usize_t(), AK::ty(0, usize_t())]), r) + .unwrap() + }; - let [_elem_1, arr] = { - let r = builder - .add_array_set(usize_t(), 2, arr, us1, elem_0) - .unwrap(); - let res_sum_ty = { - let row = vec![usize_t(), array_type(2, usize_t())]; - either_type(row.clone(), row) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; + builder + .add_generic_array_discard_empty::(usize_t(), arr) + .unwrap(); + builder +} - let [_elem_left, arr] = { - let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) - .unwrap() - }; - let [_elem_right, arr] = { - let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) - .unwrap() - }; +/// Helper function to build a Hugr that contains all basic array operations. +pub fn build_all_array_ops(builder: B) -> B { + build_all_array_ops_generic::(builder) +} - builder.add_array_discard_empty(usize_t(), arr).unwrap(); - builder +/// Helper function to build a Hugr that contains all basic array operations. +pub fn build_all_value_array_ops(builder: B) -> B { + build_all_array_ops_generic::(builder) +} + +/// Testing utilities to generate Hugrs that contain array operations. +#[cfg(test)] +mod test { + use crate::builder::{DFGBuilder, HugrBuilder}; + use crate::extension::prelude::PRELUDE_ID; + use crate::extension::ExtensionSet; + use crate::std_extensions::collections::array::{self}; + use crate::std_extensions::collections::value_array::{self}; + use crate::types::Signature; + + use super::*; + + #[test] + fn all_array_ops() { + let sig = Signature::new_endo(Type::EMPTY_TYPEROW) + .with_extension_delta(ExtensionSet::from_iter([PRELUDE_ID, array::EXTENSION_ID])); + let builder = DFGBuilder::new(sig).unwrap(); + build_all_array_ops(builder).finish_hugr().unwrap(); } - #[rstest] - fn build_all_ops(all_array_ops: DFGBuilder) { - all_array_ops.finish_hugr().unwrap(); + #[test] + fn all_value_array_ops() { + let sig = Signature::new_endo(Type::EMPTY_TYPEROW).with_extension_delta( + ExtensionSet::from_iter([PRELUDE_ID, value_array::EXTENSION_ID]), + ); + let builder = DFGBuilder::new(sig).unwrap(); + build_all_value_array_ops(builder).finish_hugr().unwrap(); } } diff --git a/hugr-core/src/std_extensions/collections/value_array.rs b/hugr-core/src/std_extensions/collections/value_array.rs index 15c8359a1..20a20204d 100644 --- a/hugr-core/src/std_extensions/collections/value_array.rs +++ b/hugr-core/src/std_extensions/collections/value_array.rs @@ -8,14 +8,16 @@ use std::sync::Arc; use delegate::delegate; use lazy_static::lazy_static; +use crate::builder::{BuildError, Dataflow}; use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry}; -use crate::extension::simple_op::MakeOpDef; +use crate::extension::simple_op::{HasConcrete, MakeOpDef}; use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; use crate::ops::constant::{CustomConst, ValueName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName}; -use crate::Extension; +use crate::{Extension, Wire}; +use super::array::op_builder::GenericArrayOpBuilder; use super::array::{ Array, ArrayKind, GenericArrayConvert, GenericArrayConvertDef, GenericArrayOp, GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef, GenericArrayScan, @@ -49,6 +51,24 @@ impl ArrayKind for ValueArray { fn type_def() -> &'static TypeDef { EXTENSION.get_type(&VALUE_ARRAY_TYPENAME).unwrap() } + + fn build_clone( + _builder: &mut D, + _elem_ty: Type, + _size: u64, + arr: Wire, + ) -> Result<(Wire, Wire), BuildError> { + Ok((arr, arr)) + } + + fn build_discard( + _builder: &mut D, + _elem_ty: Type, + _size: u64, + _arr: Wire, + ) -> Result<(), BuildError> { + Ok(()) + } } /// Value array operation definitions. @@ -142,3 +162,189 @@ pub fn value_array_type_parametric( ) -> Result { ValueArray::ty_parametric(size, element_ty) } + +/// Trait for building value array operations in a dataflow graph. +pub trait VArrayOpBuilder: GenericArrayOpBuilder { + /// Adds a new array operation to the dataflow graph and return the wire + /// representing the new array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `values` - An iterator over the values to initialize the array with. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wire representing the new array. + fn add_new_value_array( + &mut self, + elem_ty: Type, + values: impl IntoIterator, + ) -> Result { + self.add_new_generic_array::(elem_ty, values) + } + + /// Adds an array get operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to get. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// * The wire representing the value at the specified index in the array + /// * The wire representing the array + fn add_value_array_get( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_get::(elem_ty, size, input, index) + } + + /// Adds an array set operation to the dataflow graph. + /// + /// This operation sets the value at a specified index in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to set. + /// * `value` - The wire representing the value to set at the specified index. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the set operation. + fn add_value_array_set( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + value: Wire, + ) -> Result { + self.add_generic_array_set::(elem_ty, size, input, index, value) + } + + /// Adds an array swap operation to the dataflow graph. + /// + /// This operation swaps the values at two specified indices in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index1` - The wire representing the first index to swap. + /// * `index2` - The wire representing the second index to swap. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the swap operation. + fn add_value_array_swap( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index1: Wire, + index2: Wire, + ) -> Result { + let op = + GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self + .add_dataflow_op(op, vec![input, index1, index2])? + .outputs_arr(); + Ok(out) + } + + /// Adds an array pop-left operation to the dataflow graph. + /// + /// This operation removes the leftmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_left( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_left::(elem_ty, size, input) + } + + /// Adds an array pop-right operation to the dataflow graph. + /// + /// This operation removes the rightmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_right( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_right::(elem_ty, size, input) + } + + /// Adds an operation to discard an empty array from the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { + self.add_generic_array_discard_empty::(elem_ty, input) + } +} + +impl VArrayOpBuilder for D {} diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 67f89e001..4f0881a23 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -892,6 +892,7 @@ mod test { use hugr_core::extension::prelude::either_type; use hugr_core::extension::ExtensionSet; use hugr_core::ops::Tag; + use hugr_core::std_extensions::collections::array::op_builder::build_all_array_ops; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; use hugr_core::std_extensions::STD_REG; use hugr_core::types::Type; @@ -922,66 +923,12 @@ mod test { utils::{ArrayOpBuilder, IntOpBuilder, LogicOpBuilder}, }; - /// Build all array ops - /// Copied from `hugr_core::std_extensions::collections::array::builder::test` - fn all_array_ops(mut builder: B) -> B { - let us0 = builder.add_load_value(ConstUsize::new(0)); - let us1 = builder.add_load_value(ConstUsize::new(1)); - let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); - let [arr] = { - let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); - let res_sum_ty = { - let array_type = array_type(2, usize_t()); - either_type(array_type.clone(), array_type) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; - - let ([elem_0], arr) = { - let (r, arr) = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); - ( - builder - .build_unwrap_sum(1, option_type(usize_t()), r) - .unwrap(), - arr, - ) - }; - - let [_elem_1, arr] = { - let r = builder - .add_array_set(usize_t(), 2, arr, us1, elem_0) - .unwrap(); - let res_sum_ty = { - let row = vec![usize_t(), array_type(2, usize_t())]; - either_type(row.clone(), row) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; - - let [_elem_left, arr] = { - let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) - .unwrap() - }; - let [_elem_right, arr] = { - let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) - .unwrap() - }; - - builder.add_array_discard_empty(usize_t(), arr).unwrap(); - builder - } - #[rstest] fn emit_all_ops(mut llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { - all_array_ops(builder.dfg_builder_endo([]).unwrap()) + build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); builder.finish_sub_container().unwrap() diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 7a2c50367..43aa6376b 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -25,6 +25,7 @@ lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } +strum = { workspace = true } [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 83ff71b67..d803b817c 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -11,6 +11,8 @@ mod dead_funcs; pub use dead_funcs::{remove_dead_funcs, RemoveDeadFuncsError, RemoveDeadFuncsPass}; pub mod force_order; mod half_node; +pub mod linearize_array; +pub use linearize_array::LinearizeArrayPass; pub mod lower; pub mod merge_bbs; mod monomorphize; diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs new file mode 100644 index 000000000..fff37d326 --- /dev/null +++ b/hugr-passes/src/linearize_array.rs @@ -0,0 +1,432 @@ +//! Provides [LinearizeArrayPass] which turns 'value_array`s into regular linear `array`s. + +use hugr_core::{ + extension::{ + prelude::Noop, + simple_op::{HasConcrete, MakeRegisteredOp}, + }, + hugr::hugrmut::HugrMut, + ops::NamedOp, + std_extensions::collections::{ + array::{ + array_type_def, array_type_parametric, Array, ArrayKind, ArrayOpDef, ArrayRepeatDef, + ArrayScanDef, ArrayValue, ARRAY_REPEAT_OP_ID, ARRAY_SCAN_OP_ID, + }, + value_array::{self, VArrayFromArrayDef, VArrayToArrayDef, VArrayValue, ValueArray}, + }, + types::Transformable, + Node, +}; +use itertools::Itertools; +use strum::IntoEnumIterator; + +use crate::{ + replace_types::{ + handlers::copy_discard_array, DelegatingLinearizer, NodeTemplate, ReplaceTypesError, + }, + ComposablePass, ReplaceTypes, +}; + +/// A HUGR -> HUGR pass that turns 'value_array`s into regular linear `array`s. +/// +/// # Panics +/// +/// - If the Hugr has inter-graph edges whose type contains `value_array`s +/// - If the Hugr contains [`ArrayOpDef::get`] operations on `value_array`s that +/// contain nested `value_array`s. +#[derive(Clone)] +pub struct LinearizeArrayPass(ReplaceTypes); + +impl Default for LinearizeArrayPass { + fn default() -> Self { + let mut pass = ReplaceTypes::default(); + pass.replace_parametrized_type(ValueArray::type_def(), |args| { + Some(Array::ty_parametric(args[0].clone(), args[1].clone()).unwrap()) + }); + pass.replace_consts_parametrized(ValueArray::type_def(), |v, replacer| { + let v: &VArrayValue = v.value().downcast_ref().unwrap(); + let mut ty = v.get_element_type().clone(); + let mut contents = v.get_contents().iter().cloned().collect_vec(); + ty.transform(replacer).unwrap(); + contents.iter_mut().for_each(|v| { + replacer.change_value(v).unwrap(); + }); + Ok(Some(ArrayValue::new(ty, contents).into())) + }); + for op_def in ArrayOpDef::iter() { + pass.replace_parametrized_op( + value_array::EXTENSION.get_op(&op_def.name()).unwrap(), + move |args| { + // `get` is only allowed for copyable elements. Assuming the Hugr was + // valid when we started, the only way for the element to become linear + // is if it used to contain nested `value_array`s. In that case, we + // have to get rid of the `get`. + // TODO: But what should we replace it with? Can't be a `set` since we + // don't have anything to put in. Maybe we need a new `get_copy` op + // that takes a function ptr to copy the element? For now, let's just + // error out and make sure we're not emitting `get`s for nested value + // arrays. + if op_def == ArrayOpDef::get && !args[1].as_type().unwrap().copyable() { + panic!( + "Cannot linearise arrays in this Hugr: \ + Contains a `get` operation on nested value arrays" + ); + } + Some(NodeTemplate::SingleOp( + op_def.instantiate(args).unwrap().into(), + )) + }, + ); + } + pass.replace_parametrized_op( + value_array::EXTENSION.get_op(&ARRAY_REPEAT_OP_ID).unwrap(), + |args| { + Some(NodeTemplate::SingleOp( + ArrayRepeatDef::new().instantiate(args).unwrap().into(), + )) + }, + ); + pass.replace_parametrized_op( + value_array::EXTENSION.get_op(&ARRAY_SCAN_OP_ID).unwrap(), + |args| { + Some(NodeTemplate::SingleOp( + ArrayScanDef::new().instantiate(args).unwrap().into(), + )) + }, + ); + pass.replace_parametrized_op( + value_array::EXTENSION + .get_op(&VArrayFromArrayDef::new().name()) + .unwrap(), + |args| { + let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); + Some(NodeTemplate::SingleOp( + Noop::new(array_ty).to_extension_op().unwrap().into(), + )) + }, + ); + pass.replace_parametrized_op( + value_array::EXTENSION + .get_op(&VArrayToArrayDef::new().name()) + .unwrap(), + |args| { + let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); + Some(NodeTemplate::SingleOp( + Noop::new(array_ty).to_extension_op().unwrap().into(), + )) + }, + ); + pass.linearizer() + .register_callback(array_type_def(), copy_discard_array); + Self(pass) + } +} + +impl ComposablePass for LinearizeArrayPass { + type Node = Node; + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr) + } +} + +impl LinearizeArrayPass { + /// Returns a new [`LinearizeArrayPass`] that handles all standard extensions. + pub fn new() -> Self { + Self::default() + } + + /// Allows to configure how to clone and discard arrays that are nested + /// inside opaque extension values. + pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { + self.0.linearizer() + } +} + +#[cfg(test)] +mod test { + use hugr_core::builder::ModuleBuilder; + use hugr_core::extension::prelude::{ConstUsize, Noop}; + use hugr_core::ops::handle::NodeHandle; + use hugr_core::ops::{Const, OpType}; + use hugr_core::std_extensions::collections::array::{ + self, array_type, ArrayValue, Direction, FROM, INTO, + }; + use hugr_core::std_extensions::collections::value_array::{ + VArrayFromArray, VArrayRepeat, VArrayScan, VArrayToArray, VArrayValue, + }; + use hugr_core::types::Transformable; + use hugr_core::{ + builder::{Container, DFGBuilder, Dataflow, HugrBuilder}, + extension::{ + prelude::{qb_t, usize_t, PRELUDE_ID}, + ExtensionSet, + }, + std_extensions::collections::{ + array::{ + op_builder::{build_all_array_ops, build_all_value_array_ops}, + ArrayRepeat, ArrayScan, + }, + value_array::{self, value_array_type}, + }, + types::{Signature, Type}, + HugrView, + }; + use itertools::Itertools; + use rstest::rstest; + + use crate::{composable::ValidatingPass, ComposablePass}; + + use super::LinearizeArrayPass; + + #[test] + fn all_value_array_ops() { + let sig = Signature::new_endo(Type::EMPTY_TYPEROW).with_extension_delta( + ExtensionSet::from_iter([PRELUDE_ID, value_array::EXTENSION_ID, array::EXTENSION_ID]), + ); + let mut hugr = build_all_value_array_ops(DFGBuilder::new(sig.clone()).unwrap()) + .finish_hugr() + .unwrap(); + ValidatingPass::new_default(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + + let target_hugr = build_all_array_ops(DFGBuilder::new(sig).unwrap()) + .finish_hugr() + .unwrap(); + for (n1, n2) in hugr.nodes().zip_eq(target_hugr.nodes()) { + assert_eq!(hugr.get_optype(n1), target_hugr.get_optype(n2)); + } + } + + #[rstest] + #[case(usize_t(), 2)] + #[case(qb_t(), 2)] + #[case(value_array_type(4, usize_t()), 2)] + fn repeat(#[case] elem_ty: Type, #[case] size: u64) { + let reqs = + ExtensionSet::from_iter([PRELUDE_ID, value_array::EXTENSION_ID, array::EXTENSION_ID]); + let mut builder = ModuleBuilder::new(); + let repeat_decl = builder + .declare( + "foo", + Signature::new(Type::EMPTY_TYPEROW, elem_ty.clone()) + .with_extension_delta(reqs.clone()) + .into(), + ) + .unwrap(); + let mut f = builder + .define_function( + "bar", + Signature::new(Type::EMPTY_TYPEROW, value_array_type(size, elem_ty.clone())) + .with_extension_delta(reqs.clone()), + ) + .unwrap(); + let repeat_f = f.load_func(&repeat_decl, &[]).unwrap(); + let repeat = f + .add_dataflow_op( + VArrayRepeat::new(elem_ty.clone(), size, reqs.clone()), + [repeat_f], + ) + .unwrap(); + let [arr] = repeat.outputs_arr(); + f.set_outputs([arr]).unwrap(); + let mut hugr = builder.finish_hugr().unwrap(); + + let pass = LinearizeArrayPass::default(); + ValidatingPass::new_default(pass.clone()) + .run(&mut hugr) + .unwrap(); + let new_repeat: ArrayRepeat = hugr.get_optype(repeat.node()).cast().unwrap(); + let mut new_elem_ty = elem_ty.clone(); + new_elem_ty.transform(&pass.0).unwrap(); + assert_eq!(new_repeat, ArrayRepeat::new(new_elem_ty, size, reqs)); + } + + #[rstest] + #[case(usize_t(), qb_t(), 2)] + #[case(usize_t(), value_array_type(4, usize_t()), 2)] + #[case(value_array_type(4, usize_t()), value_array_type(8, usize_t()), 2)] + fn scan(#[case] src_ty: Type, #[case] tgt_ty: Type, #[case] size: u64) { + let reqs = + ExtensionSet::from_iter([PRELUDE_ID, value_array::EXTENSION_ID, array::EXTENSION_ID]); + let mut builder = ModuleBuilder::new(); + let scan_decl = builder + .declare( + "foo", + Signature::new(src_ty.clone(), tgt_ty.clone()) + .with_extension_delta(reqs.clone()) + .into(), + ) + .unwrap(); + let mut f = builder + .define_function( + "bar", + Signature::new( + value_array_type(size, src_ty.clone()), + value_array_type(size, tgt_ty.clone()), + ) + .with_extension_delta(reqs.clone()), + ) + .unwrap(); + let [arr] = f.input_wires_arr(); + let scan_f = f.load_func(&scan_decl, &[]).unwrap(); + let scan = f + .add_dataflow_op( + VArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, reqs.clone()), + [arr, scan_f], + ) + .unwrap(); + let [arr] = scan.outputs_arr(); + f.set_outputs([arr]).unwrap(); + let mut hugr = builder.finish_hugr().unwrap(); + + let pass = LinearizeArrayPass::default(); + ValidatingPass::new_default(pass.clone()) + .run(&mut hugr) + .unwrap(); + let new_scan: ArrayScan = hugr.get_optype(scan.node()).cast().unwrap(); + let mut new_src_ty = src_ty.clone(); + let mut new_tgt_ty = tgt_ty.clone(); + new_src_ty.transform(&pass.0).unwrap(); + new_tgt_ty.transform(&pass.0).unwrap(); + + assert_eq!( + new_scan, + ArrayScan::new(new_src_ty, new_tgt_ty, vec![], size, reqs) + ); + } + + #[rstest] + #[case(INTO, usize_t(), 2)] + #[case(FROM, usize_t(), 2)] + #[case(INTO, array_type(4, usize_t()), 2)] + #[case(FROM, array_type(4, usize_t()), 2)] + #[case(INTO, value_array_type(4, usize_t()), 2)] + #[case(FROM, value_array_type(4, usize_t()), 2)] + fn convert(#[case] dir: Direction, #[case] elem_ty: Type, #[case] size: u64) { + let (src, tgt) = match dir { + INTO => ( + value_array_type(size, elem_ty.clone()), + array_type(size, elem_ty.clone()), + ), + FROM => ( + array_type(size, elem_ty.clone()), + value_array_type(size, elem_ty.clone()), + ), + }; + let sig = Signature::new(src, tgt).with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + value_array::EXTENSION_ID, + array::EXTENSION_ID, + ])); + let mut builder = DFGBuilder::new(sig).unwrap(); + let [arr] = builder.input_wires_arr(); + let op: OpType = match dir { + INTO => VArrayToArray::new(elem_ty.clone(), size).into(), + FROM => VArrayFromArray::new(elem_ty.clone(), size).into(), + }; + let convert = builder.add_dataflow_op(op, [arr]).unwrap(); + let [arr] = convert.outputs_arr(); + builder.set_outputs(vec![arr]).unwrap(); + let mut hugr = builder.finish_hugr().unwrap(); + + let pass = LinearizeArrayPass::default(); + ValidatingPass::new_default(pass.clone()) + .run(&mut hugr) + .unwrap(); + let new_convert: Noop = hugr.get_optype(convert.node()).cast().unwrap(); + let mut new_elem_ty = elem_ty.clone(); + new_elem_ty.transform(&pass.0).unwrap(); + + assert_eq!(new_convert, Noop::new(array_type(size, new_elem_ty))); + } + + #[rstest] + #[case(value_array_type(2, usize_t()))] + #[case(value_array_type(2, value_array_type(4, usize_t())))] + #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] + fn implicit_clone(#[case] array_ty: Type) { + let sig = Signature::new(array_ty.clone(), vec![array_ty; 2]).with_extension_delta( + ExtensionSet::from_iter([PRELUDE_ID, value_array::EXTENSION_ID, array::EXTENSION_ID]), + ); + let mut builder = DFGBuilder::new(sig).unwrap(); + let [arr] = builder.input_wires_arr(); + builder.set_outputs(vec![arr, arr]).unwrap(); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new_default(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + } + + #[rstest] + #[case(value_array_type(2, usize_t()))] + #[case(value_array_type(2, value_array_type(4, usize_t())))] + #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] + fn implicit_discard(#[case] array_ty: Type) { + let sig = Signature::new(array_ty, Type::EMPTY_TYPEROW).with_extension_delta( + ExtensionSet::from_iter([PRELUDE_ID, value_array::EXTENSION_ID, array::EXTENSION_ID]), + ); + let mut builder = DFGBuilder::new(sig).unwrap(); + builder.set_outputs(vec![]).unwrap(); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new_default(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + } + + #[test] + fn array_value() { + let mut builder = ModuleBuilder::new(); + let array_v = VArrayValue::new(usize_t(), vec![ConstUsize::new(1).into()]); + let c = builder.add_constant(Const::new(array_v.clone().into())); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new_default(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + + let new_array_v: &ArrayValue = hugr + .get_optype(c.node()) + .as_const() + .unwrap() + .get_custom_value() + .unwrap(); + + assert_eq!(new_array_v.get_element_type(), array_v.get_element_type()); + assert_eq!(new_array_v.get_contents(), array_v.get_contents()); + } + + #[test] + fn array_value_nested() { + let mut builder = ModuleBuilder::new(); + let array_v_inner = VArrayValue::new(usize_t(), vec![ConstUsize::new(1).into()]); + let array_v: array::GenericArrayValue = VArrayValue::new( + value_array_type(1, usize_t()), + vec![array_v_inner.clone().into()], + ); + let c = builder.add_constant(Const::new(array_v.clone().into())); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new_default(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + + let new_array_v: &ArrayValue = hugr + .get_optype(c.node()) + .as_const() + .unwrap() + .get_custom_value() + .unwrap(); + + assert_eq!(new_array_v.get_element_type(), &array_type(1, usize_t())); + assert_eq!( + new_array_v.get_contents()[0], + ArrayValue::new(usize_t(), vec![ConstUsize::new(1).into()]).into() + ); + } +} diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index b16e0d4b6..0b0452269 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -10,11 +10,12 @@ use hugr_core::ops::{OpTrait, OpType, Tag}; use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use hugr_core::std_extensions::collections::array::{Array, ArrayKind, GenericArrayValue}; -use hugr_core::std_extensions::collections::list::ListValue; -use hugr_core::std_extensions::collections::value_array::{ - value_array_type, VArrayOpDef, VArrayRepeat, VArrayScan, ValueArray, +use hugr_core::std_extensions::collections::array::{ + array_type, Array, ArrayClone, ArrayDiscard, ArrayKind, ArrayOpBuilder, GenericArrayOpDef, + GenericArrayRepeat, GenericArrayScan, GenericArrayValue, }; +use hugr_core::std_extensions::collections::list::ListValue; +use hugr_core::std_extensions::collections::value_array::ValueArray; use hugr_core::types::{SumType, Transformable, Type, TypeArg}; use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; @@ -100,11 +101,10 @@ fn runtime_reqs(h: &Hugr) -> ExtensionSet { h.signature(h.root()).unwrap().runtime_reqs.clone() } -/// Handler for copying/discarding value arrays if their elements have become linear. -/// Included in [ReplaceTypes::default] and [DelegatingLinearizer::default]. +/// Handler for copying/discarding arrays if their elements have become linear. /// -/// [DelegatingLinearizer::default]: super::DelegatingLinearizer::default -pub fn linearize_value_array( +/// Generic over the concrete array implementation. +pub fn linearize_generic_array( args: &[TypeArg], num_outports: usize, lin: &CallbackHandler, @@ -126,8 +126,9 @@ pub fn linearize_value_array( dfb.finish_hugr_with_outputs([ret]).unwrap() }; // Now array.scan that over the input array to get an array of unit (which can be discarded) - let array_scan = VArrayScan::new(ty.clone(), Type::UNIT, vec![], *n, runtime_reqs(&map_fn)); - let in_type = value_array_type(*n, ty.clone()); + let array_scan = + GenericArrayScan::::new(ty.clone(), Type::UNIT, vec![], *n, runtime_reqs(&map_fn)); + let in_type = AK::ty(*n, ty.clone()); return Ok(NodeTemplate::CompoundOp(Box::new({ let mut dfb = DFGBuilder::new(inout_sig(in_type, type_row![])).unwrap(); let [in_array] = dfb.input_wires_arr(); @@ -135,14 +136,18 @@ pub fn linearize_value_array( hugr: Box::new(map_fn), }); // scan has one output, an array of unit, so just ignore/discard that - dfb.add_dataflow_op(array_scan, [in_array, map_fn]).unwrap(); + let unit_arr = dfb + .add_dataflow_op(array_scan, [in_array, map_fn]) + .unwrap() + .out_wire(0); + AK::build_discard(&mut dfb, Type::UNIT, *n, unit_arr).unwrap(); dfb.finish_hugr_with_outputs([]).unwrap() }))); }; // The num_outports>1 case will simplify, and unify with the previous, when we have a // more general ArrayScan https://github.com/CQCL/hugr/issues/2041. In the meantime: let num_new = num_outports - 1; - let array_ty = value_array_type(*n, ty.clone()); + let array_ty = AK::ty(*n, ty.clone()); let mut dfb = DFGBuilder::new(inout_sig( array_ty.clone(), vec![array_ty.clone(); num_outports], @@ -161,7 +166,10 @@ pub fn linearize_value_array( dfb.finish_hugr_with_outputs(none.outputs()).unwrap() }; let repeats = - vec![VArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); num_new]; + vec![ + GenericArrayRepeat::::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); + num_new + ]; let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); repeats .into_iter() @@ -175,7 +183,7 @@ pub fn linearize_value_array( // 2. use a scan through the input array, copying the element num_outputs times; // return the first copy, and put each of the other copies into one of the array