diff --git a/Cargo.lock b/Cargo.lock index bed0362bb..64f80f855 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1939,9 +1939,9 @@ dependencies = [ [[package]] name = "portgraph" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f923403923182a2ddc882051fc292cdc8900ba98e98b0aa8ff2743c865bebc" +checksum = "395d61e69996c0fe4b187149ab0e610a6be2af1b37bd13df8a567ae629f7359f" dependencies = [ "bitvec", "delegate", diff --git a/Cargo.toml b/Cargo.toml index da1c4b0a0..06e78328a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ missing_docs = "warn" debug_assert_with_mut_call = "warn" [workspace.dependencies] -portgraph = { version = "0.13.0" } +portgraph = { version = "0.13.3" } insta = { version = "1.42.2" } bitvec = "1.0.1" capnp = "0.20.4" diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index f2ccfc27a..f3ef094be 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,7 +1,7 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; @@ -11,6 +11,8 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; +use crate::ops::OpTrait; +use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; @@ -146,6 +148,29 @@ pub trait HugrMut: HugrMutInternals { self.hugr_mut().remove_node(node); } + /// Copies the strict descendants of `root` to under the `new_parent`, optionally applying a + /// [Substitution] to the [OpType]s of the copied nodes. + /// + /// That is, the immediate children of root, are copied to make children of `new_parent`. + /// + /// Note this may invalidate the Hugr in two ways: + /// * Adding children of `root` may make the children-list of `new_parent` invalid e.g. + /// leading to multiple [Input](OpType::Input), [Output](OpType::Output) or + /// [ExitBlock](OpType::ExitBlock) nodes or Input/Output in the wrong positions + /// * Nonlocal edges incoming to the subtree of `root` will be copied to target the subtree under `new_parent` + /// which may be invalid if `new_parent` is not a child of `root`s parent (for `Ext` edges - or + /// correspondingly for `Dom` edges) + fn copy_descendants( + &mut self, + root: Node, + new_parent: Node, + subst: Option, + ) -> BTreeMap { + panic_invalid_node(self, root); + panic_invalid_node(self, new_parent); + self.hugr_mut().copy_descendants(root, new_parent, subst) + } + /// Connect two nodes at the given ports. /// /// # Panics @@ -294,8 +319,8 @@ pub struct InsertionResult { fn translate_indices( node_map: HashMap, -) -> HashMap { - HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into()))) +) -> impl Iterator { + node_map.into_iter().map(|(k, v)| (k.into(), v.into())) } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -398,7 +423,7 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map), + node_map: translate_indices(node_map).collect(), } } @@ -419,7 +444,7 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map), + node_map: translate_indices(node_map).collect(), } } @@ -448,7 +473,44 @@ impl + AsMut> HugrMut for T self.use_extensions(exts); } } - translate_indices(node_map) + translate_indices(node_map).collect() + } + + fn copy_descendants( + &mut self, + root: Node, + new_parent: Node, + subst: Option, + ) -> BTreeMap { + let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); + let root2 = descendants.next(); + debug_assert_eq!(root2, Some(root.pg_index())); + let nodes = Vec::from_iter(descendants); + let node_map = translate_indices( + portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) + .copy_in_parent() + .expect("Is a MultiPortGraph"), + ) + .collect::>(); + + for node in self.children(root).collect::>() { + self.set_parent(*node_map.get(&node).unwrap(), new_parent); + } + + // Copy the optypes, metadata, and hierarchy + for (&node, &new_node) in node_map.iter() { + for ch in self.children(node).collect::>() { + self.set_parent(*node_map.get(&ch).unwrap(), new_node); + } + let new_optype = match (&subst, self.get_optype(node)) { + (None, op) => op.clone(), + (Some(subst), op) => op.substitute(subst), + }; + self.as_mut().op_types.set(new_node.pg_index(), new_optype); + let meta = self.base_hugr().metadata.get(node.pg_index()).clone(); + self.as_mut().metadata.set(new_node.pg_index(), meta); + } + node_map } } diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index 5f252c68f..1f2be7a9a 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -1,6 +1,7 @@ //! Rewrite operations on the HUGR - replacement, outlining, etc. pub mod consts; +pub mod inline_call; pub mod inline_dfg; pub mod insert_identity; pub mod outline_cfg; diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/rewrite/inline_call.rs new file mode 100644 index 000000000..9af9cd70a --- /dev/null +++ b/hugr-core/src/hugr/rewrite/inline_call.rs @@ -0,0 +1,336 @@ +//! Rewrite to inline a Call to a FuncDefn by copying the body of the function +//! into a DFG which replaces the Call node. +use derive_more::{Display, Error}; + +use crate::ops::{DataflowParent, OpType, DFG}; +use crate::types::Substitution; +use crate::{Direction, HugrView, Node}; + +use super::{HugrMut, Rewrite}; + +/// Rewrite to inline a [Call](OpType::Call) to a known [FuncDefn](OpType::FuncDefn) +pub struct InlineCall(Node); + +/// Error in performing [InlineCall] rewrite. +#[derive(Clone, Debug, Display, Error, PartialEq)] +#[non_exhaustive] +pub enum InlineCallError { + /// The specified Node was not a [Call](OpType::Call) + #[display("Node to inline {_0} expected to be a Call but actually {_1}")] + NotCallNode(Node, OpType), + /// The node was a Call, but the target was not a [FuncDefn](OpType::FuncDefn) + /// - presumably a [FuncDecl](OpType::FuncDecl), if the Hugr is valid. + #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")] + CallTargetNotFuncDefn(Node, OpType), +} + +impl InlineCall { + /// Create a new instance that will inline the specified node + /// (i.e. that should be a [Call](OpType::Call)) + pub fn new(node: Node) -> Self { + Self(node) + } +} + +impl Rewrite for InlineCall { + type ApplyResult = (); + type Error = InlineCallError; + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let call_ty = h.get_optype(self.0); + if !call_ty.is_call() { + return Err(InlineCallError::NotCallNode(self.0, call_ty.clone())); + } + let func = h.static_source(self.0).unwrap(); + let func_ty = h.get_optype(func); + if !func_ty.is_func_defn() { + return Err(InlineCallError::CallTargetNotFuncDefn( + func, + func_ty.clone(), + )); + } + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + self.verify(h)?; // Now we know we have a Call to a FuncDefn. + let orig_func = h.static_source(self.0).unwrap(); + + h.disconnect(self.0, h.get_optype(self.0).static_input_port().unwrap()); + + // The order input port gets renumbered because the static input + // (which comes between the value inports and the order inport) gets removed + let old_order_in = h.get_optype(self.0).other_input_port().unwrap(); + let order_preds = h.linked_outputs(self.0, old_order_in).collect::>(); + h.disconnect(self.0, old_order_in); // PortGraph currently does this anyway + + let new_op = OpType::from(DFG { + signature: h + .get_optype(orig_func) + .as_func_defn() + .unwrap() + .inner_signature() + .into_owned(), + }); + let new_order_in = new_op.other_input_port().unwrap(); + + let ty_args = h + .replace_op(self.0, new_op) + .unwrap() + .as_call() + .unwrap() + .type_args + .clone(); + + h.add_ports(self.0, Direction::Incoming, -1); + + // Reconnect order predecessors + for (src, srcp) in order_preds { + h.connect(src, srcp, self.0, new_order_in); + } + + h.copy_descendants( + orig_func, + self.0, + (!ty_args.is_empty()).then_some(Substitution::new(&ty_args)), + ); + Ok(()) + } + + /// Failure only occurs if the node is not a Call, or the target not a FuncDefn. + /// (Any later failure means an invalid Hugr and `panic`.) + const UNCHANGED_ON_FAILURE: bool = true; + + fn invalidation_set(&self) -> impl Iterator { + Some(self.0).into_iter() + } +} + +#[cfg(test)] +mod test { + use std::iter::successors; + + use itertools::Itertools; + + use crate::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use crate::extension::prelude::usize_t; + use crate::hugr::views::RootChecked; + use crate::ops::handle::{FuncID, ModuleRootID, NodeHandle}; + use crate::ops::{Input, OpType, Value}; + use crate::std_extensions::arithmetic::{ + int_ops::{self, IntOpDef}, + int_types::{self, ConstInt, INT_TYPES}, + }; + + use crate::types::{PolyFuncType, Signature, Type, TypeBound}; + use crate::{HugrView, Node}; + + use super::{HugrMut, InlineCall, InlineCallError}; + + fn calls(h: &impl HugrView) -> Vec { + h.nodes().filter(|n| h.get_optype(*n).is_call()).collect() + } + + fn extension_ops(h: &impl HugrView) -> Vec { + h.nodes() + .filter(|n| h.get_optype(*n).is_extension_op()) + .collect() + } + + #[test] + fn test_inline() -> Result<(), Box> { + let mut mb = ModuleBuilder::new(); + let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?)); + let sig = Signature::new_endo(INT_TYPES[4].clone()) + .with_extension_delta(int_ops::EXTENSION_ID) + .with_extension_delta(int_types::EXTENSION_ID); + let func = { + let mut fb = mb.define_function("foo", sig.clone())?; + let c1 = fb.load_const(&cst3); + let mut inner = fb.dfg_builder(sig.clone(), fb.input_wires())?; + let [i] = inner.input_wires_arr(); + let add = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [i, c1])?; + let inner_res = inner.finish_with_outputs(add.outputs())?; + fb.finish_with_outputs(inner_res.outputs())? + }; + let mut main = mb.define_function("main", sig)?; + let call1 = main.call(func.handle(), &[], main.input_wires())?; + main.add_other_wire(main.input().node(), call1.node()); + let call2 = main.call(func.handle(), &[], call1.outputs())?; + main.finish_with_outputs(call2.outputs())?; + let mut hugr = mb.finish_hugr()?; + let call1 = call1.node(); + let call2 = call2.node(); + assert_eq!( + hugr.output_neighbours(func.node()).collect_vec(), + [call1, call2] + ); + assert_eq!(calls(&hugr), [call1, call2]); + assert_eq!(extension_ops(&hugr).len(), 1); + + assert_eq!( + hugr.linked_outputs( + call1.node(), + hugr.get_optype(call1).other_input_port().unwrap() + ) + .count(), + 1 + ); + RootChecked::<_, ModuleRootID>::try_new(&mut hugr) + .unwrap() + .apply_rewrite(InlineCall(call1.node())) + .unwrap(); + hugr.validate().unwrap(); + assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]); + assert_eq!(calls(&hugr), [call2]); + assert_eq!(extension_ops(&hugr).len(), 2); + assert_eq!( + hugr.linked_outputs( + call1.node(), + hugr.get_optype(call1).other_input_port().unwrap() + ) + .count(), + 1 + ); + hugr.apply_rewrite(InlineCall(call2.node())).unwrap(); + hugr.validate().unwrap(); + assert_eq!(hugr.output_neighbours(func.node()).next(), None); + assert_eq!(calls(&hugr), []); + assert_eq!(extension_ops(&hugr).len(), 3); + + Ok(()) + } + + #[test] + fn test_recursion() -> Result<(), Box> { + let mut mb = ModuleBuilder::new(); + let sig = Signature::new_endo(INT_TYPES[5].clone()) + .with_extension_delta(int_ops::EXTENSION_ID) + .with_extension_delta(int_types::EXTENSION_ID); + let (func, rec_call) = { + let mut fb = mb.define_function("foo", sig.clone())?; + let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?); + let [i] = fb.input_wires_arr(); + let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [i, cst1])?; + let call = fb.call( + &FuncID::::from(fb.container_node()), + &[], + add.outputs(), + )?; + (fb.finish_with_outputs(call.outputs())?, call) + }; + let mut main = mb.define_function("main", sig)?; + let call = main.call(func.handle(), &[], main.input_wires())?; + let main = main.finish_with_outputs(call.outputs())?; + let mut hugr = mb.finish_hugr()?; + + let func = func.node(); + let mut call = call.node(); + for i in 2..10 { + hugr.apply_rewrite(InlineCall(call))?; + hugr.validate().unwrap(); + assert_eq!(extension_ops(&hugr).len(), i); + let v = calls(&hugr); + assert!(v.iter().all(|n| hugr.static_source(*n) == Some(func))); + + let [rec, nonrec] = v.try_into().expect("Should be two"); + assert_eq!(rec, rec_call.node()); + assert_eq!(hugr.output_neighbours(func).collect_vec(), [rec, nonrec]); + call = nonrec; + + let mut ancestors = successors(hugr.get_parent(call), |n| hugr.get_parent(*n)); + for _ in 1..i { + assert!(hugr.get_optype(ancestors.next().unwrap()).is_dfg()); + } + assert_eq!(ancestors.next(), Some(main.node())); + assert_eq!(ancestors.next(), Some(hugr.root())); + assert_eq!(ancestors.next(), None); + } + Ok(()) + } + + #[test] + fn test_bad() { + let mut modb = ModuleBuilder::new(); + let decl = modb + .declare( + "UndefinedFunc", + Signature::new_endo(INT_TYPES[3].clone()).into(), + ) + .unwrap(); + let mut main = modb + .define_function("main", Signature::new_endo(INT_TYPES[3].clone())) + .unwrap(); + let call = main.call(&decl, &[], main.input_wires()).unwrap(); + let main = main.finish_with_outputs(call.outputs()).unwrap(); + let h = modb.finish_hugr().unwrap(); + let mut h2 = h.clone(); + assert_eq!( + h2.apply_rewrite(InlineCall(call.node())), + Err(InlineCallError::CallTargetNotFuncDefn( + decl.node(), + h.get_optype(decl.node()).clone() + )) + ); + assert_eq!(h, h2); + let [inp, _out, _call] = h + .children(main.node()) + .collect::>() + .try_into() + .unwrap(); + assert_eq!( + h2.apply_rewrite(InlineCall(inp)), + Err(InlineCallError::NotCallNode( + inp, + Input { + types: INT_TYPES[3].clone().into() + } + .into() + )) + ) + } + + #[test] + fn test_polymorphic() -> Result<(), Box> { + let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); + let mut fb = FunctionBuilder::new( + "mkpair", + Signature::new(usize_t(), tuple_ty.clone()).with_prelude(), + )?; + let inner = fb.define_function( + "id", + PolyFuncType::new( + [TypeBound::Copyable.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)), + ), + )?; + let inps = inner.input_wires(); + let inner = inner.finish_with_outputs(inps)?; + let call1 = fb.call(inner.handle(), &[usize_t().into()], fb.input_wires())?; + let [call1_out] = call1.outputs_arr(); + let tup = fb.make_tuple([call1_out, call1_out])?; + let call2 = fb.call(inner.handle(), &[tuple_ty.into()], [tup])?; + let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap(); + + assert_eq!( + hugr.output_neighbours(inner.node()).collect::>(), + [call1.node(), call2.node()] + ); + hugr.apply_rewrite(InlineCall::new(call1.node()))?; + + assert_eq!( + hugr.output_neighbours(inner.node()).collect::>(), + [call2.node()] + ); + assert!(hugr.get_optype(call1.node()).is_dfg()); + assert!(matches!( + hugr.children(call1.node()) + .map(|n| hugr.get_optype(n).clone()) + .collect::>()[..], + [OpType::Input(_), OpType::Output(_)] + )); + Ok(()) + } +}