Skip to content

Commit f50ee3b

Browse files
committed
refactor: Put extension inference behind a feature gate
1 parent b680662 commit f50ee3b

File tree

7 files changed

+63
-11
lines changed

7 files changed

+63
-11
lines changed

.github/workflows/ci.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Continuous integration
33
on:
44
push:
55
branches:
6-
- main
6+
- main
77
pull_request:
88
branches:
99
- main
@@ -33,7 +33,7 @@ jobs:
3333
- name: Check formatting
3434
run: cargo fmt -- --check
3535
- name: Run clippy
36-
run: cargo clippy --all-targets -- -D warnings
36+
run: cargo clippy --all-targets --all-features -- -D warnings
3737
- name: Build docs
3838
run: cargo doc --no-deps --all-features
3939
env:

Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ name = "hugr"
2222
bench = false
2323
path = "src/lib.rs"
2424

25+
[features]
26+
extension_inference = []
27+
2528
[dependencies]
2629
thiserror = "1.0.28"
2730
portgraph = { version = "0.11.0", features = ["serde", "petgraph"] }

src/extension.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ use crate::types::type_param::{check_type_args, TypeArgError};
1818
use crate::types::type_param::{TypeArg, TypeParam};
1919
use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound};
2020

21+
#[allow(dead_code)]
2122
mod infer;
22-
pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError};
23+
#[cfg(feature = "extension_inference")]
24+
pub use infer::infer_extensions;
25+
pub use infer::{ExtensionSolution, InferExtensionError};
2326

2427
mod op_def;
2528
pub use op_def::{

src/extension/infer/test.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
use std::error::Error;
22

33
use super::*;
4+
#[cfg(feature = "extension_inference")]
45
use crate::builder::test::closed_dfg_root_hugr;
56
use crate::builder::{
67
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
78
};
89
use crate::extension::prelude::QB_T;
910
use crate::extension::ExtensionId;
1011
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
11-
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
12+
#[cfg(feature = "extension_inference")]
13+
use crate::hugr::validate::ValidationError;
14+
use crate::hugr::{Hugr, HugrMut, HugrView, NodeType};
1215
use crate::macros::const_extension_ids;
1316
use crate::ops::custom::{ExternalOp, OpaqueOp};
14-
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
17+
#[cfg(feature = "extension_inference")]
18+
use crate::ops::handle::NodeHandle;
19+
use crate::ops::{self, dataflow::IOTrait};
1520
use crate::ops::{LeafOp, OpType};
1621

1722
use crate::type_row;
@@ -153,6 +158,7 @@ fn plus() -> Result<(), InferExtensionError> {
153158
Ok(())
154159
}
155160

161+
#[cfg(feature = "extension_inference")]
156162
#[test]
157163
// This generates a solution that causes validation to fail
158164
// because of a missing lift node
@@ -214,6 +220,7 @@ fn open_variables() -> Result<(), InferExtensionError> {
214220
Ok(())
215221
}
216222

223+
#[cfg(feature = "extension_inference")]
217224
#[test]
218225
// Infer the extensions on a child node with no inputs
219226
fn dangling_src() -> Result<(), Box<dyn Error>> {
@@ -305,6 +312,7 @@ fn create_with_io(
305312
Ok([node, input, output])
306313
}
307314

315+
#[cfg(feature = "extension_inference")]
308316
#[test]
309317
fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
310318
fn build_case(
@@ -967,6 +975,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
967975
Ok(())
968976
}
969977

978+
#[cfg(feature = "extension_inference")]
970979
#[test]
971980
fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
972981
let mut builder = ModuleBuilder::new();
@@ -997,6 +1006,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
9971006
Ok(())
9981007
}
9991008

1009+
#[cfg(feature = "extension_inference")]
10001010
#[test]
10011011
// Test that the difference between a FuncDefn's input and output nodes is being
10021012
// constrained to be the same as the extension delta in the FuncDefn signature.

src/hugr.rs

+20-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ pub mod serialize;
88
pub mod validate;
99
pub mod views;
1010

11+
#[cfg(not(feature = "extension_inference"))]
12+
use std::collections::HashMap;
1113
use std::collections::VecDeque;
1214
use std::iter;
1315

@@ -23,9 +25,9 @@ use thiserror::Error;
2325

2426
pub use self::views::{HugrView, RootTagged};
2527
use crate::core::NodeIndex;
26-
use crate::extension::{
27-
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
28-
};
28+
#[cfg(feature = "extension_inference")]
29+
use crate::extension::infer_extensions;
30+
use crate::extension::{ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError};
2931
use crate::ops::custom::resolve_extension_ops;
3032
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
3133
use crate::types::FunctionType;
@@ -197,12 +199,19 @@ impl Hugr {
197199
/// Infer extension requirements and add new information to `op_types` field
198200
///
199201
/// See [`infer_extensions`] for details on the "closure" value
202+
#[cfg(feature = "extension_inference")]
200203
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
201204
let (solution, extension_closure) = infer_extensions(self)?;
202205
self.instantiate_extensions(solution);
203206
Ok(extension_closure)
204207
}
208+
/// Do nothing - this functionality is gated by the feature "extension_inference"
209+
#[cfg(not(feature = "extension_inference"))]
210+
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
211+
Ok(HashMap::new())
212+
}
205213

