Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hugr-core/src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use self::petgraph::PetgraphWrapper;
use self::render::{MermaidFormatter, RenderConfig};
pub use nodes_iter::NodesIter;
pub use rerooted::Rerooted;
pub use root_checked::{InvalidSignature, RootCheckable, RootChecked, check_tag};
pub use root_checked::{InvalidSignature, RootChecked, check_tag};
pub use sibling_subgraph::SiblingSubgraph;

use itertools::Itertools;
Expand Down
20 changes: 0 additions & 20 deletions hugr-core/src/hugr/views/root_checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,6 @@ impl<H: AsRef<Hugr>, Handle> AsRef<Hugr> for RootChecked<H, Handle> {
}
}

/// A trait for types that can be checked for a specific [`OpTag`] at their entrypoint node.
///
/// This is used mainly specifying function inputs that may either be a [`HugrView`] or an already checked [`RootChecked`].
pub trait RootCheckable<H: HugrView, Handle: NodeHandle<H::Node>>: Sized {
/// Wrap the Hugr in a [`RootChecked`] if it is valid for the required [`OpTag`].
///
/// If `Self` is already a [`RootChecked`], it is a no-op.
fn try_into_checked(self) -> Result<RootChecked<H, Handle>, HugrError>;
}
impl<H: HugrView, Handle: NodeHandle<H::Node>> RootCheckable<H, Handle> for H {
fn try_into_checked(self) -> Result<RootChecked<H, Handle>, HugrError> {
RootChecked::try_new(self)
}
}
impl<H: HugrView, Handle: NodeHandle<H::Node>> RootCheckable<H, Handle> for RootChecked<H, Handle> {
fn try_into_checked(self) -> Result<RootChecked<H, Handle>, HugrError> {
Ok(self)
}
}

/// Check that the node in a HUGR can be represented by the required tag.
pub fn check_tag<Required: NodeHandle<N>, N>(
hugr: &impl HugrView<Node = N>,
Expand Down
49 changes: 30 additions & 19 deletions hugr-core/src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
use crate::types::{Signature, Type};
use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};

use super::root_checked::RootCheckable;
use super::RootChecked;

