diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 096a0d9d7e..8c4f2a27a4 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -1140,12 +1140,12 @@ impl<'a> Context<'a> { /// Import a node with a custom operation. /// - /// A custom operation in `hugr-model` referred to by a symbol application - /// term. The name of the symbol specifies the name of the custom operation, - /// and the arguments supplied to the symbol are the arguments to be passed - /// to the custom operation. This method imports the custom operations as - /// [`OpaqueOp`]s. The [`OpaqueOp`]s are then resolved later against the - /// [`ExtensionRegistry`]. + /// A custom operation in `hugr-model` is referred to by a symbol + /// application term. The name of the symbol specifies the name of the + /// custom operation, and the arguments supplied to the symbol are the + /// arguments to be passed to the custom operation. This method imports the + /// custom operations as [`OpaqueOp`]s. The [`OpaqueOp`]s are then resolved + /// later against the [`ExtensionRegistry`]. /// /// Some operations that needed to be builtins in `hugr-core` are custom /// operations in `hugr-model`. This method detects these and converts them diff --git a/hugr-model/src/v0/binary/mod.rs b/hugr-model/src/v0/binary/mod.rs index c871181643..1fdca95d9e 100644 --- a/hugr-model/src/v0/binary/mod.rs +++ b/hugr-model/src/v0/binary/mod.rs @@ -10,5 +10,5 @@ mod read; mod write; -pub use read::{ReadError, read_from_reader, read_from_slice}; +pub use read::{ReadError, read_from_reader, read_from_slice, read_from_slice_with_suffix}; pub use write::{WriteError, write_to_vec, write_to_writer}; diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index c5f65463ba..ae44ebb07d 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -3,7 +3,7 @@ use crate::v0::table; use crate::{CURRENT_VERSION, v0 as model}; use bumpalo::Bump; use bumpalo::collections::Vec as BumpVec; -use std::io::BufRead; +use std::io::{BufRead, BufReader, Read}; /// An error encountered while deserialising a model. #[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)] @@ -27,10 +27,27 @@ pub enum ReadError { type ReadResult = Result; /// Read a hugr package from a byte slice. +/// +/// If the slice contains bytes beyond the encoded package, they are ignored. pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult> { read_from_reader(slice, bump) } +/// Read a hugr package from a byte slice. +/// +/// If the slice contains bytes beyond the encoded package, they are returned +/// along with the decoded package. +pub fn read_from_slice_with_suffix<'a>( + slice: &[u8], + bump: &'a Bump, +) -> ReadResult<(table::Package<'a>, Vec)> { + let mut buffer = BufReader::new(slice); + let package = read_from_reader(&mut buffer, bump)?; + let mut suffix: Vec = vec![]; + buffer.read_to_end(&mut suffix)?; + Ok((package, suffix)) +} + /// Read a hugr package from an impl of [`BufRead`]. pub fn read_from_reader(reader: impl BufRead, bump: &Bump) -> ReadResult> { let reader = diff --git a/hugr-py/rust/lib.rs b/hugr-py/rust/lib.rs index 7bb8dacf84..c24c04bc4a 100644 --- a/hugr-py/rust/lib.rs +++ b/hugr-py/rust/lib.rs @@ -74,14 +74,14 @@ fn package_to_bytes(package: ast::Package) -> PyResult> { } #[pyfunction] -fn bytes_to_package(bytes: &[u8]) -> PyResult { +fn bytes_to_package(bytes: &[u8]) -> PyResult<(ast::Package, Vec)> { let bump = bumpalo::Bump::new(); - let table = hugr_model::v0::binary::read_from_slice(bytes, &bump) + let (table, suffix) = hugr_model::v0::binary::read_from_slice_with_suffix(bytes, &bump) .map_err(|err| PyValueError::new_err(err.to_string()))?; let package = table .as_ast() .ok_or_else(|| PyValueError::new_err("Malformed package"))?; - Ok(package) + Ok((package, suffix)) } /// Returns the current version of the HUGR model format as a tuple of (major, minor, patch). diff --git a/hugr-py/src/hugr/_hugr/__init__.pyi b/hugr-py/src/hugr/_hugr/__init__.pyi index 4287361c56..f939d8c330 100644 --- a/hugr-py/src/hugr/_hugr/__init__.pyi +++ b/hugr-py/src/hugr/_hugr/__init__.pyi @@ -20,7 +20,7 @@ def bytes_to_module(binary: bytes) -> hugr.model.Module: ... def package_to_string(package: hugr.model.Package) -> str: ... def string_to_package(string: str) -> hugr.model.Package: ... def package_to_bytes(package: hugr.model.Package) -> bytes: ... -def bytes_to_package(binary: bytes) -> hugr.model.Package: ... +def bytes_to_package(binary: bytes) -> tuple[hugr.model.Package, bytes]: ... def current_model_version() -> tuple[int, int, int]: ... def to_json_envelope(binary: bytes) -> bytes: ... def run_cli() -> None: ... diff --git a/hugr-py/src/hugr/envelope.py b/hugr-py/src/hugr/envelope.py index 088e0ae4d9..00b1d5bba9 100644 --- a/hugr-py/src/hugr/envelope.py +++ b/hugr-py/src/hugr/envelope.py @@ -39,7 +39,7 @@ import pyzstd -from hugr import cli +import hugr._hugr as rust if TYPE_CHECKING: from hugr.hugr.base import Hugr @@ -64,7 +64,6 @@ def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes: if not isinstance(package, Package): package = Package(modules=[package], extensions=[]) - # Currently only uncompressed JSON is supported. payload: bytes match config.format: case EnvelopeFormat.JSON: @@ -103,6 +102,7 @@ def make_envelope_str(package: Package | Hugr, config: EnvelopeConfig) -> str: def read_envelope(envelope: bytes) -> Package: """Decode a HUGR package from an envelope.""" import hugr._serialization.extension as ext_s + from hugr.package import Package header = EnvelopeHeader.from_bytes(envelope) payload = envelope[10:] @@ -113,12 +113,23 @@ def read_envelope(envelope: bytes) -> Package: match header.format: case EnvelopeFormat.JSON: return ext_s.Package.model_validate_json(payload).deserialize() - case EnvelopeFormat.MODEL | EnvelopeFormat.MODEL_WITH_EXTS: - # TODO Going via JSON is a temporary solution, until we get model import to - # python properly implemented. - # https://github.com/CQCL/hugr/issues/2287 - json_data = cli.convert(envelope, format="json") - return read_envelope(json_data) + case EnvelopeFormat.MODEL: + model_package, suffix = rust.bytes_to_package(payload) + if suffix: + msg = f"Excess bytes in envelope with format {EnvelopeFormat.MODEL}." + raise ValueError(msg) + return Package.from_model(model_package) + case EnvelopeFormat.MODEL_WITH_EXTS: + from hugr.ext import Extension + + model_package, suffix = rust.bytes_to_package(payload) + return Package( + modules=Package.from_model(model_package).modules, + extensions=[ + Extension.from_json(json.dumps(extension)) + for extension in json.loads(suffix) + ], + ) def read_envelope_hugr(envelope: bytes) -> Hugr: diff --git a/hugr-py/src/hugr/hugr/base.py b/hugr-py/src/hugr/hugr/base.py index 38e90a34cf..1ab0bd7207 100644 --- a/hugr-py/src/hugr/hugr/base.py +++ b/hugr-py/src/hugr/hugr/base.py @@ -1114,6 +1114,19 @@ def from_str(envelope: str) -> Hugr: """ return read_envelope_hugr_str(envelope) + @staticmethod + def from_model(module: model.Module) -> Hugr: + """Import from the hugr model format.""" + from hugr.model.load import ModelImport + + loader = ModelImport(module=module) + for i, node in enumerate(module.root.children): + loader.import_node_in_module(node, i) + loader.link_ports() + loader.link_static_ports() + loader.add_module_metadata() + return loader.hugr + def to_bytes(self, config: EnvelopeConfig | None = None) -> bytes: """Serialize the HUGR into an envelope byte string. diff --git a/hugr-py/src/hugr/model/__init__.py b/hugr-py/src/hugr/model/__init__.py index ac97177cd2..ba8f4b20ac 100644 --- a/hugr-py/src/hugr/model/__init__.py +++ b/hugr-py/src/hugr/model/__init__.py @@ -1,9 +1,9 @@ """HUGR model data structures.""" -from collections.abc import Sequence +import warnings +from collections.abc import Generator, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Protocol from semver import Version @@ -21,7 +21,7 @@ def _current_version() -> Version: CURRENT_VERSION: Version = _current_version() -class Term(Protocol): +class Term: """A model term for static data such as types, constants and metadata.""" def __str__(self) -> str: @@ -33,6 +33,26 @@ def from_str(s: str) -> "Term": """Read the term from its string representation.""" return rust.string_to_term(s) + def to_list_parts(self) -> Generator["SeqPart"]: + if isinstance(self, List): + for part in self.parts: + if isinstance(part, Splice): + yield from part.seq.to_list_parts() + else: + yield part + else: + yield Splice(self) + + def to_tuple_parts(self) -> Generator["SeqPart"]: + if isinstance(self, Tuple): + for part in self.parts: + if isinstance(part, Splice): + yield from part.seq.to_tuple_parts() + else: + yield part + else: + yield Splice(self) + @dataclass(frozen=True) class Wildcard(Term): @@ -129,9 +149,13 @@ def from_str(s: str) -> "Symbol": return rust.string_to_symbol(s) -class Op(Protocol): +class Op: """The operation of a node.""" + def symbol_name(self) -> str | None: + """Returns name of the symbol introduced by this node, if any.""" + return None + @dataclass(frozen=True) class InvalidOp(Op): @@ -159,6 +183,9 @@ class DefineFunc(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class DeclareFunc(Op): @@ -166,6 +193,9 @@ class DeclareFunc(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class CustomOp(Op): @@ -181,6 +211,9 @@ class DefineAlias(Op): symbol: Symbol value: Term + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class DeclareAlias(Op): @@ -188,6 +221,9 @@ class DeclareAlias(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class TailLoop(Op): @@ -205,6 +241,9 @@ class DeclareConstructor(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class DeclareOperation(Op): @@ -212,6 +251,9 @@ class DeclareOperation(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class Import(Op): @@ -219,6 +261,9 @@ class Import(Op): name: str + def symbol_name(self) -> str | None: + return self.name + @dataclass class Node: @@ -302,7 +347,13 @@ def from_str(s: str) -> "Package": @staticmethod def from_bytes(b: bytes) -> "Package": """Read a package from its binary representation.""" - return rust.bytes_to_package(b) + package, suffix = rust.bytes_to_package(b) + if suffix: + warnings.warn( + "Binary encoding of model Package contains extra bytes: ignoring them.", + stacklevel=2, + ) + return package @property def version(self) -> Version: diff --git a/hugr-py/src/hugr/model/load.py b/hugr-py/src/hugr/model/load.py new file mode 100644 index 0000000000..0cbb9edd7b --- /dev/null +++ b/hugr-py/src/hugr/model/load.py @@ -0,0 +1,1117 @@ +"""Helpers to import hugr graphs from hugr model to their python representation.""" + +import json +from collections.abc import Generator, Iterable +from dataclasses import dataclass, field +from typing import Any, cast + +import hugr.model as model +from hugr import val +from hugr.hugr import InPort, OutPort +from hugr.hugr.base import Hugr +from hugr.hugr.node_port import Node +from hugr.ops import ( + CFG, + DFG, + AliasDecl, + AliasDefn, + Call, + CallIndirect, + Case, + Conditional, + Const, + Custom, + DataflowBlock, + ExitBlock, + FuncDecl, + FuncDefn, + Input, + LoadConst, + LoadFunc, + MakeTuple, + Op, + Output, + Tag, + TailLoop, + UnpackTuple, +) +from hugr.std.collections.array import ArrayVal +from hugr.std.float import FloatVal +from hugr.std.int import IntVal +from hugr.tys import ( + BoundedNatArg, + BoundedNatParam, + BytesArg, + BytesParam, + ConstParam, + FloatArg, + FloatParam, + FunctionType, + ListArg, + ListConcatArg, + ListParam, + Opaque, + PolyFuncType, + RowVariable, + StringArg, + StringParam, + Sum, + Tuple, + TupleArg, + TupleConcatArg, + TupleParam, + Type, + TypeArg, + TypeBound, + TypeParam, + TypeTypeArg, + TypeTypeParam, + Variable, + _QubitDef, +) + +ImportContext = model.Term | model.Node | model.Region | str + + +class ModelImportError(Exception): + """Exception raised when importing from the model representation fails.""" + + def __init__(self, message: str, location: ImportContext | None = None): + self.message = message + self.location = location + + match location: + case model.Term() as term: + location_error = f"Error caused by term:\n```\n{term}\n```" + case model.Region() as region: + location_error = f"Error caused by region:\n```\n{region}\n```" + case model.Node() as node: + location_error = f"Error caused by node:\n```\n{node}\n```" + case str() as other: + location_error = other + case None: + location_error = "Error in unspecified location." + + super().__init__(f"{message}\n{location_error}") + + +def _collect_meta_json(node: model.Node) -> dict[str, Any]: + """Collects the `core.meta_json` metadata on the given node.""" + metadata = {} + + for meta in node.meta: + match meta: + case model.Apply( + "compat.meta_json", + [model.Literal(str() as key), model.Literal(str() as value)], + ): + pass + case _: + continue + + try: + decoded = json.loads(value) + except json.JSONDecodeError as err: + error = "Failed to decode JSON metadata." + raise ModelImportError(error, node) from err + + metadata[key] = decoded + + return metadata + + +def _find_meta_title(node: model.Node) -> str | None: + """Searches for `core.title` metadata on the given node.""" + for meta in node.meta: + match meta: + case model.Apply("core.title", [model.Literal(str() as title)]): + return title + case model.Apply("core.title"): + error = "Invalid instance of `core.title` metadata." + raise ModelImportError(error, meta) + case _: + pass + + return None + + +def _find_meta_order_region(region: model.Region) -> "RegionOrderHints": + """Searches for order hint metadata on the given region.""" + data = RegionOrderHints() + + for meta in region.meta: + match meta: + case model.Apply( + "core.order_hint.input_key", [model.Literal(int() as key)] + ): + data.input_keys.append(key) + case model.Apply( + "core.order_hint.output_key", [model.Literal(int() as key)] + ): + data.output_keys.append(key) + case model.Apply( + "core.order_hint.order", + [model.Literal(int() as before), model.Literal(int() as after)], + ): + data.edges.append((before, after)) + case _: + pass + + return data + + +def _collect_meta_order_keys(node: model.Node) -> list[int]: + """Collects all order hint keys in the metadata of a node.""" + keys = [] + + for meta in node.meta: + match meta: + case model.Apply("core.order_hint.key", [model.Literal(int() as key)]): + keys.append(key) + case _: + pass + + return keys + + +class ModelImport: + """Helper to import a Hugr.""" + + local_vars: dict[str, "LocalVarData"] + current_symbol: str | None + link_prefix: int | None + linked_ports: dict[str, tuple[list[InPort], list[OutPort]]] + static_edges: list[tuple[Node, Node]] + + module: model.Module + symbols: dict[str, model.Node] + fn_nodes: dict[str, Node] + fn_calls: list[tuple[str, Node]] + hugr: Hugr + + def __init__(self, module: model.Module): + self.local_vars = {} + self.current_symbol = None + self.module = module + self.symbols = {} + self.hugr = Hugr() + self.link_prefix = None + self.linked_ports = {} + self.static_edges = [] + self.fn_nodes = {} + self.fn_calls = [] + + for node in module.root.children: + symbol_name = node.operation.symbol_name() + + if symbol_name is None: + continue + + if symbol_name in self.symbols: + error = f"Duplicate symbol name `{symbol_name}`." + raise ModelImportError(error, node) + + self.symbols[symbol_name] = node + + def add_node( + self, node: model.Node, operation: Op, parent: Node, num_outs: int | None = None + ) -> Node: + """Add a model Node to the Hugr and record its in- and out-links.""" + node_id = self.hugr.add_node( + op=operation, + parent=parent, + num_outs=num_outs, + metadata=_collect_meta_json(node), + ) + self.record_in_links(node_id, node.inputs) + self.record_out_links(node_id, node.outputs) + if model.Apply("core.entrypoint") in node.meta: + self.hugr.entrypoint = node_id + return node_id + + def record_in_links(self, node: Node, links: Iterable[str]): + """Record a bunch of links entering the given Hugr Node with the given names.""" + for offset, link in enumerate(links): + in_port = InPort(node=node, offset=offset) + self.linked_ports.setdefault(f"{self.link_prefix}_{link}", ([], []))[ + 0 + ].append(in_port) + + def record_out_links(self, node: Node, links: Iterable[str]): + """Record a bunch of links exiting the given Hugr Node with the given names.""" + for offset, link in enumerate(links): + out_port = OutPort(node=node, offset=offset) + self.linked_ports.setdefault(f"{self.link_prefix}_{link}", ([], []))[ + 1 + ].append(out_port) + + def link_ports(self): + """Add links to the Hugr according to the recorded data.""" + for link, (in_ports, out_ports) in self.linked_ports.items(): + match in_ports, out_ports: + case [[], []]: + raise AssertionError + case _, [out_port]: + for in_port in in_ports: + self.hugr.add_link(out_port, in_port) + case [[in_port], _]: + for out_port in out_ports: + self.hugr.add_link(out_port, in_port) + case _, _: + error = f"Link `{link}` has multiple inputs and outputs." + raise ModelImportError(error) + + def link_static_ports(self): + for symbol, callnode in self.fn_calls: + self.static_edges.append((self.fn_nodes[symbol], callnode)) + for src, dst in self.static_edges: + out_port_offset = self.hugr.num_out_ports(src) - 1 + out_port = OutPort(node=src, offset=out_port_offset) + + in_port_offset = self.hugr.num_in_ports(dst) + in_port = InPort(node=dst, offset=in_port_offset) + + self.hugr.add_link(out_port, in_port) + + def add_module_metadata(self): + self.hugr[self.hugr.module_root].metadata = _collect_meta_json(self.module.root) + + def import_dfg_region(self, region: model.Region, parent: Node): + """Import an entire DFG region from the model into the Hugr.""" + signature = self.import_signature(region.signature) + + input_node = self.hugr.add_node( + Input(signature.input), parent=parent, num_outs=len(signature.input) + ) + self.record_out_links(input_node, region.sources) + + output_node = self.hugr.add_node(Output(signature.output), parent=parent) + self.record_in_links(output_node, region.targets) + + order_data = _find_meta_order_region(region) + order_data.add_node_keys(input_node, order_data.input_keys) + order_data.add_node_keys(output_node, order_data.output_keys) + + for child in region.children: + child_id = self.import_node_in_dfg(child, parent) + child_order_keys = _collect_meta_order_keys(child) + order_data.add_node_keys(child_id, child_order_keys) + + for src_key, tgt_key in order_data.edges: + src_node = order_data.get_node_by_key(src_key) + tgt_node = order_data.get_node_by_key(tgt_key) + self.hugr.add_order_link(src_node, tgt_node) + + def import_block(self, block: model.Node, parent: Node): + # 1. Add the DataFlowBlock node: + match block.signature: + case model.Apply("core.ctrl", [ctrl_inputs, ctrl_outputs]): + pass + case _: + error = f"Invalid signature for {block}." + raise ModelImportError(error) + match list(ctrl_inputs.to_list_parts()): + case [inputs]: + pass + case _: + error = f"DFB inputs should be singleton list: {ctrl_inputs}." + raise ModelImportError(error) + assert isinstance(inputs, model.Term) + block_node = self.add_node( + block, + # TODO The translation here seems to be underdetermined. It could be + # DataflowBlock( + # self.import_type_row(inputs), + # Sum(ts), + # ss, + # ), + # where the ctrl_outputs have been expressed as: + # [[*ts[0], *ss], [*ts[1], *ss], ...] + # with ss some common suffix of the lists in ctrl_outputs. But how do we + # decide on that common suffix? Below we take it to be empty. + DataflowBlock( + self.import_type_row(inputs), + Sum( + [ + self.import_type_row(cast(model.Term, output)) + for output in ctrl_outputs.to_list_parts() + ] + ), + [], + ), + parent, + ) + # 2. Import the dataflow region: + [block_region] = block.regions + self.import_dfg_region(block_region, block_node) + + def import_cfg_region( + self, region: model.Region, signature: FunctionType, parent: Node + ): + """Import an entire CFG region from the model into the Hugr.""" + [entry_link] = region.sources + entry_block_idx = None + for i, child in enumerate(region.children): + if entry_link in child.inputs: + entry_block_idx = i + break + assert entry_block_idx is not None + entry_block = region.children[entry_block_idx] + + # 1. Import the entry block: + self.import_block(entry_block, parent) + + # 2. Create the exit node: + exit_node = self.hugr.add_node(ExitBlock(signature.output), parent) + self.record_in_links(exit_node, region.targets) + + # 3. Import the other blocks: + for i, child in enumerate(region.children): + if i != entry_block_idx: + self.import_block(child, parent) + + def import_node_in_dfg(self, node: model.Node, parent: Node) -> Node: + """Import a model Node within a DFG region. + + Returns the Hugr Node corresponding to the model Node. The correspondence is + almost 1-1, but a LoadConst model Node requires two Hugr Nodes (Const and + LoadConst); in this case the LoadConst is returned. + """ + signature = self.import_signature(node.signature) + + def import_dfg_node() -> Node: + match node.regions: + case [body]: + pass + case _: + error = "DFG node expects a dataflow region." + raise ModelImportError(error, node) + node_id = self.add_node( + node, DFG(signature.input, signature.output), parent + ) + self.import_dfg_region(body, node_id) + return node_id + + def import_tail_loop() -> Node: + match node.regions: + case [body]: + pass + case _: + error = "Loop node expects a dataflow region." + raise ModelImportError(error, node) + + match body.signature: + case model.Apply("core.fn", [_, body_outputs]): + pass + case _: + error = "Tail loop body expects `(core.fn _ _)` signature." + raise ModelImportError(error, node) + + match list(_import_closed_list(body_outputs)): + case [model.Apply("core.adt", [variants]), *rest]: + pass + case _: + error = "TailLoop body expects `(core.adt _)` as first target type." + raise ModelImportError(error, node) + + match list(_import_closed_list(variants)): + case [just_inputs, just_outputs]: + pass + case _: + error = "TailLoop body expects sum type with two variants." + raise ModelImportError(error, node) + + node_id = self.add_node( + node, + TailLoop( + just_inputs=self.import_type_row(just_inputs), + rest=[self.import_type(t) for t in rest], + _just_outputs=self.import_type_row(just_outputs), + ), + parent, + len(signature.output), + ) + self.import_dfg_region(body, node_id) + return node_id + + def import_custom_node(op: model.Term) -> Node: + match op: + case model.Apply(symbol, args): + pass + case _: + error = "The operation of a custom node must be a symbol " + "application." + raise ModelImportError(error, node) + + match symbol: + case "core.call": + _input_types, _output_types, func = args + match func: + case model.Apply(fn_symbol, fn_args): + pass + case _: + error = "The function of a Call node must be a symbol " + "application." + raise ModelImportError(error, node) + type_args = [self.import_type_arg(fn_arg) for fn_arg in fn_args] + callnode = self.add_node( + node, + Call( + # FIXME PolyFuncType needs list[TypeParam], not + # list[TypeArg]. How to get this? + signature=PolyFuncType(type_args, signature), # type: ignore[arg-type] + instantiation=signature, + type_args=type_args, + ), + parent, + len(signature.output), + ) + self.fn_calls.append((fn_symbol, callnode)) + return callnode + case "core.call_indirect": + [inputs, outputs] = args + sig = FunctionType( + self.import_type_row(inputs), self.import_type_row(outputs) + ) + callindirectnode = self.add_node( + node, CallIndirect(sig), parent, len(signature.output) + ) + return callindirectnode + case "core.load_const": + value = args[-1] + [datatype] = signature.output + match datatype: + case FunctionType(_inputs, _outputs): + # Import as a LoadFunc operation. + match value: + case model.Apply(str() as fn_id, fn_args): + pass + case _: + error = "Unexpected arguments to core.load_const: " + f"{args}" + raise ModelImportError(error, node) + type_args = [ + self.import_type_arg(fn_arg) for fn_arg in fn_args + ] + loadfunc_node = self.add_node( + node, + LoadFunc( + # FIXME PolyFuncType needs list[TypeParam], not + # list[TypeArg]. How to get this? + PolyFuncType(type_args, datatype), # type: ignore[arg-type] + datatype, + type_args, + ), + parent, + 1, + ) + self.fn_calls.append((fn_id, loadfunc_node)) + return loadfunc_node + case _: + # Import as a Const and a LoadConst node. + v = self.import_value(value) + const_node = self.hugr.add_node(Const(v), parent, 1) + loadconst_node = self.add_node( + node, LoadConst(datatype), parent, 1 + ) + self.hugr.add_link( + OutPort(const_node, 0), InPort(loadconst_node, 0) + ) + return loadconst_node + case "core.make_adt": + tag = args[-1] + match tag: + case model.Literal(int() as tagval): + pass + case _: + error = f"Unexpected tag: {tag}" + raise ModelImportError(error) + [sigout] = signature.output + match sigout: + case Sum(_variant_rows) as output_sum: + pass + case _: + error = f"Invalid signature with {symbol}: {node.signature}" + raise ModelImportError(error) + return self.add_node(node, Tag(tagval, output_sum), parent, 1) + case "prelude.MakeTuple": + [arglist] = args + return self.add_node( + node, + MakeTuple(self.import_type_row(arglist)), + parent, + 1, + ) + case "prelude.UnpackTuple": + [arglist] = args + typerow = self.import_type_row(arglist) + return self.add_node( + node, + UnpackTuple(typerow), + parent, + len(typerow), + ) + # Others are imported as Custom nodes. + case _: + extension, op_name = _split_extension_name(symbol) + return self.add_node( + node, + Custom( + op_name=op_name, + extension=extension, + signature=signature, + args=[self.import_type_arg(arg) for arg in args], + ), + parent, + len(signature.output), + ) + + def import_cfg() -> Node: + match node.regions: + case [body]: + pass + case _: + error = "CFG node expects a control-flow region." + raise ModelImportError(error, node) + node_id = self.add_node( + node, CFG(signature.input, signature.output), parent + ) + self.import_cfg_region(body, signature, node_id) + return node_id + + def import_conditional() -> Node: + match node.signature: + case model.Apply("core.fn", [inputs, outputs]): + pass + case _: + error = "Conditional node expects `(core.fn _ _)` signature." + raise ModelImportError(error, node) + + match list(_import_closed_list(inputs)): + case [model.Apply("core.adt", [variants]), *other_inputs]: + sum_ty = Sum( + [ + self.import_type_row(variant) + for variant in _import_closed_list(variants) + ] + ) + case _: + error = ( + "Conditional node expects `(core.adt _)` as first input type." + ) + raise ModelImportError( + error, + node, + ) + + node_id = self.add_node( + node, + Conditional( + sum_ty=sum_ty, + other_inputs=[self.import_type(t) for t in other_inputs], + _outputs=self.import_type_row(outputs), + ), + parent, + ) + + for case_body in node.regions: + case_signature = self.import_signature(case_body.signature) + case_id = self.hugr.add_node( + Case(inputs=case_signature.input, _outputs=case_signature.output), + node_id, + ) + self.import_dfg_region(case_body, case_id) + + return node_id + + match node.operation: + case model.InvalidOp(): + error = "Invalid operation can not be imported." + raise ModelImportError(error, node) + case model.Dfg(): + return import_dfg_node() + case model.Cfg(): + return import_cfg() + case model.Block(): + error = "Unexpected basic block." + raise ModelImportError(error, node) + case model.CustomOp(op): + return import_custom_node(op) + case model.TailLoop(): + return import_tail_loop() + case model.Conditional(): + return import_conditional() + case _: + error = "Unexpected node in DFG region." + raise ModelImportError(error, node) + + def import_node_in_module(self, node: model.Node, link_prefix: int) -> Node | None: + """Import a model Node at the Hugr Module level.""" + self.link_prefix = link_prefix + + def import_declare_func(symbol: model.Symbol) -> Node: + f_name = _find_meta_title(node) + if f_name is None: + f_name = symbol.name + signature = self.enter_symbol(symbol) + node_id = self.add_node( + node, + FuncDecl( + f_name=f_name, signature=signature, visibility=symbol.visibility + ), + self.hugr.module_root, + 1, + ) + self.exit_symbol() + self.fn_nodes[symbol.name] = node_id + return node_id + + def import_define_func(symbol: model.Symbol) -> Node: + f_name = _find_meta_title(node) + if f_name is None: + f_name = symbol.name + signature = self.enter_symbol(symbol) + node_id = self.add_node( + node, + FuncDefn( + f_name=f_name, + inputs=signature.body.input, + _outputs=signature.body.output, + params=signature.params, + visibility=symbol.visibility, + ), + self.hugr.module_root, + 1, + ) + + match node.regions: + case [body]: + pass + case _: + error = "Function definition expects a single region." + raise ModelImportError(error, node) + + self.import_dfg_region(body, node_id) + self.exit_symbol() + self.fn_nodes[symbol.name] = node_id + return node_id + + def import_declare_alias(symbol: model.Symbol) -> Node: + match symbol: + case model.Symbol( + name=name, + visibility=_visibility, + signature=model.Apply("core.type", []), + ): + pass + case _: + error = f"Unexpected symbol in alias declaration: {symbol}" + raise ModelImportError(error) + return self.add_node( + node, + AliasDecl(alias=name, bound=TypeBound.Copyable), # TODO which bound? + self.hugr.module_root, + ) + + def import_define_alias(symbol: model.Symbol, value: model.Term) -> Node: + match symbol: + case model.Symbol( + name=name, + visibility=_visibility, + signature=model.Apply("core.type", []), + ): + pass + case _: + error = f"Unexpected symbol in alias definition: {symbol}" + raise ModelImportError(error) + return self.add_node( + node, + AliasDefn(alias=name, definition=self.import_type(value)), + self.hugr.module_root, + ) + + imported_node = None + match node.operation: + case model.DeclareFunc(symbol): + imported_node = import_declare_func(symbol) + case model.DefineFunc(symbol): + imported_node = import_define_func(symbol) + case model.DeclareAlias(symbol): + imported_node = import_declare_alias(symbol) + case model.DefineAlias(symbol, value): + imported_node = import_define_alias(symbol, value) + case model.Import(): + pass + case model.DeclareConstructor(): + pass + case model.DeclareOperation(): + pass + case _: + error = "Unexpected node in module region." + raise ModelImportError(error, node) + self.link_prefix = None + return imported_node + + def enter_symbol(self, symbol: model.Symbol) -> PolyFuncType: + assert len(self.local_vars) == 0 + + bounds: dict[str, TypeBound] = {} + + for constraint in symbol.constraints: + match constraint: + case model.Apply("core.nonlinear", [model.Var(name)]): + bounds[name] = TypeBound.Copyable + case _: + error = "Constraint other than `core.nonlinear` on a variable." + raise ModelImportError(error, constraint) + + param_types: list[TypeParam] = [] + + for index, param in enumerate(symbol.params): + bound = bounds.get(param.name, TypeBound.Linear) + type = self.import_type_param(param.type, bound=bound) + self.local_vars[param.name] = LocalVarData(index, type) + param_types.append(type) + + body = self.import_signature(symbol.signature) + return PolyFuncType(param_types, body) + + def exit_symbol(self): + self.local_vars = {} + + def import_signature(self, term: model.Term | None) -> FunctionType: + match term: + case None: + error = "Signature required." + raise ModelImportError(error) + case model.Apply("core.fn", [inputs, outputs]): + return FunctionType( + self.import_type_row(inputs), self.import_type_row(outputs) + ) + case _: + error = "Invalid signature." + raise ModelImportError(error, term) + + def lookup_var(self, name: str) -> "LocalVarData": + if name not in self.local_vars: + error = f"Unknown variable `{name}`." + raise ModelImportError(error) + + return self.local_vars[name] + + def import_type_param( + self, term: model.Term, bound: TypeBound = TypeBound.Linear + ) -> TypeParam: + """Import a TypeParam from a model Term.""" + match term: + case model.Apply("core.nat"): + return BoundedNatParam() + case model.Apply("core.str"): + return StringParam() + case model.Apply("core.float"): + return FloatParam() + case model.Apply("core.bytes"): + return BytesParam() + case model.Apply("core.type"): + return TypeTypeParam(bound) + case model.Apply("core.list", [item_type]): + return ListParam(self.import_type_param(item_type)) + case model.Apply("core.tuple", [item_types]): + return TupleParam( + [ + self.import_type_param(item_type) + for item_type in _import_closed_list(item_types) + ] + ) + case model.Apply("core.const", [runtime_type]): + return ConstParam(self.import_type(runtime_type)) + case _: + error = "Failed to import TypeParam." + raise ModelImportError(error, term) + + def import_type_arg(self, term: model.Term) -> TypeArg: + """Import a TypeArg from a model Term.""" + + def import_list(term: model.Term) -> TypeArg: + lists: list[TypeArg] = [] + + for group in _group_seq_parts(term.to_list_parts()): + if isinstance(group, list): + lists.append( + ListArg([self.import_type_arg(item) for item in group]) + ) + else: + lists.append(self.import_type_arg(group)) + + return ListConcatArg(lists).flatten() + + def import_tuple(term: model.Term) -> TypeArg: + tuples: list[TypeArg] = [] + + for group in _group_seq_parts(term.to_list_parts()): + if isinstance(group, list): + tuples.append( + TupleArg([self.import_type_arg(item) for item in group]) + ) + else: + tuples.append(self.import_type_arg(group)) + + return TupleConcatArg(tuples).flatten() + + match term: + case model.Literal(str() as value): + return StringArg(value) + case model.Literal(int() as value): + return BoundedNatArg(value) + case model.Literal(float() as value): + return FloatArg(value) + case model.Literal(bytes() as value): + return BytesArg(value) + case model.List(): + return import_list(term) + case model.Tuple(): + return import_tuple(term) + case _: + # Assume it's a TypeTypeArg + return TypeTypeArg(self.import_type(term)) + + def import_type(self, term: model.Term) -> Type: + """Import the type from a model Term.""" + match term: + case model.Apply("core.fn", [inputs, outputs]): + return FunctionType( + self.import_type_row(inputs), self.import_type_row(outputs) + ) + case model.Apply("core.adt", [variants]): + return Sum( + [ + self.import_type_row(variant) + for variant in _import_closed_list(variants) + ] + ) + case model.Apply("prelude.qubit", []): + return _QubitDef() + case model.Apply(symbol, args): + extension, type_id = _split_extension_name(symbol) + return Opaque( + id=type_id, + extension=extension, + # TODO How to determine the type bound (Copyable or Linear)? + bound=TypeBound.Copyable, + args=[self.import_type_arg(arg) for arg in args], + ) + case model.Var(name): + var_data = self.lookup_var(name) + return Variable(idx=var_data.index, bound=var_data.bound) + case _: + error = "Failed to import Type." + raise ModelImportError(error, term) + + def import_type_row(self, term: model.Term) -> list[Type]: + def import_part(part: model.SeqPart) -> Type: + if isinstance(part, model.Splice): + if isinstance(part.seq, model.Var): + var_data = self.lookup_var(part.seq.name) + return RowVariable(var_data.index, var_data.bound) + else: + error = "Can only import spliced variables." + raise ModelImportError(error, term) + else: + return self.import_type(part) + + return [import_part(part) for part in term.to_list_parts()] + + def import_value(self, term: model.Term) -> val.Value: + match term: + case model.Apply( + "arithmetic.int.const", + [ + model.Literal(int() as int_logwidth), + model.Literal(int() as int_value), + ], + ): + # Ensure value is in signed form for conversion to IntVal: + width = 1 << int_logwidth + if int_value >= 1 << (width - 1): + int_value -= 1 << width + return IntVal(int_value, int_logwidth) + case model.Apply( + "arithmetic.float.const_f64", [model.Literal(float() as float_value)] + ): + return FloatVal(float_value) + case model.Apply("collections.array.const", [_, array_type, array_values]): + return ArrayVal( + [ + self.import_value(cast(model.Term, v)) + for v in array_values.to_list_parts() + ], + self.import_type(array_type), + ) + case model.Apply( + "compat.const_json", [typ, model.Literal(str() as json_str)] + ): + json_dict = json.loads(json_str) + match typ: + case model.Apply(typename, args): + match typename: + case "core.adt": + [arg] = args + match list(arg.to_list_parts()): + case [model.List() as ts]: + pass + case _: + error = f"Unexpected term: {term}" + raise ModelImportError(error) + match json_dict: + case {"c": "ConstExternalSymbol", "v": value}: + return val.Extension( + name="ConstExternalSymbol", + typ=Tuple( + *[ + self.import_type( + cast(model.Term, t) + ) + for t in ts.to_list_parts() + ] + ), + val=value, + ) + case _: + error = f"Unexpected term: {term}" + raise ModelImportError(error) + case _: + extension, type_id = _split_extension_name(typename) + match json_dict: + case {"c": name, "v": value}: + # Determine appropriate TypeBound + bound = TypeBound.Copyable + if typename == "collections.list.List": + [arg] = args + datatype = self.import_type(arg) + bound = datatype.type_bound() + # TODO Determine type bound in other cases + return val.Extension( + name=name, + typ=Opaque( + id=type_id, + bound=bound, + args=[ + self.import_type_arg(arg) + for arg in args + ], + extension=extension, + ), + val=value, + ) + case _: + error = f"Unexpected term: {term}" + raise ModelImportError(error) + case _: + error = f"Unexpected compat.const_json type: {typ}" + raise ModelImportError(error) + case model.Apply("core.const.adt", [variants, _types, tag, values]): + match tag: + case model.Literal(int() as tagval): + pass + case _: + error = f"Unexpected tag: {tag}" + raise ModelImportError(error) + return val.Sum( + tag=tagval, + typ=Sum( + variant_rows=[ + [ + self.import_type(cast(model.Term, t)) + for t in cast(model.Term, variant).to_list_parts() + ] + for variant in variants.to_list_parts() + ] + ), + vals=[ + self.import_value(cast(model.Term, v)) + for v in values.to_tuple_parts() + ], + ) + case _: + error = "Unsupported constant value." + raise ModelImportError(error, term) + + +@dataclass +class LocalVarData: + """Data describing a local variable.""" + + index: int + type: TypeParam + bound: TypeBound = field(default=TypeBound.Linear) + + +@dataclass +class RegionOrderHints: + """Order hint metadata.""" + + input_keys: list[int] = field(default_factory=list) + output_keys: list[int] = field(default_factory=list) + edges: list[tuple[int, int]] = field(default_factory=list) + key_to_node: dict[int, Node] = field(default_factory=dict) + + def add_node_keys(self, node: Node, keys: Iterable[int]): + for key in keys: + if key in self.key_to_node: + error = f"Duplicate order key `{key}`." + raise ModelImportError(error) + + self.key_to_node[key] = node + + def get_node_by_key(self, key: int) -> Node: + if key not in self.key_to_node: + error = f"Unknown order key `{key}`." + raise ModelImportError(error) + + return self.key_to_node[key] + + +def _group_seq_parts( + parts: Iterable[model.SeqPart], +) -> Generator[model.Term | list[model.Term]]: + group: list[model.Term] = [] + + for part in parts: + if isinstance(part, model.Splice): + if len(group) > 0: + yield group + group = [] + yield part.seq + else: + group.append(part) + + if len(group) > 0: + yield group + + +def _import_closed_list(term: model.Term) -> Generator[model.Term]: + for part in term.to_list_parts(): + if isinstance(part, model.Splice): + error = "Expected closed list." + raise ModelImportError(error, term) + else: + yield part + + +def _import_closed_tuple(term: model.Term) -> Generator[model.Term]: + for part in term.to_tuple_parts(): + if isinstance(part, model.Splice): + error = "Expected closed tuple." + raise ModelImportError(error, term) + else: + yield part + + +def _split_extension_name(name: str) -> tuple[str, str]: + match name.rsplit(".", 1): + case [extension, id]: + return (extension, id) + case [id]: + return ("", id) + case _: + raise AssertionError diff --git a/hugr-py/src/hugr/package.py b/hugr-py/src/hugr/package.py index c219a750e6..00d3fe26fa 100644 --- a/hugr-py/src/hugr/package.py +++ b/hugr-py/src/hugr/package.py @@ -16,11 +16,11 @@ read_envelope, read_envelope_str, ) +from hugr.hugr.base import Hugr from hugr.ops import FuncDecl, FuncDefn, Op if TYPE_CHECKING: from hugr.ext import Extension - from hugr.hugr.base import Hugr from hugr.hugr.node_port import Node __all__ = [ @@ -82,6 +82,12 @@ def from_str(envelope: str) -> Package: """ return read_envelope_str(envelope) + @staticmethod + def from_model(package: model.Package, extensions: list[Extension] | None = None): + if extensions is None: + extensions = [] + return Package([Hugr.from_model(hugr) for hugr in package.modules], extensions) + def to_bytes(self, config: EnvelopeConfig | None = None) -> bytes: """Serialize the package to a HUGR envelope byte string. diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index e17d39653a..8cf7aeb970 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -364,6 +364,15 @@ def to_model(self) -> model.Term: [model.Splice(cast(model.Term, elem.to_model())) for elem in self.lists] ) + def flatten(self) -> TypeArg: + match self.lists: + case []: + return ListArg([]) + case [item]: + return item + case _: + return self + @dataclass(frozen=True) class TupleArg(TypeArg): @@ -405,6 +414,15 @@ def to_model(self) -> model.Term: [model.Splice(cast(model.Term, elem.to_model())) for elem in self.tuples] ) + def flatten(self) -> TypeArg: + match self.tuples: + case []: + return TupleArg([]) + case [item]: + return item + case _: + return self + @dataclass(frozen=True) class VariableArg(TypeArg): diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 26c24638a4..50c3e3000d 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -208,6 +208,26 @@ def validate( ), f"HUGRs are not the same for {write_fmt} -> {load_fmt}" +def canonicalize_json(payload): + """Put json into a canonical form for hashing purposes. + + Specifically, replace a general sum of empty rows with an explicit unit sum.""" + if isinstance(payload, list): + return list(map(canonicalize_json, payload)) + elif isinstance(payload, dict): + if ( + set(payload.keys()) == {"t", "s", "rows"} + and payload["t"] == "Sum" + and payload["s"] == "General" + and not any(payload["rows"]) + ): + return {"t": "Sum", "s": "Unit", "size": len(payload["rows"])} + else: + return {k: canonicalize_json(v) for k, v in payload.items()} + else: + return payload + + @dataclass(frozen=True, order=True) class _NodeHash: op: _OpHash @@ -255,7 +275,7 @@ def _hash_node(cls, h: Hugr, n: Node, depth: int, name: str) -> _NodeHash: # StaticArrayVal's dictionary payload containing a SumValue # internally, see `test_val_static_array`). value_dict = op_type.val._to_serial_root().model_dump(mode="json") - op = _OpHash("Const", value_dict) + op = _OpHash("Const", canonicalize_json(value_dict)) else: op = _OpHash(op_type.name())