Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add build_expect_sum to allow specific error messages #2032

Merged
merged 3 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Prelude extension - available in all contexts, defining common types,
//! operations and constants.
use std::str::FromStr;
use std::sync::{Arc, Weak};

use itertools::Itertools;
Expand Down Expand Up @@ -447,14 +448,19 @@ impl CustomConst for ConstUsize {
}

#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
/// Structure for holding constant [error types].
///
/// [error types]: crate::extension::prelude::error_type
pub struct ConstError {
/// Integer tag/signal for the error.
pub signal: u32,
/// Error message.
pub message: String,
}

/// Default error signal.
pub const DEFAULT_ERROR_SIGNAL: u32 = 1;

impl ConstError {
/// Define a new error value.
pub fn new(signal: u32, message: impl ToString) -> Self {
Expand All @@ -464,6 +470,12 @@ impl ConstError {
}
}

/// Define a new error value with the [default signal].
///
/// [default signal]: DEFAULT_ERROR_SIGNAL
pub fn new_default_signal(message: impl ToString) -> Self {
Self::new(DEFAULT_ERROR_SIGNAL, message)
}
/// Returns an "either" value with a failure variant.
///
/// args:
Expand Down Expand Up @@ -491,6 +503,20 @@ impl CustomConst for ConstError {
}
}

impl FromStr for ConstError {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self::new_default_signal(s))
}
}

impl From<String> for ConstError {
fn from(s: String) -> Self {
Self::new_default_signal(s)
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
/// A structure for holding references to external symbols.
pub struct ConstExternalSymbol {
Expand Down
28 changes: 26 additions & 2 deletions hugr-core/src/extension/prelude/unwrap_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,31 @@ pub trait UnwrapBuilder: Dataflow {
tag: usize,
sum_type: SumType,
input: Wire,
) -> Result<[Wire; N], BuildError> {
self.build_expect_sum(tag, sum_type, input, |i| {
format!("Expected variant {} but got variant {}", tag, i)
})
}

/// Build an unwrap operation for a sum type to extract the variant at the given tag
/// or panic with given message if the tag is not the expected value.
///
/// `error` is a function that takes the actual tag and returns the error message
/// for cases where the tag is not the expected value.
///
/// # Panics
///
/// If `tag` is greater than the number of variants in the sum type.
///
/// # Errors
///
/// Errors in building the unwrapping conditional.
fn build_expect_sum<const N: usize, T: Into<ConstError>>(
&mut self,
tag: usize,
sum_type: SumType,
input: Wire,
mut error: impl FnMut(usize) -> T,
) -> Result<[Wire; N], BuildError> {
let variants: Vec<TypeRow> = (0..sum_type.num_variants())
.map(|i| {
Expand All @@ -64,8 +89,7 @@ pub trait UnwrapBuilder: Dataflow {
} else {
let output_row = output_row.iter().cloned();
let inputs = zip_eq(case.input_wires(), variant.iter().cloned());
let err =
ConstError::new(1, format!("Expected variant {} but got variant {}", tag, i));
let err = error(i).into();
let outputs = case.add_panic(err, output_row, inputs)?.outputs();
case.finish_with_outputs(outputs)?;
}
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading