diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index fb72b86f8..480975fab 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -1,39 +1,47 @@ #![allow(clippy::type_complexity)] #![warn(missing_docs)] -//! Replace types with other types across the Hugr. +//! Replace types with other types across the Hugr. See [ReplaceTypes] and [Linearizer]. //! -//! Parametrized types and ops will be reparametrized taking into account the replacements, -//! but any ops taking/returning the replaced types *not* as a result of parametrization, -//! will also need to be replaced - see [ReplaceTypes::replace_op]. (Similarly [Const]s.) +use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; use thiserror::Error; +use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; +use hugr_core::ops::handle::DataflowOpID; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, - FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, TailLoop, Value, - CFG, DFG, + FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, + Value, CFG, DFG, }; -use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; -use hugr_core::{Hugr, Node}; +use hugr_core::types::{ + CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, +}; +use hugr_core::{Hugr, HugrView, Node, Wire}; use crate::validation::{ValidatePassError, ValidationLevel}; -/// A thing with which an Op (i.e. node) can be replaced +mod linearize; +pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; + +/// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] +/// or in order to replace an existing node. +/// +/// [DataflowParent]: hugr_core::ops::OpTag::DataflowParent #[derive(Clone, Debug, PartialEq)] -pub enum OpReplacement { - /// Keep the same node, change only the op (updating types of inputs/outputs) +pub enum NodeTemplate { + /// A single node - so if replacing an existing node, change only the op SingleOp(OpType), - /// Defines a sub-Hugr to splice in place of the op - a [CFG], [Conditional], [DFG] - /// or [TailLoop], which must have the same inputs and outputs as the original op, - /// modulo replacement. + /// Defines a sub-Hugr to insert, whose root becomes (or replaces) the desired Node. + /// The root must be a [CFG], [Conditional], [DFG] or [TailLoop]. // Not a FuncDefn, nor Case/DataflowBlock - /// Note this will be of limited use before [monomorphization](super::monomorphize()) because - /// the sub-Hugr will not be able to use type variables present in the op. + /// Note this will be of limited use before [monomorphization](super::monomorphize()) + /// because the new subtree will not be able to use type variables present in the + /// parent Hugr or previous op. // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), @@ -42,12 +50,33 @@ pub enum OpReplacement { // So client should add the functions before replacement, then remove unused ones afterwards.) } -impl OpReplacement { +impl NodeTemplate { + /// Adds this instance to the specified [HugrMut] as a new node or subtree under a + /// given parent, returning the unique new child (of that parent) thus created + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + match self { + NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), + NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + } + } + + /// Adds this instance to the specified [Dataflow] builder as a new node or subtree + pub fn add( + self, + dfb: &mut impl Dataflow, + inputs: impl IntoIterator, + ) -> Result, BuildError> { + match self { + NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), + NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + } + } + fn replace(&self, hugr: &mut impl HugrMut, n: Node) { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { - OpReplacement::SingleOp(op_type) => op_type, - OpReplacement::CompoundOp(new_h) => { + NodeTemplate::SingleOp(op_type) => op_type, + NodeTemplate::CompoundOp(new_h) => { let new_root = hugr.insert_hugr(n, *new_h).new_root; let children = hugr.children(new_root).collect::>(); let root_opty = hugr.remove_node(new_root); @@ -59,16 +88,50 @@ impl OpReplacement { }; *hugr.optype_mut(n) = new_optype; } + + fn signature(&self) -> Option> { + match self { + NodeTemplate::SingleOp(op_type) => op_type, + NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + } + .dataflow_signature() + } } /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. +/// +/// Parametrized types and ops will be reparametrized taking into account the +/// replacements, but any ops taking/returning the replaced types *not* as a result of +/// parametrization, will also need to be replaced - see [Self::replace_op]. +/// Similarly [Const]s. +/// +/// Types that are [Copyable](hugr_core::types::TypeBound::Copyable) may also be replaced +/// with types that are not, see [Linearizer]. +/// +/// Note that although this pass may be used before [monomorphization], there are some +/// limitations (that do not apply if done after [monomorphization]): +/// * [NodeTemplate::CompoundOp] only works for operations that do not use type variables +/// * "Overrides" of specific instantiations of polymorphic types will not be detected if +/// the instantiations are created inside polymorphic functions. For example, suppose +/// we [Self::replace_type] type `A` with `X`, [Self::replace_parametrized_type] +/// container `MyList` with `List`, and [Self::replace_type] `MyList` with +/// `SpecialListOfXs`. If a function `foo` polymorphic over a type variable `T` dealing +/// with `MyList`s, that is called with type argument `A`, then `foo` will be +/// updated to deal with `List`s and the call `foo` updated to `foo`, but this +/// will still result in using `List` rather than `SpecialListOfXs`. (However this +/// would be fine *after* [monomorphization]: the monomorphic definition of `foo_A` +/// would use `SpecialListOfXs`.) +/// * See also limitations noted for [Linearizer]. +/// +/// [monomorphization]: super::monomorphize() #[derive(Clone, Default)] pub struct ReplaceTypes { type_map: HashMap, param_types: HashMap Option>>, - op_map: HashMap, - param_ops: HashMap Option>>, + linearize: DelegatingLinearizer, + op_map: HashMap, + param_ops: HashMap Option>>, consts: HashMap< CustomType, Arc Result>, @@ -109,6 +172,8 @@ pub enum ReplaceTypesError { SignatureError(#[from] SignatureError), #[error(transparent)] ValidationError(#[from] ValidatePassError), + #[error(transparent)] + LinearizeError(#[from] LinearizeError), } impl ReplaceTypes { @@ -157,16 +222,33 @@ impl ReplaceTypes { // (depending on arguments - i.e. if src's TypeDefBound is anything other than // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying // overapproximation. Moreover, these depend upon the *return type* of the Fn. + // It would be too awkward to require: + // dest_fn: impl Fn(&TypeArg) -> (Type, + // Fn(&Linearizer) -> NodeTemplate, // copy + // Fn(&Linearizer) -> NodeTemplate)` // discard self.param_types.insert(src.into(), Arc::new(dest_fn)); } + /// Allows to configure how to deal with types/wires that were [Copyable] + /// but have become linear as a result of type-changing. Specifically, + /// the [Linearizer] is used whenever lowering produces an outport which both + /// * has a non-[Copyable] type - perhaps a direct substitution, or perhaps e.g. + /// as a result of changing the element type of a collection such as an [`array`] + /// * has other than one connected inport, + /// + /// [Copyable]: hugr_core::types::TypeBound::Copyable + /// [`array`]: hugr_core::std_extensions::collections::array::array_type + pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { + &mut self.linearize + } + /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* [OpDef], this takes /// precedence over [Self::replace_parametrized_op] where the `src`s overlap. Thus, /// this should only be used on already-*[monomorphize](super::monomorphize())d* /// Hugrs, as substitution (parametric polymorphism) happening later will not respect /// this replacement. - pub fn replace_op(&mut self, src: &ExtensionOp, dest: OpReplacement) { + pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { self.op_map.insert(OpHashWrapper::from(src), dest); } @@ -179,7 +261,7 @@ impl ReplaceTypes { pub fn replace_parametrized_op( &mut self, src: &OpDef, - dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { self.param_ops.insert(src.into(), Arc::new(dest_fn)); } @@ -221,6 +303,22 @@ impl ReplaceTypes { let mut changed = false; for n in hugr.nodes().collect::>() { changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } } Ok(changed) } @@ -452,7 +550,7 @@ mod test { use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; - use super::{handlers::list_const, OpReplacement, ReplaceTypes}; + use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; const READ: &str = "read"; @@ -513,7 +611,7 @@ mod test { } fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { + fn lowered_read(args: &[TypeArg]) -> Option { let ty = just_elem_type(args); let mut dfb = DFGBuilder::new(inout_sig( vec![array_type(64, ty.clone()), i64_t()], @@ -532,7 +630,7 @@ mod test { let [res] = dfb .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) .unwrap(); - Some(OpReplacement::CompoundOp(Box::new( + Some(NodeTemplate::CompoundOp(Box::new( dfb.finish_hugr_with_outputs([res]).unwrap(), ))) } @@ -545,7 +643,7 @@ mod test { ); lw.replace_op( &read_op(ext, bool_t()), - OpReplacement::SingleOp( + NodeTemplate::SingleOp( ExtensionOp::new(ext.get_op("lowered_read_bool").unwrap().clone(), []) .unwrap() .into(), @@ -824,7 +922,7 @@ mod test { e.get_op(READ).unwrap().as_ref(), Box::new(|args: &[TypeArg]| { option_contents(just_elem_type(args)).map(|elem| { - OpReplacement::SingleOp( + NodeTemplate::SingleOp( ListOp::get .with_type(elem) .to_extension_op() diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs new file mode 100644 index 000000000..371798dce --- /dev/null +++ b/hugr-passes/src/replace_types/linearize.rs @@ -0,0 +1,648 @@ +use std::borrow::Cow; +use std::iter::repeat; +use std::{collections::HashMap, sync::Arc}; + +use hugr_core::builder::{ + inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, +}; +use hugr_core::extension::{SignatureError, TypeDef}; +use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::Wire; +use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node}; +use itertools::Itertools; + +use super::{NodeTemplate, ParametricType}; + +/// Trait for things that know how to wire up linear outports to other than one +/// target. Used to restore Hugr validity when a [ReplaceTypes](super::ReplaceTypes) +/// results in types of such outports changing from [Copyable] to linear (i.e. +/// [hugr_core::types::TypeBound::Any]). +/// +/// Note that this is not really effective before [monomorphization]: if a +/// function polymorphic over a [Copyable] becomes called with a +/// non-Copyable type argument, [Linearizer] cannot insert copy/discard +/// operations for such a case. However, following [monomorphization], there +/// would be a specific instantiation of the function for the +/// type-that-becomes-linear, into which copy/discard can be inserted. +/// +/// [monomorphization]: crate::monomorphize() +/// [Copyable]: hugr_core::types::TypeBound::Copyable +pub trait Linearizer { + /// Insert copy or discard operations (as appropriate) enough to wire `src` + /// up to all `targets`. + /// + /// The default implementation + /// * if `targets.len() == 1`, wires `src` to the unique target + /// * otherwise, makes a single call to [Self::copy_discard_op], inserts that op, + /// and wires its outputs 1:1 to each target + /// + /// # Errors + /// + /// Most variants of [LinearizeError] can be raised, specifically including + /// [LinearizeError::CopyableType] if the type is [Copyable], in which case the Hugr + /// will be unchanged. + /// + /// [Copyable]: hugr_core::types::TypeBound::Copyable + /// + /// # Panics + /// + /// if `src` is not a valid Wire (does not identify a dataflow out-port) + fn insert_copy_discard( + &self, + hugr: &mut impl HugrMut, + src: Wire, + targets: &[(Node, IncomingPort)], + ) -> Result<(), LinearizeError> { + let sig = hugr.signature(src.node()).unwrap(); + let typ = sig.port_type(src.source()).unwrap(); + let (tgt_node, tgt_inport) = if targets.len() == 1 { + *targets.first().unwrap() + } else { + // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + let src_parent = hugr + .get_parent(src.node()) + .expect("Root node cannot have out edges"); + if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { + let tgt_parent = hugr + .get_parent(*tgt) + .expect("Root node cannot have incoming edges"); + (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) + }) { + return Err(LinearizeError::NoLinearNonLocalEdges { + src: src.node(), + src_parent, + tgt, + tgt_parent, + }); + } + let copy_discard_op = self + .copy_discard_op(typ, targets.len())? + .add_hugr(hugr, src_parent); + for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { + hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); + } + (copy_discard_op, 0.into()) + }; + hugr.connect(src.node(), src.source(), tgt_node, tgt_inport); + Ok(()) + } + + /// Gets an [NodeTemplate] for copying or discarding a value of type `typ`, i.e. + /// a recipe for a node with one input of that type and the specified number of + /// outports. + /// + /// Implementations are free to panic if `num_outports == 1`, such calls should never + /// occur as source/target can be directly wired without any node/op being required. + fn copy_discard_op( + &self, + typ: &Type, + num_outports: usize, + ) -> Result; +} + +/// A configuration for implementing [Linearizer] by delegating to +/// type-specific callbacks, and by composing them in order to handle compound types +/// such as [TypeEnum::Sum]s. +#[derive(Clone, Default)] +pub struct DelegatingLinearizer { + // Keyed by lowered type, as only needed when there is an op outputting such + copy_discard: HashMap, + // Copy/discard of parametric types handled by a function that receives the new/lowered type. + // We do not allow overriding copy/discard of non-extension types, but that + // can be achieved by *firstly* lowering to a custom linear type, with copy/discard + // inserted; *secondly* by lowering that to the desired non-extension linear type, + // including lowering of the copy/discard operations to...whatever. + copy_discard_parametric: HashMap< + ParametricType, + Arc Result>, + >, +} + +/// Implementation of [Linearizer] passed to callbacks, (e.g.) so that callbacks for +/// handling collection types can use it to generate copy/discards of elements. +// (Note, this is its own type just to give a bit of room for future expansion, +// rather than passing a &DelegatingLinearizer directly) +pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); + +#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[allow(missing_docs)] +pub enum LinearizeError { + #[error("Need copy/discard op for {_0}")] + NeedCopyDiscard(Type), + #[error("Copy/discard op for {typ} with {num_outports} outputs had wrong signature {sig:?}")] + WrongSignature { + typ: Type, + num_outports: usize, + sig: Option, + }, + #[error("Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})")] + NoLinearNonLocalEdges { + src: Node, + src_parent: Node, + tgt: Node, + tgt_parent: Node, + }, + /// SignatureError's can happen when converting nested types e.g. Sums + #[error(transparent)] + SignatureError(#[from] SignatureError), + /// We cannot linearize (insert copy and discard functions) for + /// [Variable](TypeEnum::Variable)s, [Row variables](TypeEnum::RowVar), + /// or [Alias](TypeEnum::Alias)es. + #[error("Cannot linearize type {_0}")] + UnsupportedType(Type), + /// Neither does linearization make sense for copyable types + #[error("Type {_0} is copyable")] + CopyableType(Type), +} + +impl DelegatingLinearizer { + /// Configures this instance that the specified monomorphic type can be copied and/or + /// discarded via the provided [NodeTemplate]s - directly or as part of a compound type + /// e.g. [TypeEnum::Sum]. + /// `copy` should have exactly one inport, of type `src`, and two outports, of same type; + /// `discard` should have exactly one inport, of type 'src', and no outports. + /// + /// # Errors + /// + /// * [LinearizeError::CopyableType] If `typ` is + /// [Copyable](hugr_core::types::TypeBound::Copyable) + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the + /// expected inputs or outputs + pub fn register_simple( + &mut self, + cty: CustomType, + copy: NodeTemplate, + discard: NodeTemplate, + ) -> Result<(), LinearizeError> { + let typ = Type::new_extension(cty.clone()); + if typ.copyable() { + return Err(LinearizeError::CopyableType(typ)); + } + check_sig(©, &typ, 2)?; + check_sig(&discard, &typ, 0)?; + self.copy_discard.insert(cty, (copy, discard)); + Ok(()) + } + + /// Configures this instance that instances of the specified [TypeDef] (perhaps + /// polymorphic) can be copied and/or discarded by using the provided callback + /// to generate a [NodeTemplate] for an appropriate copy/discard operation. + /// + /// The callback is given + /// * the type arguments (as appropriate for the [TypeDef], so perhaps empty) + /// * the desired number of outports (this will never be 1) + /// * A [CallbackHandler] that the callback can use it to generate + /// `copy`/`discard` ops for other types (e.g. the elements of a collection), + /// as part of an [NodeTemplate::CompoundOp]. + /// + /// Note that [Self::register_simple] takes precedence when the `src` types overlap. + pub fn register_callback( + &mut self, + src: &TypeDef, + copy_discard_fn: impl Fn(&[TypeArg], usize, &CallbackHandler) -> Result + + 'static, + ) { + // We could look for `src`s TypeDefBound being explicit Copyable, otherwise + // it depends on the arguments. Since there is no method to get the TypeDefBound + // from a TypeDef, leaving this for now. + self.copy_discard_parametric + .insert(src.into(), Arc::new(copy_discard_fn)); + } +} + +fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { + let sig = tmpl.signature(); + if sig.as_ref().is_some_and(|sig| { + sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + }) { + Ok(()) + } else { + Err(LinearizeError::WrongSignature { + typ: typ.clone(), + num_outports, + sig: sig.map(Cow::into_owned), + }) + } +} + +impl Linearizer for DelegatingLinearizer { + fn copy_discard_op( + &self, + typ: &Type, + num_outports: usize, + ) -> Result { + if typ.copyable() { + return Err(LinearizeError::CopyableType(typ.clone())); + }; + assert!(num_outports != 1); + + match typ.as_type_enum() { + TypeEnum::Sum(sum_type) => { + let variants = sum_type + .variants() + .map(|trv| trv.clone().try_into()) + .collect::, _>>()?; + let mut cb = ConditionalBuilder::new( + variants.clone(), + vec![], + vec![sum_type.clone().into(); num_outports], + ) + .unwrap(); + for (tag, variant) in variants.iter().enumerate() { + let mut case_b = cb.case_builder(tag).unwrap(); + let mut elems_for_copy = vec![vec![]; num_outports]; + for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { + let inp_copies = if ty.copyable() { + repeat(inp).take(num_outports).collect::>() + } else { + self.copy_discard_op(ty, num_outports)? + .add(&mut case_b, [inp]) + .unwrap() + .outputs() + .collect() + }; + for (src, elems) in inp_copies.into_iter().zip_eq(elems_for_copy.iter_mut()) + { + elems.push(src) + } + } + let t = Tag::new(tag, variants.clone()); + let outputs = elems_for_copy + .into_iter() + .map(|elems| { + let [copy] = case_b + .add_dataflow_op(t.clone(), elems) + .unwrap() + .outputs_arr(); + copy + }) + .collect::>(); // must collect to end borrow of `case_b` by closure + case_b.finish_with_outputs(outputs).unwrap(); + } + Ok(NodeTemplate::CompoundOp(Box::new( + cb.finish_hugr().unwrap(), + ))) + } + TypeEnum::Extension(cty) => match self.copy_discard.get(cty) { + Some((copy, discard)) => Ok(if num_outports == 0 { + discard.clone() + } else { + let mut dfb = + DFGBuilder::new(inout_sig(typ.clone(), vec![typ.clone(); num_outports])) + .unwrap(); + let [mut src] = dfb.input_wires_arr(); + let mut outputs = vec![]; + for _ in 0..num_outports - 1 { + let [out0, out1] = copy.clone().add(&mut dfb, [src]).unwrap().outputs_arr(); + outputs.push(out0); + src = out1; + } + outputs.push(src); + NodeTemplate::CompoundOp(Box::new( + dfb.finish_hugr_with_outputs(outputs).unwrap(), + )) + }), + None => { + let copy_discard_fn = self + .copy_discard_parametric + .get(&cty.into()) + .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; + let tmpl = copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self))?; + check_sig(&tmpl, typ, num_outports)?; + Ok(tmpl) + } + }, + TypeEnum::Function(_) => panic!("Ruled out above as copyable"), + _ => Err(LinearizeError::UnsupportedType(typ.clone())), + } + } +} + +impl Linearizer for CallbackHandler<'_> { + fn copy_discard_op( + &self, + typ: &Type, + num_outports: usize, + ) -> Result { + self.0.copy_discard_op(typ, num_outports) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + use std::sync::Arc; + + use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; + + use hugr_core::extension::prelude::{option_type, usize_t}; + use hugr_core::extension::{ + CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version, + }; + use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; + use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; + use hugr_core::ops::{DataflowOpTrait, OpType}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; + use hugr_core::types::type_param::TypeParam; + use hugr_core::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeRow}; + use hugr_core::{hugr::IdentList, Extension, Hugr, HugrView, Node}; + use itertools::Itertools; + use rstest::rstest; + + use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; + use crate::ReplaceTypes; + + const LIN_T: &str = "Lin"; + + struct NWayCopySigFn(Type); + impl CustomSignatureFunc for NWayCopySigFn { + fn compute_signature<'o, 'a: 'o>( + &'a self, + arg_values: &[TypeArg], + _def: &'o OpDef, + ) -> Result { + let [TypeArg::BoundedNat { n }] = arg_values else { + panic!() + }; + let outs = vec![self.0.clone(); *n as usize]; + Ok(FuncValueType::new(self.0.clone(), outs).into()) + } + + fn static_params(&self) -> &[TypeParam] { + const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat()]; + JUST_NAT + } + } + + fn ext_lowerer() -> (Arc, ReplaceTypes) { + // Extension with a linear type, an n-way parametric copy op, and a discard op + let e = Extension::new_arc( + IdentList::new_unchecked("TestExt"), + Version::new(0, 0, 0), + |e, w| { + let lin = Type::new_extension( + e.add_type(LIN_T.into(), vec![], String::new(), TypeDefBound::any(), w) + .unwrap() + .instantiate([]) + .unwrap(), + ); + e.add_op( + "discard".into(), + String::new(), + Signature::new(lin.clone(), vec![]), + w, + ) + .unwrap(); + e.add_op( + "copy".into(), + String::new(), + SignatureFunc::CustomFunc(Box::new(NWayCopySigFn(lin))), + w, + ) + .unwrap(); + }, + ); + + let lin_custom_t = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + + // Configure to lower usize_t to the linear type above, using a 2-way copy only + let copy_op = ExtensionOp::new(e.get_op("copy").unwrap().clone(), [2.into()]).unwrap(); + let discard_op = ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(); + let mut lowerer = ReplaceTypes::default(); + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + lowerer.replace_type(usize_custom_t, Type::new_extension(lin_custom_t.clone())); + lowerer + .linearizer() + .register_simple( + lin_custom_t, + NodeTemplate::SingleOp(copy_op.into()), + NodeTemplate::SingleOp(discard_op.into()), + ) + .unwrap(); + (e, lowerer) + } + + #[test] + fn single_values() { + let (_e, lowerer) = ext_lowerer(); + // Build Hugr - uses first input three times, discards second input (both usize) + let mut outer = DFGBuilder::new(inout_sig( + vec![usize_t(); 2], + vec![usize_t(), array_type(2, usize_t())], + )) + .unwrap(); + let [inp, _] = outer.input_wires_arr(); + let new_array = outer + .add_dataflow_op(ArrayOpDef::new_array.to_concrete(usize_t(), 2), [inp, inp]) + .unwrap(); + let [arr] = new_array.outputs_arr(); + let mut h = outer.finish_hugr_with_outputs([inp, arr]).unwrap(); + + assert!(lowerer.run(&mut h).unwrap()); + + let ext_ops = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); + let mut counts = HashMap::::new(); + for e in ext_ops { + *counts.entry(e.name()).or_default() += 1; + } + assert_eq!( + counts, + HashMap::from([ + ("TestExt.copy".into(), 2), + ("TestExt.discard".into(), 1), + ("collections.array.new_array".into(), 1) + ]) + ); + } + + fn copy_n_discard_one(ty: Type, n: usize) -> (Hugr, Node) { + let mut outer = DFGBuilder::new(inout_sig(ty.clone(), vec![ty.clone(); n - 1])).unwrap(); + let [inp] = outer.input_wires_arr(); + let inner = outer + .dfg_builder(inout_sig(ty, vec![]), [inp]) + .unwrap() + .finish_with_outputs([]) + .unwrap(); + let h = outer.finish_hugr_with_outputs(vec![inp; n - 1]).unwrap(); + (h, inner.node()) + } + + #[rstest] + fn sums_2way_copy(#[values(2, 3, 4)] num_copies: usize) { + let (mut h, inner) = copy_n_discard_one(option_type(usize_t()).into(), num_copies); + + let (e, lowerer) = ext_lowerer(); + assert!(lowerer.run(&mut h).unwrap()); + + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let sum_ty: Type = option_type(lin_t.clone()).into(); + let count_tags = |n| h.children(n).filter(|n| h.get_optype(*n).is_tag()).count(); + + // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... + for (dfg, num_tags, expected_ext_ops) in [ + (inner.node(), 0, vec!["TestExt.discard"]), + (h.root(), num_copies, vec!["TestExt.copy"; num_copies - 1]), // 2 copy nodes -> 3 outputs, etc. + ] { + let [(cond_node, cond)] = h + .children(dfg) + .filter_map(|n| h.get_optype(n).as_conditional().map(|c| (n, c))) + .collect_array() + .unwrap(); + assert_eq!( + cond.signature().output(), + &TypeRow::from(vec![sum_ty.clone(); num_tags]) + ); + let [case0, case1] = h.children(cond_node).collect_array().unwrap(); + // first is for empty variant + assert_eq!(h.children(case0).count(), 2 + num_tags); // Input, Output + assert_eq!(count_tags(case0), num_tags); + + // second is for variant of a LIN_T + assert_eq!(h.children(case1).count(), 3 + num_tags); // Input, Output, copy/discard + assert_eq!(count_tags(case1), num_tags); + let ext_ops = DescendantsGraph::::try_new(&h, case1) + .unwrap() + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op().map(ExtensionOp::name)) + .collect_vec(); + assert_eq!(ext_ops, expected_ext_ops); + } + } + + #[rstest] + fn sum_nway_copy(#[values(2, 5, 9)] num_copies: usize) { + let i8_t = || INT_TYPES[3].clone(); + let sum_ty = Type::new_sum([vec![i8_t()], vec![usize_t(); 2]]); + + let (mut h, inner) = copy_n_discard_one(sum_ty, num_copies); + let (e, _) = ext_lowerer(); + let mut lowerer = ReplaceTypes::default(); + let lin_t_def = e.get_type(LIN_T).unwrap(); + lowerer.replace_type( + usize_t().as_extension().unwrap().clone(), + lin_t_def.instantiate([]).unwrap().into(), + ); + let opdef = e.get_op("copy").unwrap(); + let opdef2 = opdef.clone(); + lowerer + .linearizer() + .register_callback(lin_t_def, move |args, num_outs, _| { + assert!(args.is_empty()); + Ok(NodeTemplate::SingleOp( + ExtensionOp::new(opdef2.clone(), [(num_outs as u64).into()]) + .unwrap() + .into(), + )) + }); + assert!(lowerer.run(&mut h).unwrap()); + + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let sum_ty = Type::new_sum([vec![i8_t()], vec![lin_t.clone(); 2]]); + let count_tags = |n| h.children(n).filter(|n| h.get_optype(*n).is_tag()).count(); + + // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... + for (dfg, num_tags) in [(inner.node(), 0), (h.root(), num_copies)] { + let [cond] = h + .children(dfg) + .filter(|n| h.get_optype(*n).is_conditional()) + .collect_array() + .unwrap(); + let [case0, case1] = h.children(cond).collect_array().unwrap(); + let out_row = vec![sum_ty.clone(); num_tags].into(); + // first is for empty variant - the only input is Copyable so can be directly wired or ignored + assert_eq!(h.children(case0).count(), 2 + num_tags); // Input, Output + assert_eq!(count_tags(case0), num_tags); + let case0 = h.get_optype(case0).as_case().unwrap(); + assert_eq!(case0.signature.io(), (&vec![i8_t()].into(), &out_row)); + + // second is for variant of two elements + assert_eq!(h.children(case1).count(), 4 + num_tags); // Input, Output, two leaf copies/discards: + assert_eq!(count_tags(case1), num_tags); + let ext_ops = h + .children(case1) + .filter_map(|n| h.get_optype(n).as_extension_op()) + .collect_vec(); + let expected_op = ExtensionOp::new(opdef.clone(), [(num_tags as u64).into()]).unwrap(); + assert_eq!(ext_ops, vec![&expected_op; 2]); + + let case1 = h.get_optype(case1).as_case().unwrap(); + assert_eq!( + case1.signature.io(), + (&vec![lin_t.clone(); 2].into(), &out_row) + ); + } + } + + #[test] + fn bad_sig() { + // Change usize to QB_T + let (ext, _) = ext_lowerer(); + let lin_ct = ext.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t = Type::from(lin_ct.clone()); + let copy3 = OpType::from( + ExtensionOp::new(ext.get_op("copy").unwrap().clone(), [3.into()]).unwrap(), + ); + let copy2 = ExtensionOp::new(ext.get_op("copy").unwrap().clone(), [2.into()]).unwrap(); + let discard = ExtensionOp::new(ext.get_op("discard").unwrap().clone(), []).unwrap(); + let mut replacer = ReplaceTypes::default(); + replacer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + + let bad_copy = replacer.linearizer().register_simple( + lin_ct.clone(), + NodeTemplate::SingleOp(copy3.clone()), + NodeTemplate::SingleOp(discard.clone().into()), + ); + let sig3 = Some( + Signature::new(lin_t.clone(), vec![lin_t.clone(); 3]) + .with_extension_delta(ext.name().clone()), + ); + assert_eq!( + bad_copy, + Err(LinearizeError::WrongSignature { + typ: lin_t.clone(), + num_outports: 2, + sig: sig3.clone() + }) + ); + + let bad_discard = replacer.linearizer().register_simple( + lin_ct.clone(), + NodeTemplate::SingleOp(copy2.into()), + NodeTemplate::SingleOp(copy3.clone()), + ); + + assert_eq!( + bad_discard, + Err(LinearizeError::WrongSignature { + typ: lin_t.clone(), + num_outports: 0, + sig: sig3.clone() + }) + ); + + // Try parametrized instead, but this version always returns 3 outports + replacer + .linearizer() + .register_callback(ext.get_type(LIN_T).unwrap(), move |_args, _, _| { + Ok(NodeTemplate::SingleOp(copy3.clone())) + }); + + // A hugr that copies a usize + let dfb = DFGBuilder::new(inout_sig(usize_t(), vec![usize_t(); 2])).unwrap(); + let [inp] = dfb.input_wires_arr(); + let mut h = dfb.finish_hugr_with_outputs([inp, inp]).unwrap(); + + assert_eq!( + replacer.run(&mut h), + Err(ReplaceTypesError::LinearizeError( + LinearizeError::WrongSignature { + typ: lin_t.clone(), + num_outports: 2, + sig: sig3.clone() + } + )) + ); + } +}