Skip to content

Commit 9d91c6e

Browse files
zrhoaborgna-q
andauthored
feat: Export and import entrypoints via metadata in hugr-model. (#2172)
This PR adds code to import and export the entrypoint field of modules to the `hugr-model` representation. It assumes that only regions should be made the entrypoint. The PR also fixes some bugs in the parser that prevented packages with multiple modules to be parsed correctly. --------- Co-authored-by: Agustín Borgna <[email protected]>
1 parent 7c101d9 commit 9d91c6e

File tree

9 files changed

+227
-58
lines changed

9 files changed

+227
-58
lines changed

hugr-core/src/export.rs

+71-41
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ impl<'a> Context<'a> {
120120
self.symbols.enter(self.module.root);
121121
self.links.enter(self.module.root);
122122

123-
let hugr_children = self.hugr.children(self.hugr.entrypoint());
123+
let hugr_children = self.hugr.children(self.hugr.module_root());
124124
let mut children = Vec::with_capacity(hugr_children.size_hint().0);
125125

126126
for child in hugr_children.clone() {
@@ -291,9 +291,11 @@ impl<'a> Context<'a> {
291291
}
292292

293293
OpType::DFG(_) => {
294-
regions = self
295-
.bump
296-
.alloc_slice_copy(&[self.export_dfg(node, model::ScopeClosure::Open)]);
294+
regions = self.bump.alloc_slice_copy(&[self.export_dfg(
295+
node,
296+
model::ScopeClosure::Open,
297+
false,
298+
)]);
297299
table::Operation::Dfg
298300
}
299301

@@ -313,18 +315,22 @@ impl<'a> Context<'a> {
313315
}
314316

315317
OpType::DataflowBlock(_) => {
316-
regions = self
317-
.bump
318-
.alloc_slice_copy(&[self.export_dfg(node, model::ScopeClosure::Open)]);
318+
regions = self.bump.alloc_slice_copy(&[self.export_dfg(
319+
node,
320+
model::ScopeClosure::Open,
321+
false,
322+
)]);
319323
table::Operation::Block
320324
}
321325

322326
OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| {
323327
let name = this.get_func_name(node).unwrap();
324328
let symbol = this.export_poly_func_type(name, &func.signature);
325-
regions = this
326-
.bump
327-
.alloc_slice_copy(&[this.export_dfg(node, model::ScopeClosure::Closed)]);
329+
regions = this.bump.alloc_slice_copy(&[this.export_dfg(
330+
node,
331+
model::ScopeClosure::Closed,
332+
false,
333+
)]);
328334
table::Operation::DefineFunc(symbol)
329335
}),
330336

@@ -426,9 +432,11 @@ impl<'a> Context<'a> {
426432
}
427433

428434
OpType::TailLoop(_) => {
429-
regions = self
430-
.bump
431-
.alloc_slice_copy(&[self.export_dfg(node, model::ScopeClosure::Open)]);
435+
regions = self.bump.alloc_slice_copy(&[self.export_dfg(
436+
node,
437+
model::ScopeClosure::Open,
438+
false,
439+
)]);
432440
table::Operation::TailLoop
433441
}
434442

@@ -479,7 +487,12 @@ impl<'a> Context<'a> {
479487
let inputs = self.make_ports(node, Direction::Incoming, num_inputs);
480488
let outputs = self.make_ports(node, Direction::Outgoing, num_outputs);
481489

482-
let meta = self.export_node_metadata(node);
490+
let meta = {
491+
let mut meta = Vec::new();
492+
self.export_node_json_metadata(node, &mut meta);
493+
self.export_node_order_metadata(node, &mut meta);
494+
self.bump.alloc_slice_copy(&meta)
495+
};
483496

484497
self.module.nodes[node_id.index()] = table::Node {
485498
operation,
@@ -532,7 +545,7 @@ impl<'a> Context<'a> {
532545
}
533546

534547
for (name, value) in opdef.iter_misc() {
535-
meta.push(self.export_json_meta(name, value));
548+
meta.push(self.make_json_meta(name, value));
536549
}
537550

538551
self.bump.alloc_slice_copy(&meta)
@@ -574,7 +587,12 @@ impl<'a> Context<'a> {
574587
/// Creates a data flow region from the given node's children.
575588
///
576589
/// `Input` and `Output` nodes are used to determine the source and target ports of the region.
577-
pub fn export_dfg(&mut self, node: Node, closure: model::ScopeClosure) -> table::RegionId {
590+
pub fn export_dfg(
591+
&mut self,
592+
node: Node,
593+
closure: model::ScopeClosure,
594+
export_json_meta: bool,
595+
) -> table::RegionId {
578596
let region = self.module.insert_region(table::Region::default());
579597

580598
self.symbols.enter(region);
@@ -586,8 +604,14 @@ impl<'a> Context<'a> {
586604
let mut targets: &[_] = &[];
587605
let mut input_types = None;
588606
let mut output_types = None;
607+
589608
let mut meta = Vec::new();
590609

610+
if export_json_meta {
611+
self.export_node_json_metadata(node, &mut meta);
612+
}
613+
self.export_node_entrypoint_metadata(node, &mut meta);
614+
591615
let children = self.hugr.children(node);
592616
let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump);
593617

@@ -672,6 +696,10 @@ impl<'a> Context<'a> {
672696
let mut source = None;
673697
let mut targets: &[_] = &[];
674698

699+
let mut meta = Vec::new();
700+
self.export_node_json_metadata(node, &mut meta);
701+
self.export_node_entrypoint_metadata(node, &mut meta);
702+
675703
let children = self.hugr.children(node);
676704
let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 1, self.bump);
677705

@@ -728,7 +756,7 @@ impl<'a> Context<'a> {
728756
sources: self.bump.alloc_slice_copy(&[source.unwrap()]),
729757
targets,
730758
children: region_children.into_bump_slice(),
731-
meta: &[], // TODO: Export metadata
759+
meta: self.bump.alloc_slice_copy(&meta),
732760
signature,
733761
scope,
734762
};
@@ -746,7 +774,7 @@ impl<'a> Context<'a> {
746774
panic!("expected a `Case` node as a child of a `Conditional` node");
747775
};
748776

749-
regions.push(self.export_dfg(child, model::ScopeClosure::Open));
777+
regions.push(self.export_dfg(child, model::ScopeClosure::Open, true));
750778
}
751779

752780
regions.into_bump_slice()
@@ -1000,7 +1028,7 @@ impl<'a> Context<'a> {
10001028

10011029
let region = match hugr.entrypoint_optype() {
10021030
OpType::DFG(_) => {
1003-
self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed)
1031+
self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed, true)
10041032
}
10051033
_ => panic!("Value::Function root must be a DFG"),
10061034
};
@@ -1031,41 +1059,43 @@ impl<'a> Context<'a> {
10311059
}
10321060
}
10331061

1034-
pub fn export_node_metadata(&mut self, node: Node) -> &'a [table::TermId] {
1062+
fn export_node_json_metadata(&mut self, node: Node, meta: &mut Vec<table::TermId>) {
10351063
let metadata_map = self.hugr.node_metadata_map(node);
1064+
meta.reserve(metadata_map.len());
10361065

1037-
let has_order_edges = {
1038-
fn is_relevant_node(hugr: &Hugr, node: Node) -> bool {
1039-
let optype = hugr.get_optype(node);
1040-
!optype.is_input() && !optype.is_output()
1041-
}
1042-
1043-
let optype = self.hugr.get_optype(node);
1066+
for (name, value) in metadata_map {
1067+
meta.push(self.make_json_meta(name, value));
1068+
}
1069+
}
10441070

1045-
Direction::BOTH
1046-
.iter()
1047-
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
1048-
.filter_map(|dir| optype.other_port(*dir))
1049-
.flat_map(|port| self.hugr.linked_ports(node, port))
1050-
.any(|(other, _)| is_relevant_node(self.hugr, other))
1051-
};
1071+
fn export_node_order_metadata(&mut self, node: Node, meta: &mut Vec<table::TermId>) {
1072+
fn is_relevant_node(hugr: &Hugr, node: Node) -> bool {
1073+
let optype = hugr.get_optype(node);
1074+
!optype.is_input() && !optype.is_output()
1075+
}
10521076

1053-
let meta_capacity = metadata_map.len() + has_order_edges as usize;
1054-
let mut meta = BumpVec::with_capacity_in(meta_capacity, self.bump);
1077+
let optype = self.hugr.get_optype(node);
10551078

1056-
for (name, value) in metadata_map {
1057-
meta.push(self.export_json_meta(name, value));
1058-
}
1079+
let has_order_edges = Direction::BOTH
1080+
.iter()
1081+
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
1082+
.filter_map(|dir| optype.other_port(*dir))
1083+
.flat_map(|port| self.hugr.linked_ports(node, port))
1084+
.any(|(other, _)| is_relevant_node(self.hugr, other));
10591085

10601086
if has_order_edges {
10611087
let key = self.make_term(model::Literal::Nat(node.index() as u64).into());
10621088
meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key]));
10631089
}
1090+
}
10641091

1065-
meta.into_bump_slice()
1092+
fn export_node_entrypoint_metadata(&mut self, node: Node, meta: &mut Vec<table::TermId>) {
1093+
if self.hugr.entrypoint() == node {
1094+
meta.push(self.make_term_apply(model::CORE_ENTRYPOINT, &[]));
1095+
}
10661096
}
10671097

1068-
pub fn export_json_meta(&mut self, name: &str, value: &serde_json::Value) -> table::TermId {
1098+
pub fn make_json_meta(&mut self, name: &str, value: &serde_json::Value) -> table::TermId {
10691099
let value = serde_json::to_string(value).expect("json values are always serializable");
10701100
let value = self.make_term(model::Literal::Str(value.into()).into());
10711101
let name = self.make_term(model::Literal::Str(name.into()).into());

hugr-core/src/import.rs

+29-10
Original file line numberDiff line numberDiff line change
@@ -216,30 +216,41 @@ impl<'a> Context<'a> {
216216
self.record_links(node, Direction::Incoming, node_data.inputs);
217217
self.record_links(node, Direction::Outgoing, node_data.outputs);
218218

219-
// Import the JSON metadata
220219
for meta_item in node_data.meta {
221-
let Some([name_arg, json_arg]) =
222-
self.match_symbol(*meta_item, model::COMPAT_META_JSON)?
223-
else {
224-
continue;
225-
};
220+
self.import_node_metadata(node, *meta_item)?;
221+
}
226222

223+
Ok(node)
224+
}
225+
226+
fn import_node_metadata(
227+
&mut self,
228+
node: Node,
229+
meta_item: table::TermId,
230+
) -> Result<(), ImportError> {
231+
// Import the JSON metadata
232+
if let Some([name_arg, json_arg]) = self.match_symbol(meta_item, model::COMPAT_META_JSON)? {
227233
let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? else {
228-
return Err(table::ModelError::TypeError(*meta_item).into());
234+
return Err(table::ModelError::TypeError(meta_item).into());
229235
};
230236

231237
let table::Term::Literal(model::Literal::Str(json_str)) = self.get_term(json_arg)?
232238
else {
233-
return Err(table::ModelError::TypeError(*meta_item).into());
239+
return Err(table::ModelError::TypeError(meta_item).into());
234240
};
235241

236242
let json_value: NodeMetadata = serde_json::from_str(json_str)
237-
.map_err(|_| table::ModelError::TypeError(*meta_item))?;
243+
.map_err(|_| table::ModelError::TypeError(meta_item))?;
238244

239245
self.hugr.set_metadata(node, name, json_value);
240246
}
241247

242-
Ok(node)
248+
// Set the entrypoint
249+
if let Some([]) = self.match_symbol(meta_item, model::CORE_ENTRYPOINT)? {
250+
self.hugr.set_entrypoint(node);
251+
}
252+
253+
Ok(())
243254
}
244255

245256
/// Associate links with the ports of the given node in the given direction.
@@ -654,6 +665,10 @@ impl<'a> Context<'a> {
654665

655666
self.create_order_edges(region)?;
656667

668+
for meta_item in region_data.meta {
669+
self.import_node_metadata(node, *meta_item)?;
670+
}
671+
657672
self.region_scope = prev_region;
658673

659674
Ok(())
@@ -956,6 +971,10 @@ impl<'a> Context<'a> {
956971
self.record_links(exit, Direction::Incoming, region_data.targets);
957972
}
958973

974+
for meta_item in region_data.meta {
975+
self.import_node_metadata(node, *meta_item)?;
976+
}
977+
959978
self.region_scope = prev_region;
960979

961980
Ok(())

hugr-core/tests/model.rs

+8
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,11 @@ pub fn test_roundtrip_order() {
9595
"../../hugr-model/tests/fixtures/model-order.edn"
9696
)));
9797
}
98+
99+
#[test]
100+
#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
101+
pub fn test_roundtrip_entrypoint() {
102+
insta::assert_snapshot!(roundtrip(include_str!(
103+
"../../hugr-model/tests/fixtures/model-entrypoint.edn"
104+
)));
105+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
---
2+
source: hugr-core/tests/model.rs
3+
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entrypoint.edn\"))"
4+
---
5+
(hugr 0)
6+
7+
(mod)
8+
9+
(import core.fn)
10+
11+
(import core.entrypoint)
12+
13+
(define-func main (core.fn [] [])
14+
(dfg (signature (core.fn [] [])) (meta core.entrypoint)))
15+
16+
(mod)
17+
18+
(import core.fn)
19+
20+
(import core.entrypoint)
21+
22+
(define-func wrapper_dfg (core.fn [] [])
23+
(dfg (signature (core.fn [] [])) (meta core.entrypoint)))
24+
25+
(mod)
26+
27+
(import core.make_adt)
28+
29+
(import core.ctrl)
30+
31+
(import core.fn)
32+
33+
(import core.entrypoint)
34+
35+
(import core.adt)
36+
37+
(define-func wrapper_cfg (core.fn [] [])
38+
(dfg
39+
(signature (core.fn [] []))
40+
(cfg
41+
(signature (core.fn [] []))
42+
(cfg [%0] [%1]
43+
(signature (core.fn [(core.ctrl [])] [(core.ctrl [])]))
44+
(meta core.entrypoint)
45+
(block [%0] [%1]
46+
(signature (core.fn [(core.ctrl [])] [(core.ctrl [])]))
47+
(dfg [] [%2]
48+
(signature (core.fn [] [(core.adt [[]])]))
49+
((core.make_adt _ _ 0) [] [%2]
50+
(signature (core.fn [] [(core.adt [[]])])))))))))

0 commit comments

Comments
 (0)