/// A non-empty convex subgraph of a HUGR sibling graph.
///
Expand Down Expand Up @@ -110,15 +110,12 @@ impl<N: HugrNode> SiblingSubgraph<N> {
/// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the
/// subgraph is empty.
pub fn try_new_dataflow_subgraph<'h, H, Root>(
dfg_graph: impl RootCheckable<&'h H, Root>,
dfg_graph: RootChecked<&'h H, Root>,
) -> Result<Self, InvalidSubgraph<N>>
where
H: 'h + Clone + HugrView<Node = N>,
Root: ContainerHandle<N, ChildrenHandle = DataflowOpID>,
{
let Ok(dfg_graph) = dfg_graph.try_into_checked() else {
return Err(InvalidSubgraph::NonDataflowRegion);
};
let dfg_graph = dfg_graph.into_hugr();

let parent = HugrView::entrypoint(&dfg_graph);
Expand Down Expand Up @@ -1651,7 +1648,9 @@ mod tests {
fn construct_simple_replacement() -> Result<(), InvalidSubgraph> {
let (mut hugr, func_root) = build_hugr().unwrap();
let func = hugr.with_entrypoint(func_root);
let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(
RootChecked::try_new(&func).expect("Root should be FuncDefn."),
)?;
assert!(sub.validate(&func, Default::default()).is_ok());

let empty_dfg = {
Expand Down Expand Up @@ -1699,7 +1698,9 @@ mod tests {
fn test_signature() -> Result<(), InvalidSubgraph> {
let (hugr, dfg) = build_hugr().unwrap();
let func = hugr.with_entrypoint(dfg);
let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(
RootChecked::try_new(&func).expect("Root should be FuncDefn."),
)?;
assert!(sub.validate(&func, Default::default()).is_ok());
assert_eq!(
sub.signature(&func),
Expand Down Expand Up @@ -1732,10 +1733,12 @@ mod tests {
let (hugr, func_root) = build_hugr().unwrap();
let func = hugr.with_entrypoint(func_root);
assert_eq!(
SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)
.unwrap()
.nodes()
.len(),
SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(
RootChecked::try_new(&func).expect("Root should be FuncDefn.")
)
.unwrap()
.nodes()
.len(),
4
);
}
Expand Down Expand Up @@ -1848,8 +1851,10 @@ mod tests {
fn preserve_signature() {
let (hugr, func_root) = build_hugr_classical().unwrap();
let func_graph = hugr.with_entrypoint(func_root);
let func =
SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
let func = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(
RootChecked::try_new(&func_graph).expect("Root should be FuncDefn."),
)
.unwrap();
let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap();
assert_eq!(func_defn.signature(), &func.signature(&func_graph).into());
}
Expand All @@ -1858,8 +1863,10 @@ mod tests {
fn extract_subgraph() {
let (hugr, func_root) = build_hugr().unwrap();
let func_graph = hugr.with_entrypoint(func_root);
let subgraph =
SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
let subgraph = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(
RootChecked::try_new(&func_graph).expect("Root should be FuncDefn."),
)
.unwrap();
let extracted = subgraph.extract_subgraph(&hugr, "region");

extracted.validate().unwrap();
Expand All @@ -1883,7 +1890,10 @@ mod tests {
.outputs();
let outw = [outw1].into_iter().chain(outw2);
let h = builder.finish_hugr_with_outputs(outw).unwrap();
let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&h).unwrap();
let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(
RootChecked::try_new(&h).expect("Root should be DFG."),
)
.unwrap();
assert_eq!(subg.nodes().len(), 2);
}

Expand Down Expand Up @@ -2178,9 +2188,10 @@ mod tests {

#[rstest]
fn test_call_subgraph_from_dfg(hugr_call_subgraph: Hugr) {
let subg =
SiblingSubgraph::try_new_dataflow_subgraph::<_, DataflowParentID>(&hugr_call_subgraph)
.unwrap();
let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DataflowParentID>(
RootChecked::try_new(&hugr_call_subgraph).expect("Root should be DFG container."),
)
.unwrap();

assert_eq!(subg.function_calls.len(), 1);
assert_eq!(subg.function_calls[0].len(), 2);
Expand Down
10 changes: 5 additions & 5 deletions hugr-passes/src/half_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::hash::Hash;
use super::nest_cfgs::CfgNodeMap;

use hugr_core::hugr::internal::HugrInternals;
use hugr_core::hugr::views::RootCheckable;
use hugr_core::hugr::views::RootChecked;
use hugr_core::ops::handle::CfgID;
use hugr_core::ops::{OpTag, OpTrait};
use hugr_core::{Direction, HugrView, Node};
Expand Down Expand Up @@ -32,9 +32,8 @@ struct HalfNodeView<H: HugrInternals> {

impl<H: HugrView> HalfNodeView<H> {
#[allow(unused)]
pub(crate) fn new(h: impl RootCheckable<H, CfgID<H::Node>>) -> Self {
let checked = h.try_into_checked().expect("Hugr must be a CFG region");
let h = checked.into_hugr();
pub(crate) fn new(h: RootChecked<H, CfgID<H::Node>>) -> Self {
let h = h.into_hugr();

let (entry, exit) = {
let mut children = h.children(h.entrypoint());
Expand Down Expand Up @@ -99,6 +98,7 @@ mod test {
use super::super::nest_cfgs::{EdgeClassifier, test::*};
use super::{HalfNode, HalfNodeView};
use hugr_core::builder::BuildError;
use hugr_core::hugr::views::RootChecked;
use hugr_core::ops::handle::NodeHandle;

use itertools::Itertools;
Expand All @@ -118,7 +118,7 @@ mod test {
// \---<---<---<---<---<---<---<---<---<---/
// Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example)

let v = HalfNodeView::new(&h);
let v = HalfNodeView::new(RootChecked::try_new(&h).expect("Root should be CFG."));

let edge_classes = EdgeClassifier::get_edge_classes(&v);
let HalfNodeView { h: _, entry, exit } = v;
Expand Down
9 changes: 3 additions & 6 deletions hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub mod untuple;
/// Merge basic blocks. Subset of [normalize_cfgs], use the latter.
#[deprecated(note = "Use normalize_cfgs", since = "0.23.0")]
pub mod merge_bbs {
use hugr_core::hugr::{hugrmut::HugrMut, views::RootCheckable};
use hugr_core::hugr::{hugrmut::HugrMut, views::RootChecked};
use hugr_core::ops::handle::CfgID;

/// Merge any basic blocks that are direct children of the specified CFG
Expand All @@ -38,11 +38,8 @@ pub mod merge_bbs {
///
/// [OpType::CFG]: hugr_core::ops::OpType::CFG
#[deprecated(note = "Use version in normalize_cfgs", since = "0.23.0")]
pub fn merge_basic_blocks<'h, H: 'h + HugrMut>(
cfg: impl RootCheckable<&'h mut H, CfgID<H::Node>>,
) {
let checked = cfg.try_into_checked().expect("Hugr must be a CFG region");
super::normalize_cfgs::merge_basic_blocks(checked.into_hugr()).unwrap();
pub fn merge_basic_blocks<'h, H: 'h + HugrMut>(cfg: RootChecked<&'h mut H, CfgID<H::Node>>) {
super::normalize_cfgs::merge_basic_blocks(cfg.into_hugr()).unwrap();
}
}

Expand Down
22 changes: 9 additions & 13 deletions hugr-passes/src/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ use itertools::Itertools;
use thiserror::Error;

use hugr_core::hugr::patch::outline_cfg::OutlineCfg;
use hugr_core::hugr::views::{HugrView, RootCheckable};
use hugr_core::hugr::views::HugrView;
use hugr_core::hugr::{Patch, hugrmut::HugrMut};
use hugr_core::ops::OpTag;
use hugr_core::ops::OpTrait;
use hugr_core::ops::handle::CfgID;
use hugr_core::{Direction, Hugr, Node};

/// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into
Expand Down Expand Up @@ -220,10 +219,7 @@ pub struct IdentityCfgMap<H: HugrView> {
}
impl<H: HugrView> IdentityCfgMap<H> {
/// Creates an [`IdentityCfgMap`] for the specified CFG
pub fn new(h: impl RootCheckable<H, CfgID<H::Node>>) -> Self {
let h = h.try_into_checked().expect("Hugr must be a CFG region");
let h = h.into_hugr();

pub fn new(h: H) -> Self {
// Panic if malformed enough not to have two children
let (entry, exit) = h.children(h.entrypoint()).take(2).collect_tuple().unwrap();
debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit);
Expand Down Expand Up @@ -582,7 +578,6 @@ pub(crate) mod test {

use hugr_core::Node;
use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError};
use hugr_core::hugr::views::RootChecked;
use hugr_core::ops::Value;
use hugr_core::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use hugr_core::types::{EdgeKind, Signature};
Expand Down Expand Up @@ -632,11 +627,11 @@ pub(crate) mod test {
let exit = cfg_builder.exit_block();
cfg_builder.branch(&tail, 0, &exit)?;

let mut h = cfg_builder.finish_hugr()?;
let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap();
let h = cfg_builder.finish_hugr()?;
let (entry, exit) = (entry.node(), exit.node());
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.as_ref()));
let mut v = IdentityCfgMap::new(h);
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let [&left, &right] = edge_classes
.keys()
.filter(|(s, _)| *s == split)
Expand All @@ -657,7 +652,8 @@ pub(crate) mod test {
sorted([(entry, split), (merge, head), (tail, exit)]), // Two regions, conditional and then loop.
])
);
transform_cfg_to_nested(&mut IdentityCfgMap::new(rc));
transform_cfg_to_nested(&mut v);
let h = v.h;
h.validate().unwrap();
assert_eq!(3, depth(&h, entry));
assert_eq!(3, depth(&h, exit));
Expand Down Expand Up @@ -693,7 +689,7 @@ pub(crate) mod test {
.try_into()
.unwrap();

let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap());
let v = IdentityCfgMap::new(h);
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let [&left, &right] = edge_classes
.keys()
Expand Down Expand Up @@ -783,7 +779,7 @@ pub(crate) mod test {
// Here we would like an indication that we can make two nested regions,
// but there is no edge to act as entry to a region containing just the conditional :-(.

let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap());
let v = IdentityCfgMap::new(h);
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let IdentityCfgMap { h: _, entry, exit } = v;
// merge is unique predecessor of tail
Expand Down
7 changes: 3 additions & 4 deletions hugr-persistent/src/walker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ use hugr_core::ops::handle::DataflowParentID;
use itertools::{Either, Itertools};
use thiserror::Error;

use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootCheckable};
use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootChecked};

use crate::{Commit, PersistentReplacement, PinnedSubgraph};

Expand Down Expand Up @@ -302,7 +302,7 @@ impl<'a> Walker<'a> {
pub fn try_create_commit(
&self,
subgraph: impl Into<PinnedSubgraph>,
repl: impl RootCheckable<Hugr, DataflowParentID>,
mut repl: RootChecked<Hugr, DataflowParentID>,
map_boundary: impl Fn(PatchNode, Port) -> Port,
) -> Result<Commit<'a>, InvalidCommit> {
let pinned_subgraph = subgraph.into();
Expand All @@ -312,7 +312,6 @@ impl<'a> Walker<'a> {
.map(|id| self.selected_commits.get_commit(id).clone());

let repl = {
let mut repl = repl.try_into_checked().expect("replacement is not DFG");
let new_inputs = subgraph
.incoming_ports()
.iter()
Expand Down Expand Up @@ -783,7 +782,7 @@ mod tests {
let commit = walker
.try_create_commit(
PinnedSubgraph::try_from_pinned(std::iter::empty(), [wire], &walker).unwrap(),
empty_hugr,
RootChecked::try_new(empty_hugr).expect("Root should be DFG."),
|node, port| {
assert_eq!(port.index(), 0);
assert!([not0, not2].contains(&node));
Expand Down
5 changes: 3 additions & 2 deletions hugr-persistent/tests/persistent_walker_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use hugr_core::{
Hugr, HugrView, IncomingPort, OutgoingPort, Port, PortIndex,
builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig},
extension::prelude::qb_t,
hugr::views::RootChecked,
ops::OpType,
types::EdgeKind,
};
Expand Down Expand Up @@ -251,7 +252,7 @@ fn create_commit<'a>(wire: PersistentWire, walker: &Walker<'a>) -> Option<Commit
// Create the commit
walker.try_create_commit(
PinnedSubgraph::try_from_wires(wires, walker).unwrap(),
empty_2qb_hugr(add_swap),
RootChecked::try_new(empty_2qb_hugr(add_swap)).expect("Root should be DFG."),
|_, port| {
// the incoming/outgoing ports of the subgraph map trivially to the empty 2qb
// HUGR
Expand All @@ -273,7 +274,7 @@ fn create_commit<'a>(wire: PersistentWire, walker: &Walker<'a>) -> Option<Commit

walker.try_create_commit(
PinnedSubgraph::try_from_wires([wire], walker).unwrap(),
repl_hugr,
RootChecked::try_new(repl_hugr).expect("Root should be DFG."),
|node, port| {
// map the incoming/outgoing ports of the subgraph to the replacement as
// follows:
Expand Down
Loading