diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 352b3d0400..ff53ca0cd9 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -19,7 +19,7 @@ pub use self::petgraph::PetgraphWrapper; use self::render::{MermaidFormatter, RenderConfig}; pub use nodes_iter::NodesIter; pub use rerooted::Rerooted; -pub use root_checked::{InvalidSignature, RootCheckable, RootChecked, check_tag}; +pub use root_checked::{InvalidSignature, RootChecked, check_tag}; pub use sibling_subgraph::SiblingSubgraph; use itertools::Itertools; diff --git a/hugr-core/src/hugr/views/root_checked.rs b/hugr-core/src/hugr/views/root_checked.rs index c7a01e4108..acd592450b 100644 --- a/hugr-core/src/hugr/views/root_checked.rs +++ b/hugr-core/src/hugr/views/root_checked.rs @@ -67,26 +67,6 @@ impl, Handle> AsRef for RootChecked { } } -/// A trait for types that can be checked for a specific [`OpTag`] at their entrypoint node. -/// -/// This is used mainly specifying function inputs that may either be a [`HugrView`] or an already checked [`RootChecked`]. -pub trait RootCheckable>: Sized { - /// Wrap the Hugr in a [`RootChecked`] if it is valid for the required [`OpTag`]. - /// - /// If `Self` is already a [`RootChecked`], it is a no-op. - fn try_into_checked(self) -> Result, HugrError>; -} -impl> RootCheckable for H { - fn try_into_checked(self) -> Result, HugrError> { - RootChecked::try_new(self) - } -} -impl> RootCheckable for RootChecked { - fn try_into_checked(self) -> Result, HugrError> { - Ok(self) - } -} - /// Check that the node in a HUGR can be represented by the required tag. pub fn check_tag, N>( hugr: &impl HugrView, diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index b1d0d3e1c3..f5342aba2b 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -27,7 +27,7 @@ use crate::ops::{NamedOp, OpTag, OpTrait, OpType}; use crate::types::{Signature, Type}; use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement}; -use super::root_checked::RootCheckable; +use super::RootChecked; /// A non-empty convex subgraph of a HUGR sibling graph. /// @@ -110,15 +110,12 @@ impl SiblingSubgraph { /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the /// subgraph is empty. pub fn try_new_dataflow_subgraph<'h, H, Root>( - dfg_graph: impl RootCheckable<&'h H, Root>, + dfg_graph: RootChecked<&'h H, Root>, ) -> Result> where H: 'h + Clone + HugrView, Root: ContainerHandle, { - let Ok(dfg_graph) = dfg_graph.try_into_checked() else { - return Err(InvalidSubgraph::NonDataflowRegion); - }; let dfg_graph = dfg_graph.into_hugr(); let parent = HugrView::entrypoint(&dfg_graph); @@ -1651,7 +1648,9 @@ mod tests { fn construct_simple_replacement() -> Result<(), InvalidSubgraph> { let (mut hugr, func_root) = build_hugr().unwrap(); let func = hugr.with_entrypoint(func_root); - let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>( + RootChecked::try_new(&func).expect("Root should be FuncDefn."), + )?; assert!(sub.validate(&func, Default::default()).is_ok()); let empty_dfg = { @@ -1699,7 +1698,9 @@ mod tests { fn test_signature() -> Result<(), InvalidSubgraph> { let (hugr, dfg) = build_hugr().unwrap(); let func = hugr.with_entrypoint(dfg); - let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>( + RootChecked::try_new(&func).expect("Root should be FuncDefn."), + )?; assert!(sub.validate(&func, Default::default()).is_ok()); assert_eq!( sub.signature(&func), @@ -1732,10 +1733,12 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func = hugr.with_entrypoint(func_root); assert_eq!( - SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func) - .unwrap() - .nodes() - .len(), + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>( + RootChecked::try_new(&func).expect("Root should be FuncDefn.") + ) + .unwrap() + .nodes() + .len(), 4 ); } @@ -1848,8 +1851,10 @@ mod tests { fn preserve_signature() { let (hugr, func_root) = build_hugr_classical().unwrap(); let func_graph = hugr.with_entrypoint(func_root); - let func = - SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); + let func = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>( + RootChecked::try_new(&func_graph).expect("Root should be FuncDefn."), + ) + .unwrap(); let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); assert_eq!(func_defn.signature(), &func.signature(&func_graph).into()); } @@ -1858,8 +1863,10 @@ mod tests { fn extract_subgraph() { let (hugr, func_root) = build_hugr().unwrap(); let func_graph = hugr.with_entrypoint(func_root); - let subgraph = - SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); + let subgraph = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>( + RootChecked::try_new(&func_graph).expect("Root should be FuncDefn."), + ) + .unwrap(); let extracted = subgraph.extract_subgraph(&hugr, "region"); extracted.validate().unwrap(); @@ -1883,7 +1890,10 @@ mod tests { .outputs(); let outw = [outw1].into_iter().chain(outw2); let h = builder.finish_hugr_with_outputs(outw).unwrap(); - let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&h).unwrap(); + let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>( + RootChecked::try_new(&h).expect("Root should be DFG."), + ) + .unwrap(); assert_eq!(subg.nodes().len(), 2); } @@ -2178,9 +2188,10 @@ mod tests { #[rstest] fn test_call_subgraph_from_dfg(hugr_call_subgraph: Hugr) { - let subg = - SiblingSubgraph::try_new_dataflow_subgraph::<_, DataflowParentID>(&hugr_call_subgraph) - .unwrap(); + let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DataflowParentID>( + RootChecked::try_new(&hugr_call_subgraph).expect("Root should be DFG container."), + ) + .unwrap(); assert_eq!(subg.function_calls.len(), 1); assert_eq!(subg.function_calls[0].len(), 2); diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs index bb46828b7f..42c7bf72a6 100644 --- a/hugr-passes/src/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use super::nest_cfgs::CfgNodeMap; use hugr_core::hugr::internal::HugrInternals; -use hugr_core::hugr::views::RootCheckable; +use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{OpTag, OpTrait}; use hugr_core::{Direction, HugrView, Node}; @@ -32,9 +32,8 @@ struct HalfNodeView { impl HalfNodeView { #[allow(unused)] - pub(crate) fn new(h: impl RootCheckable>) -> Self { - let checked = h.try_into_checked().expect("Hugr must be a CFG region"); - let h = checked.into_hugr(); + pub(crate) fn new(h: RootChecked>) -> Self { + let h = h.into_hugr(); let (entry, exit) = { let mut children = h.children(h.entrypoint()); @@ -99,6 +98,7 @@ mod test { use super::super::nest_cfgs::{EdgeClassifier, test::*}; use super::{HalfNode, HalfNodeView}; use hugr_core::builder::BuildError; + use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::NodeHandle; use itertools::Itertools; @@ -118,7 +118,7 @@ mod test { // \---<---<---<---<---<---<---<---<---<---/ // Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example) - let v = HalfNodeView::new(&h); + let v = HalfNodeView::new(RootChecked::try_new(&h).expect("Root should be CFG.")); let edge_classes = EdgeClassifier::get_edge_classes(&v); let HalfNodeView { h: _, entry, exit } = v; diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index f499973c3c..63999a8ef6 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -25,7 +25,7 @@ pub mod untuple; /// Merge basic blocks. Subset of [normalize_cfgs], use the latter. #[deprecated(note = "Use normalize_cfgs", since = "0.23.0")] pub mod merge_bbs { - use hugr_core::hugr::{hugrmut::HugrMut, views::RootCheckable}; + use hugr_core::hugr::{hugrmut::HugrMut, views::RootChecked}; use hugr_core::ops::handle::CfgID; /// Merge any basic blocks that are direct children of the specified CFG @@ -38,11 +38,8 @@ pub mod merge_bbs { /// /// [OpType::CFG]: hugr_core::ops::OpType::CFG #[deprecated(note = "Use version in normalize_cfgs", since = "0.23.0")] - pub fn merge_basic_blocks<'h, H: 'h + HugrMut>( - cfg: impl RootCheckable<&'h mut H, CfgID>, - ) { - let checked = cfg.try_into_checked().expect("Hugr must be a CFG region"); - super::normalize_cfgs::merge_basic_blocks(checked.into_hugr()).unwrap(); + pub fn merge_basic_blocks<'h, H: 'h + HugrMut>(cfg: RootChecked<&'h mut H, CfgID>) { + super::normalize_cfgs::merge_basic_blocks(cfg.into_hugr()).unwrap(); } } diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 0a2d882da6..7f456f8811 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -45,11 +45,10 @@ use itertools::Itertools; use thiserror::Error; use hugr_core::hugr::patch::outline_cfg::OutlineCfg; -use hugr_core::hugr::views::{HugrView, RootCheckable}; +use hugr_core::hugr::views::HugrView; use hugr_core::hugr::{Patch, hugrmut::HugrMut}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; -use hugr_core::ops::handle::CfgID; use hugr_core::{Direction, Hugr, Node}; /// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into @@ -220,10 +219,7 @@ pub struct IdentityCfgMap { } impl IdentityCfgMap { /// Creates an [`IdentityCfgMap`] for the specified CFG - pub fn new(h: impl RootCheckable>) -> Self { - let h = h.try_into_checked().expect("Hugr must be a CFG region"); - let h = h.into_hugr(); - + pub fn new(h: H) -> Self { // Panic if malformed enough not to have two children let (entry, exit) = h.children(h.entrypoint()).take(2).collect_tuple().unwrap(); debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); @@ -582,7 +578,6 @@ pub(crate) mod test { use hugr_core::Node; use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError}; - use hugr_core::hugr::views::RootChecked; use hugr_core::ops::Value; use hugr_core::ops::handle::{BasicBlockID, ConstID, NodeHandle}; use hugr_core::types::{EdgeKind, Signature}; @@ -632,11 +627,11 @@ pub(crate) mod test { let exit = cfg_builder.exit_block(); cfg_builder.branch(&tail, 0, &exit)?; - let mut h = cfg_builder.finish_hugr()?; - let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap(); + let h = cfg_builder.finish_hugr()?; let (entry, exit) = (entry.node(), exit.node()); let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node()); - let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.as_ref())); + let mut v = IdentityCfgMap::new(h); + let edge_classes = EdgeClassifier::get_edge_classes(&v); let [&left, &right] = edge_classes .keys() .filter(|(s, _)| *s == split) @@ -657,7 +652,8 @@ pub(crate) mod test { sorted([(entry, split), (merge, head), (tail, exit)]), // Two regions, conditional and then loop. ]) ); - transform_cfg_to_nested(&mut IdentityCfgMap::new(rc)); + transform_cfg_to_nested(&mut v); + let h = v.h; h.validate().unwrap(); assert_eq!(3, depth(&h, entry)); assert_eq!(3, depth(&h, exit)); @@ -693,7 +689,7 @@ pub(crate) mod test { .try_into() .unwrap(); - let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap()); + let v = IdentityCfgMap::new(h); let edge_classes = EdgeClassifier::get_edge_classes(&v); let [&left, &right] = edge_classes .keys() @@ -783,7 +779,7 @@ pub(crate) mod test { // Here we would like an indication that we can make two nested regions, // but there is no edge to act as entry to a region containing just the conditional :-(. - let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap()); + let v = IdentityCfgMap::new(h); let edge_classes = EdgeClassifier::get_edge_classes(&v); let IdentityCfgMap { h: _, entry, exit } = v; // merge is unique predecessor of tail diff --git a/hugr-persistent/src/walker.rs b/hugr-persistent/src/walker.rs index db376d235d..eac2f05a8c 100644 --- a/hugr-persistent/src/walker.rs +++ b/hugr-persistent/src/walker.rs @@ -62,7 +62,7 @@ use hugr_core::ops::handle::DataflowParentID; use itertools::{Either, Itertools}; use thiserror::Error; -use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootCheckable}; +use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootChecked}; use crate::{Commit, PersistentReplacement, PinnedSubgraph}; @@ -302,7 +302,7 @@ impl<'a> Walker<'a> { pub fn try_create_commit( &self, subgraph: impl Into, - repl: impl RootCheckable, + mut repl: RootChecked, map_boundary: impl Fn(PatchNode, Port) -> Port, ) -> Result, InvalidCommit> { let pinned_subgraph = subgraph.into(); @@ -312,7 +312,6 @@ impl<'a> Walker<'a> { .map(|id| self.selected_commits.get_commit(id).clone()); let repl = { - let mut repl = repl.try_into_checked().expect("replacement is not DFG"); let new_inputs = subgraph .incoming_ports() .iter() @@ -783,7 +782,7 @@ mod tests { let commit = walker .try_create_commit( PinnedSubgraph::try_from_pinned(std::iter::empty(), [wire], &walker).unwrap(), - empty_hugr, + RootChecked::try_new(empty_hugr).expect("Root should be DFG."), |node, port| { assert_eq!(port.index(), 0); assert!([not0, not2].contains(&node)); diff --git a/hugr-persistent/tests/persistent_walker_example.rs b/hugr-persistent/tests/persistent_walker_example.rs index 8499700b58..ccaaf0140b 100644 --- a/hugr-persistent/tests/persistent_walker_example.rs +++ b/hugr-persistent/tests/persistent_walker_example.rs @@ -8,6 +8,7 @@ use hugr_core::{ Hugr, HugrView, IncomingPort, OutgoingPort, Port, PortIndex, builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, extension::prelude::qb_t, + hugr::views::RootChecked, ops::OpType, types::EdgeKind, }; @@ -251,7 +252,7 @@ fn create_commit<'a>(wire: PersistentWire, walker: &Walker<'a>) -> Option(wire: PersistentWire, walker: &Walker<'a>) -> Option