Skip to content

Commit a02fd36

Browse files
committed
allow closure to return whole error
1 parent 167c687 commit a02fd36

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

hugr-core/src/extension/prelude.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Prelude extension - available in all contexts, defining common types,
22
//! operations and constants.
3+
use std::str::FromStr;
34
use std::sync::{Arc, Weak};
45

56
use itertools::Itertools;
@@ -447,14 +448,19 @@ impl CustomConst for ConstUsize {
447448
}
448449

449450
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
450-
/// Structure for holding constant usize values.
451+
/// Structure for holding constant [error types].
452+
///
453+
/// [error types]: crate::extension::prelude::error_type
451454
pub struct ConstError {
452455
/// Integer tag/signal for the error.
453456
pub signal: u32,
454457
/// Error message.
455458
pub message: String,
456459
}
457460

461+
/// Default error signal.
462+
pub const DEFAULT_ERROR_SIGNAL: u32 = 1;
463+
458464
impl ConstError {
459465
/// Define a new error value.
460466
pub fn new(signal: u32, message: impl ToString) -> Self {
@@ -464,6 +470,12 @@ impl ConstError {
464470
}
465471
}
466472

473+
/// Define a new error value with the [default signal].
474+
///
475+
/// [default signal]: DEFAULT_ERROR_SIGNAL
476+
pub fn new_default_signal(message: impl ToString) -> Self {
477+
Self::new(DEFAULT_ERROR_SIGNAL, message)
478+
}
467479
/// Returns an "either" value with a failure variant.
468480
///
469481
/// args:
@@ -491,6 +503,20 @@ impl CustomConst for ConstError {
491503
}
492504
}
493505

506+
impl FromStr for ConstError {
507+
type Err = ();
508+
509+
fn from_str(s: &str) -> Result<Self, Self::Err> {
510+
Ok(Self::new_default_signal(s))
511+
}
512+
}
513+
514+
impl From<String> for ConstError {
515+
fn from(s: String) -> Self {
516+
Self::new_default_signal(s)
517+
}
518+
}
519+
494520
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
495521
/// A structure for holding references to external symbols.
496522
pub struct ConstExternalSymbol {

hugr-core/src/extension/prelude/unwrap_builder.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,23 @@ pub trait UnwrapBuilder: Dataflow {
5151

5252
/// Build an unwrap operation for a sum type to extract the variant at the given tag
5353
/// or panic with given message if the tag is not the expected value.
54-
/// Message closure takes the unexpected tag and returns message.
55-
fn build_expect_sum<const N: usize>(
54+
///
55+
/// `error` is a function that takes the actual tag and returns the error message
56+
/// for cases where the tag is not the expected value.
57+
///
58+
/// # Panics
59+
///
60+
/// If `tag` is greater than the number of variants in the sum type.
61+
///
62+
/// # Errors
63+
///
64+
/// Errors in building the unwrapping conditional.
65+
fn build_expect_sum<const N: usize, T: Into<ConstError>>(
5666
&mut self,
5767
tag: usize,
5868
sum_type: SumType,
5969
input: Wire,
60-
message: impl Fn(usize) -> String,
70+
mut error: impl FnMut(usize) -> T,
6171
) -> Result<[Wire; N], BuildError> {
6272
let variants: Vec<TypeRow> = (0..sum_type.num_variants())
6373
.map(|i| {
@@ -79,7 +89,7 @@ pub trait UnwrapBuilder: Dataflow {
7989
} else {
8090
let output_row = output_row.iter().cloned();
8191
let inputs = zip_eq(case.input_wires(), variant.iter().cloned());
82-
let err = ConstError::new(1, message(i));
92+
let err = error(i).into();
8393
let outputs = case.add_panic(err, output_row, inputs)?.outputs();
8494
case.finish_with_outputs(outputs)?;
8595
}

0 commit comments

Comments
 (0)