Skip to content

Commit 6ea4abf

Browse files
cqc-aleczrho
andauthored
refactor: Direct import of model representation to Python (#2683)
Closes #2287 . Testing: * All existing hugr-py tests that call `validate()` perform a round-trip, checking that the hashes of the start and end Hugrs (computed using the `_NodeHash` method defined in `conftest.py`) agree. * I exported all the Hugrs generated by tests in the guppylang repo (using `just export-integration-tests`) and tried importing them (1) using the old method via json and (2) using the new method, and checked that the hashes of the two imported Hugrs agree (using the same `_NodeHash` method). There are a few TODO comments remaining in `load.py` which I could not see a way to resolve using the existing model. I will investigate these further and raise issues if necessary. --------- Co-authored-by: Lukas Heidemann <[email protected]>
1 parent 7c41f82 commit 6ea4abf

File tree

12 files changed

+1280
-27
lines changed

12 files changed

+1280
-27
lines changed

hugr-core/src/import.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,12 +1140,12 @@ impl<'a> Context<'a> {
11401140

11411141
/// Import a node with a custom operation.
11421142
///
1143-
/// A custom operation in `hugr-model` referred to by a symbol application
1144-
/// term. The name of the symbol specifies the name of the custom operation,
1145-
/// and the arguments supplied to the symbol are the arguments to be passed
1146-
/// to the custom operation. This method imports the custom operations as
1147-
/// [`OpaqueOp`]s. The [`OpaqueOp`]s are then resolved later against the
1148-
/// [`ExtensionRegistry`].
1143+
/// A custom operation in `hugr-model` is referred to by a symbol
1144+
/// application term. The name of the symbol specifies the name of the
1145+
/// custom operation, and the arguments supplied to the symbol are the
1146+
/// arguments to be passed to the custom operation. This method imports the
1147+
/// custom operations as [`OpaqueOp`]s. The [`OpaqueOp`]s are then resolved
1148+
/// later against the [`ExtensionRegistry`].
11491149
///
11501150
/// Some operations that needed to be builtins in `hugr-core` are custom
11511151
/// operations in `hugr-model`. This method detects these and converts them

hugr-model/src/v0/binary/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
mod read;
1111
mod write;
1212

13-
pub use read::{ReadError, read_from_reader, read_from_slice};
13+
pub use read::{ReadError, read_from_reader, read_from_slice, read_from_slice_with_suffix};
1414
pub use write::{WriteError, write_to_vec, write_to_writer};

hugr-model/src/v0/binary/read.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::v0::table;
33
use crate::{CURRENT_VERSION, v0 as model};
44
use bumpalo::Bump;
55
use bumpalo::collections::Vec as BumpVec;
6-
use std::io::BufRead;
6+
use std::io::{BufRead, BufReader, Read};
77

88
/// An error encountered while deserialising a model.
99
#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
@@ -27,10 +27,27 @@ pub enum ReadError {
2727
type ReadResult<T> = Result<T, ReadError>;
2828

2929
/// Read a hugr package from a byte slice.
30+
///
31+
/// If the slice contains bytes beyond the encoded package, they are ignored.
3032
pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult<table::Package<'a>> {
3133
read_from_reader(slice, bump)
3234
}
3335

36+
/// Read a hugr package from a byte slice.
37+
///
38+
/// If the slice contains bytes beyond the encoded package, they are returned
39+
/// along with the decoded package.
40+
pub fn read_from_slice_with_suffix<'a>(
41+
slice: &[u8],
42+
bump: &'a Bump,
43+
) -> ReadResult<(table::Package<'a>, Vec<u8>)> {
44+
let mut buffer = BufReader::new(slice);
45+
let package = read_from_reader(&mut buffer, bump)?;
46+
let mut suffix: Vec<u8> = vec![];
47+
buffer.read_to_end(&mut suffix)?;
48+
Ok((package, suffix))
49+
}
50+
3451
/// Read a hugr package from an impl of [`BufRead`].
3552
pub fn read_from_reader(reader: impl BufRead, bump: &Bump) -> ReadResult<table::Package<'_>> {
3653
let reader =

hugr-py/rust/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ fn package_to_bytes(package: ast::Package) -> PyResult<Vec<u8>> {
7474
}
7575

7676
#[pyfunction]
77-
fn bytes_to_package(bytes: &[u8]) -> PyResult<ast::Package> {
77+
fn bytes_to_package(bytes: &[u8]) -> PyResult<(ast::Package, Vec<u8>)> {
7878
let bump = bumpalo::Bump::new();
79-
let table = hugr_model::v0::binary::read_from_slice(bytes, &bump)
79+
let (table, suffix) = hugr_model::v0::binary::read_from_slice_with_suffix(bytes, &bump)
8080
.map_err(|err| PyValueError::new_err(err.to_string()))?;
8181
let package = table
8282
.as_ast()
8383
.ok_or_else(|| PyValueError::new_err("Malformed package"))?;
84-
Ok(package)
84+
Ok((package, suffix))
8585
}
8686

8787
/// Returns the current version of the HUGR model format as a tuple of (major, minor, patch).

hugr-py/src/hugr/_hugr/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def bytes_to_module(binary: bytes) -> hugr.model.Module: ...
2020
def package_to_string(package: hugr.model.Package) -> str: ...
2121
def string_to_package(string: str) -> hugr.model.Package: ...
2222
def package_to_bytes(package: hugr.model.Package) -> bytes: ...
23-
def bytes_to_package(binary: bytes) -> hugr.model.Package: ...
23+
def bytes_to_package(binary: bytes) -> tuple[hugr.model.Package, bytes]: ...
2424
def current_model_version() -> tuple[int, int, int]: ...
2525
def to_json_envelope(binary: bytes) -> bytes: ...
2626
def run_cli() -> None: ...

hugr-py/src/hugr/envelope.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
import pyzstd
4141

42-
from hugr import cli
42+
import hugr._hugr as rust
4343

4444
if TYPE_CHECKING:
4545
from hugr.hugr.base import Hugr
@@ -64,7 +64,6 @@ def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes:
6464
if not isinstance(package, Package):
6565
package = Package(modules=[package], extensions=[])
6666

67-
# Currently only uncompressed JSON is supported.
6867
payload: bytes
6968
match config.format:
7069
case EnvelopeFormat.JSON:
@@ -103,6 +102,7 @@ def make_envelope_str(package: Package | Hugr, config: EnvelopeConfig) -> str:
103102
def read_envelope(envelope: bytes) -> Package:
104103
"""Decode a HUGR package from an envelope."""
105104
import hugr._serialization.extension as ext_s
105+
from hugr.package import Package
106106

107107
header = EnvelopeHeader.from_bytes(envelope)
108108
payload = envelope[10:]
@@ -113,12 +113,23 @@ def read_envelope(envelope: bytes) -> Package:
113113
match header.format:
114114
case EnvelopeFormat.JSON:
115115
return ext_s.Package.model_validate_json(payload).deserialize()
116-
case EnvelopeFormat.MODEL | EnvelopeFormat.MODEL_WITH_EXTS:
117-
# TODO Going via JSON is a temporary solution, until we get model import to
118-
# python properly implemented.
119-
# https://github.com/CQCL/hugr/issues/2287
120-
json_data = cli.convert(envelope, format="json")
121-
return read_envelope(json_data)
116+
case EnvelopeFormat.MODEL:
117+
model_package, suffix = rust.bytes_to_package(payload)
118+
if suffix:
119+
msg = f"Excess bytes in envelope with format {EnvelopeFormat.MODEL}."
120+
raise ValueError(msg)
121+
return Package.from_model(model_package)
122+
case EnvelopeFormat.MODEL_WITH_EXTS:
123+
from hugr.ext import Extension
124+
125+
model_package, suffix = rust.bytes_to_package(payload)
126+
return Package(
127+
modules=Package.from_model(model_package).modules,
128+
extensions=[
129+
Extension.from_json(json.dumps(extension))
130+
for extension in json.loads(suffix)
131+
],
132+
)
122133

123134

124135
def read_envelope_hugr(envelope: bytes) -> Hugr:

hugr-py/src/hugr/hugr/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,19 @@ def from_str(envelope: str) -> Hugr:
11141114
"""
11151115
return read_envelope_hugr_str(envelope)
11161116

1117+
@staticmethod
1118+
def from_model(module: model.Module) -> Hugr:
1119+
"""Import from the hugr model format."""
1120+
from hugr.model.load import ModelImport
1121+
1122+
loader = ModelImport(module=module)
1123+
for i, node in enumerate(module.root.children):
1124+
loader.import_node_in_module(node, i)
1125+
loader.link_ports()
1126+
loader.link_static_ports()
1127+
loader.add_module_metadata()
1128+
return loader.hugr
1129+
11171130
def to_bytes(self, config: EnvelopeConfig | None = None) -> bytes:
11181131
"""Serialize the HUGR into an envelope byte string.
11191132

hugr-py/src/hugr/model/__init__.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""HUGR model data structures."""
22

3-
from collections.abc import Sequence
3+
import warnings
4+
from collections.abc import Generator, Sequence
45
from dataclasses import dataclass, field
56
from enum import Enum
6-
from typing import Protocol
77

88
from semver import Version
99

@@ -21,7 +21,7 @@ def _current_version() -> Version:
2121
CURRENT_VERSION: Version = _current_version()
2222

2323

24-
class Term(Protocol):
24+
class Term:
2525
"""A model term for static data such as types, constants and metadata."""
2626

2727
def __str__(self) -> str:
@@ -33,6 +33,26 @@ def from_str(s: str) -> "Term":
3333
"""Read the term from its string representation."""
3434
return rust.string_to_term(s)
3535

36+
def to_list_parts(self) -> Generator["SeqPart"]:
37+
if isinstance(self, List):
38+
for part in self.parts:
39+
if isinstance(part, Splice):
40+
yield from part.seq.to_list_parts()
41+
else:
42+
yield part
43+
else:
44+
yield Splice(self)
45+
46+
def to_tuple_parts(self) -> Generator["SeqPart"]:
47+
if isinstance(self, Tuple):
48+
for part in self.parts:
49+
if isinstance(part, Splice):
50+
yield from part.seq.to_tuple_parts()
51+
else:
52+
yield part
53+
else:
54+
yield Splice(self)
55+
3656

3757
@dataclass(frozen=True)
3858
class Wildcard(Term):
@@ -129,9 +149,13 @@ def from_str(s: str) -> "Symbol":
129149
return rust.string_to_symbol(s)
130150

131151

132-
class Op(Protocol):
152+
class Op:
133153
"""The operation of a node."""
134154

155+
def symbol_name(self) -> str | None:
156+
"""Returns name of the symbol introduced by this node, if any."""
157+
return None
158+
135159

136160
@dataclass(frozen=True)
137161
class InvalidOp(Op):
@@ -159,13 +183,19 @@ class DefineFunc(Op):
159183

160184
symbol: Symbol
161185

186+
def symbol_name(self) -> str | None:
187+
return self.symbol.name
188+
162189

163190
@dataclass(frozen=True)
164191
class DeclareFunc(Op):
165192
"""Function declaration."""
166193

167194
symbol: Symbol
168195

196+
def symbol_name(self) -> str | None:
197+
return self.symbol.name
198+
169199

170200
@dataclass(frozen=True)
171201
class CustomOp(Op):
@@ -181,13 +211,19 @@ class DefineAlias(Op):
181211
symbol: Symbol
182212
value: Term
183213

214+
def symbol_name(self) -> str | None:
215+
return self.symbol.name
216+
184217

185218
@dataclass(frozen=True)
186219
class DeclareAlias(Op):
187220
"""Alias declaration."""
188221

189222
symbol: Symbol
190223

224+
def symbol_name(self) -> str | None:
225+
return self.symbol.name
226+
191227

192228
@dataclass(frozen=True)
193229
class TailLoop(Op):
@@ -205,20 +241,29 @@ class DeclareConstructor(Op):
205241

206242
symbol: Symbol
207243

244+
def symbol_name(self) -> str | None:
245+
return self.symbol.name
246+
208247

209248
@dataclass(frozen=True)
210249
class DeclareOperation(Op):
211250
"""Operation declaration."""
212251

213252
symbol: Symbol
214253

254+
def symbol_name(self) -> str | None:
255+
return self.symbol.name
256+
215257

216258
@dataclass(frozen=True)
217259
class Import(Op):
218260
"""Import operation."""
219261

220262
name: str
221263

264+
def symbol_name(self) -> str | None:
265+
return self.name
266+
222267

223268
@dataclass
224269
class Node:
@@ -302,7 +347,13 @@ def from_str(s: str) -> "Package":
302347
@staticmethod
303348
def from_bytes(b: bytes) -> "Package":
304349
"""Read a package from its binary representation."""
305-
return rust.bytes_to_package(b)
350+
package, suffix = rust.bytes_to_package(b)
351+
if suffix:
352+
warnings.warn(
353+
"Binary encoding of model Package contains extra bytes: ignoring them.",
354+
stacklevel=2,
355+
)
356+
return package
306357

307358
@property
308359
def version(self) -> Version:

0 commit comments

Comments
 (0)