214+
#[allow(dead_code)]
206215
/// Add extension requirement information to the hugr in place.
207216
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
208217
// We only care about inferred _input_ extensions, because `NodeType`
@@ -345,13 +354,20 @@ pub enum HugrError {
345354
#[cfg(test)]
346355
mod test {
347356
use super::{Hugr, HugrView};
357+
#[cfg(feature = "extension_inference")]
348358
use crate::builder::test::closed_dfg_root_hugr;
359+
#[cfg(feature = "extension_inference")]
349360
use crate::extension::ExtensionSet;
361+
#[cfg(feature = "extension_inference")]
350362
use crate::hugr::HugrMut;
363+
#[cfg(feature = "extension_inference")]
351364
use crate::ops;
365+
#[cfg(feature = "extension_inference")]
352366
use crate::type_row;
367+
#[cfg(feature = "extension_inference")]
353368
use crate::types::{FunctionType, Type};
354369

370+
#[cfg(feature = "extension_inference")]
355371
use std::error::Error;
356372

357373
#[test]
@@ -371,6 +387,7 @@ mod test {
371387
assert_matches!(hugr.get_io(hugr.root()), Some(_));
372388
}
373389

390+
#[cfg(feature = "extension_inference")]
374391
#[test]
375392
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
376393
const BIT: Type = crate::extension::prelude::USIZE_T;

src/hugr/validate.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ use petgraph::visit::{Topo, Walker};
99
use portgraph::{LinkView, PortView};
1010
use thiserror::Error;
1111

12+
#[cfg(feature = "extension_inference")]
13+
use crate::extension::validate::ExtensionValidator;
1214
use crate::extension::SignatureError;
1315
use crate::extension::{
14-
validate::{ExtensionError, ExtensionValidator},
15-
ExtensionRegistry, ExtensionSolution, InferExtensionError,
16+
validate::ExtensionError, ExtensionRegistry, ExtensionSolution, InferExtensionError,
1617
};
1718

1819
use crate::ops::custom::CustomOpError;
@@ -36,6 +37,7 @@ struct ValidationContext<'a, 'b> {
3637
/// Dominator tree for each CFG region, using the container node as index.
3738
dominators: HashMap<Node, Dominators<Node>>,
3839
/// Context for the extension validation.
40+
#[cfg(feature = "extension_inference")]
3941
extension_validator: ExtensionValidator,
4042
/// Registry of available Extensions
4143
extension_registry: &'b ExtensionRegistry,
@@ -64,6 +66,9 @@ impl Hugr {
6466

6567
impl<'a, 'b> ValidationContext<'a, 'b> {
6668
/// Create a new validation context.
69+
// Allow unused "extension_closure" variable for when
70+
// the "extension_inference" feature is disabled.
71+
#[allow(unused_variables)]
6772
pub fn new(
6873
hugr: &'a Hugr,
6974
extension_closure: ExtensionSolution,
@@ -72,6 +77,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
7277
Self {
7378
hugr,
7479
dominators: HashMap::new(),
80+
#[cfg(feature = "extension_inference")]
7581
extension_validator: ExtensionValidator::new(hugr, extension_closure),
7682
extension_registry,
7783
}
@@ -163,6 +169,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
163169

164170
// FuncDefns have no resources since they're static nodes, but the
165171
// functions they define can have any extension delta.
172+
#[cfg(feature = "extension_inference")]
166173
if node_type.tag() != OpTag::FuncDefn {
167174
// If this is a container with I/O nodes, check that the extension they
168175
// define match the extensions of the container.
@@ -240,6 +247,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
240247
let other_node: Node = self.hugr.graph.port_node(link).unwrap().into();
241248
let other_offset = self.hugr.graph.port_offset(link).unwrap().into();
242249

250+
#[cfg(feature = "extension_inference")]
243251
self.extension_validator
244252
.check_extensions_compatible(&(node, port), &(other_node, other_offset))?;
245253

src/hugr/validate/test.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@ use cool_asserts::assert_matches;
22

33
use super::*;
44
use crate::builder::test::closed_dfg_root_hugr;
5+
#[cfg(feature = "extension_inference")]
6+
use crate::builder::ModuleBuilder;
57
use crate::builder::{
68
BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder,
7-
ModuleBuilder,
89
};
910
use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T};
1011
use crate::extension::{
1112
Extension, ExtensionId, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY,
1213
};
1314
use crate::hugr::hugrmut::sealed::HugrMutInternals;
1415
use crate::hugr::{HugrError, HugrMut, NodeType};
16+
#[cfg(feature = "extension_inference")]
1517
use crate::macros::const_extension_ids;
1618
use crate::ops::dataflow::IOTrait;
1719
use crate::ops::{self, Const, LeafOp, OpType};
@@ -23,6 +25,7 @@ use crate::values::Value;
2325
use crate::{type_row, Direction, IncomingPort, Node};
2426

2527
const NAT: Type = crate::extension::prelude::USIZE_T;
28+
#[cfg(feature = "infer_extensions")]
2629
const Q: Type = crate::extension::prelude::QB_T;
2730

2831
/// Creates a hugr with a single function definition that copies a bit `copies` times.
@@ -71,6 +74,7 @@ fn add_df_children(b: &mut Hugr, parent: Node, copies: usize) -> (Node, Node, No
7174
/// Intended to be used to populate a BasicBlock node in a CFG.
7275
///
7376
/// Returns the node indices of each of the operations.
77+
#[cfg(feature = "infer_extensions")]
7478
fn add_block_children(
7579
b: &mut Hugr,
7680
parent: Node,
@@ -257,6 +261,7 @@ fn df_children_restrictions() {
257261
);
258262
}
259263

264+
#[cfg(feature = "extension_inference")]
260265
#[test]
261266
/// Validation errors in a dataflow subgraph.
262267
fn cfg_children_restrictions() {
@@ -404,6 +409,7 @@ fn test_ext_edge() -> Result<(), HugrError> {
404409
Ok(())
405410
}
406411

412+
#[cfg(feature = "extension_inference")]
407413
const_extension_ids! {
408414
const XA: ExtensionId = "A";
409415
const XB: ExtensionId = "BOOL_EXT";
@@ -441,6 +447,7 @@ fn test_local_const() -> Result<(), HugrError> {
441447
Ok(())
442448
}
443449

450+
#[cfg(feature = "extension_inference")]
444451
#[test]
445452
/// A wire with no extension requirements is wired into a node which has
446453
/// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed
@@ -474,6 +481,7 @@ fn missing_lift_node() -> Result<(), BuildError> {
474481
Ok(())
475482
}
476483

484+
#[cfg(feature = "extension_inference")]
477485
#[test]
478486
/// A wire with extension requirement `[A]` is wired into a an output with no
479487
/// extension req. In the validation extension typechecking, we don't do any
@@ -505,6 +513,7 @@ fn too_many_extension() -> Result<(), BuildError> {
505513
Ok(())
506514
}
507515

516+
#[cfg(feature = "extension_inference")]
508517
#[test]
509518
/// A wire with extension requirements `[A]` and another with requirements
510519
/// `[BOOL_T]` are both wired into a node which requires its inputs to have
@@ -558,6 +567,7 @@ fn extensions_mismatch() -> Result<(), BuildError> {
558567
Ok(())
559568
}
560569

570+
#[cfg(feature = "extension_inference")]
561571
#[test]
562572
fn parent_signature_mismatch() -> Result<(), BuildError> {
563573
let rs = ExtensionSet::singleton(&XA);
@@ -740,6 +750,7 @@ fn invalid_types() {
740750
);
741751
}
742752

753+
#[cfg(feature = "extension_inference")]
743754
#[test]
744755
fn parent_io_mismatch() {
745756
// The DFG node declares that it has an empty extension delta,

0 commit comments

Comments
 (0)