diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index d8dbbc957..f88a84a0d 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -1,5 +1,6 @@ //! Prelude extension - available in all contexts, defining common types, //! operations and constants. +use std::str::FromStr; use std::sync::{Arc, Weak}; use itertools::Itertools; @@ -447,7 +448,9 @@ impl CustomConst for ConstUsize { } #[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)] -/// Structure for holding constant usize values. +/// Structure for holding constant [error types]. +/// +/// [error types]: crate::extension::prelude::error_type pub struct ConstError { /// Integer tag/signal for the error. pub signal: u32, @@ -455,6 +458,9 @@ pub struct ConstError { pub message: String, } +/// Default error signal. +pub const DEFAULT_ERROR_SIGNAL: u32 = 1; + impl ConstError { /// Define a new error value. pub fn new(signal: u32, message: impl ToString) -> Self { @@ -464,6 +470,12 @@ impl ConstError { } } + /// Define a new error value with the [default signal]. + /// + /// [default signal]: DEFAULT_ERROR_SIGNAL + pub fn new_default_signal(message: impl ToString) -> Self { + Self::new(DEFAULT_ERROR_SIGNAL, message) + } /// Returns an "either" value with a failure variant. /// /// args: @@ -491,6 +503,20 @@ impl CustomConst for ConstError { } } +impl FromStr for ConstError { + type Err = (); + + fn from_str(s: &str) -> Result { + Ok(Self::new_default_signal(s)) + } +} + +impl From for ConstError { + fn from(s: String) -> Self { + Self::new_default_signal(s) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] /// A structure for holding references to external symbols. pub struct ConstExternalSymbol { diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index 2495a464d..06b4e3939 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -43,6 +43,31 @@ pub trait UnwrapBuilder: Dataflow { tag: usize, sum_type: SumType, input: Wire, + ) -> Result<[Wire; N], BuildError> { + self.build_expect_sum(tag, sum_type, input, |i| { + format!("Expected variant {} but got variant {}", tag, i) + }) + } + + /// Build an unwrap operation for a sum type to extract the variant at the given tag + /// or panic with given message if the tag is not the expected value. + /// + /// `error` is a function that takes the actual tag and returns the error message + /// for cases where the tag is not the expected value. + /// + /// # Panics + /// + /// If `tag` is greater than the number of variants in the sum type. + /// + /// # Errors + /// + /// Errors in building the unwrapping conditional. + fn build_expect_sum>( + &mut self, + tag: usize, + sum_type: SumType, + input: Wire, + mut error: impl FnMut(usize) -> T, ) -> Result<[Wire; N], BuildError> { let variants: Vec = (0..sum_type.num_variants()) .map(|i| { @@ -64,8 +89,7 @@ pub trait UnwrapBuilder: Dataflow { } else { let output_row = output_row.iter().cloned(); let inputs = zip_eq(case.input_wires(), variant.iter().cloned()); - let err = - ConstError::new(1, format!("Expected variant {} but got variant {}", tag, i)); + let err = error(i).into(); let outputs = case.add_panic(err, output_row, inputs)?.outputs(); case.finish_with_outputs(outputs)?; } diff --git a/uv.lock b/uv.lock index a1d023411..37df68963 100644 --- a/uv.lock +++ b/uv.lock @@ -297,7 +297,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "graphviz", specifier = ">=0.20.3" }, - { name = "pydantic", specifier = ">=2.8,<2.11" }, + { name = "pydantic", specifier = ">=2.8,<2.12" }, { name = "pydantic-extra-types", specifier = ">=2.9.0" }, { name = "pyzstd", specifier = "~=0.16.2" }, { name = "semver", specifier = ">=3.0.2" },