Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,68 @@ fn main() {

As this is naturally exhaustive, this is only supported for `FromPrimitive`, not also `TryFromPrimitive`.

## Interaction with std Default

You can define a different "fallback" default for FromPrimitive than what is used by the std lib's Default derive.

```rust
use num_enum::FromPrimitive;
use std::convert::TryFrom;

#[derive(Debug, Default, Eq, PartialEq, FromPrimitive)]
#[repr(u8)]
enum Number {
#[default]
Zero = 0,
#[num_enum(default)]
NonZero,
}

fn main() {
let def = Number::default();
assert_eq!(def, Number::Zero);

let zero = Number::from(0u8);
assert_eq!(zero, Number::Zero);

let one = Number::from(1u8);
assert_eq!(one, Number::NonZero);

let two = Number::from(2u8);
assert_eq!(two, Number::NonZero);
}
```

Or with catch_all

```rust
use num_enum::FromPrimitive;
use std::convert::TryFrom;

#[derive(Debug, Default, Eq, PartialEq, FromPrimitive)]
#[repr(u8)]
enum Number {
#[default]
Zero = 0,
#[num_enum(catch_all)]
NonZero(u8),
}

fn main() {
let def = Number::default();
assert_eq!(def, Number::Zero);

let zero = Number::from(0u8);
assert_eq!(zero, Number::Zero);

let one = Number::from(1u8);
assert_eq!(one, Number::NonZero(1_u8));

let two = Number::from(2u8);
assert_eq!(two, Number::NonZero(2_u8));
}
```

## Unsafely turning a primitive into an enum with unchecked_transmute_from

If you're really certain a conversion will succeed (and have not made use of `#[num_enum(default)]` or `#[num_enum(alternatives = [..])]`
Expand Down
37 changes: 37 additions & 0 deletions num_enum/tests/complex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use num_enum_derive::{FromPrimitive, IntoPrimitive};

// Guard against https://github.com/illicitonion/num_enum/issues/27
mod alloc {}
mod core {}
mod num_enum {}
mod std {}

#[test]
fn std_default_with_num_enum_catch_all() {
#[derive(Debug, Eq, PartialEq, Default, FromPrimitive, IntoPrimitive)]
#[repr(u8)]
enum Enum {
#[default]
Zero = 0,
#[num_enum(catch_all)]
NonZero(u8),
}

assert_eq!(Enum::Zero, <Enum as ::core::default::Default>::default());
assert_eq!(Enum::NonZero(5), 5u8.into());
}

#[test]
fn std_default_with_num_enum_default() {
#[derive(Debug, Eq, PartialEq, Default, FromPrimitive, IntoPrimitive)]
#[repr(u8)]
enum Enum {
#[default]
Zero = 0,
#[num_enum(default)]
NonZero,
}

assert_eq!(Enum::Zero, <Enum as ::core::default::Default>::default());
assert_eq!(Enum::NonZero, 5u8.into());
}
21 changes: 16 additions & 5 deletions num_enum_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ pub fn derive_from_primitive(input: TokenStream) -> TokenStream {
Ok(is_naturally_exhaustive) => {
if is_naturally_exhaustive {
quote! { unreachable!("exhaustive enum") }
} else if let Some(default_ident) = enum_info.default() {
quote! { Self::#default_ident }
} else if let Some(catch_all_ident) = enum_info.catch_all() {
quote! { Self::#catch_all_ident(number) }
} else if let Some(default_ident) = enum_info.default() {
quote! { Self::#default_ident }
} else if let Some(default_ident) = enum_info.std_default() {
// std default is the last priority to allow for a different num_enum FromPrimitive default than the std Default::default()
quote! { Self::#default_ident }
} else {
let span = Span::call_site();
let message =
Expand Down Expand Up @@ -339,9 +342,17 @@ pub fn derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream {
pub fn derive_default(stream: TokenStream) -> TokenStream {
let enum_info = parse_macro_input!(stream as EnumInfo);

let default_ident = match enum_info.default() {
Some(ident) => ident,
None => {
let default_ident = match ( enum_info.default(), enum_info.std_default() ) {
// num_enum(default) takes precedence over std default
(Some(ident), None) => ident,
(None, Some(ident)) => ident,
(Some(_), Some(_)) => {
let span = Span::call_site();
let message =
"#[derive(num_enum::Default)] cannot be used with both #[default] and #[num_enum(default)]";
return syn::Error::new(span, message).to_compile_error().into();
},
(None, None) => {
let span = Span::call_site();
let message =
"#[derive(num_enum::Default)] requires enum to be exhaustive, or a variant marked with `#[default]` or `#[num_enum(default)]`";
Expand Down
93 changes: 58 additions & 35 deletions num_enum_derive/src/parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ impl EnumInfo {
die!(self.repr.clone() => "Failed to parse repr into bit size");
}

pub(crate) fn std_default(&self) -> Option<&Ident> {
self.variants
.iter()
.find(|info| info.is_std_default)
.map(|info| &info.ident)
}

pub(crate) fn default(&self) -> Option<&Ident> {
self.variants
.iter()
Expand Down Expand Up @@ -144,8 +151,6 @@ impl Parse for EnumInfo {
let crate_path = attributes.crate_path.clone().map(|k| k.path);

let mut variants: Vec<VariantInfo> = vec![];
let mut has_default_variant: bool = false;
let mut has_catch_all_variant: bool = false;

// Vec to keep track of the used discriminants and alt values.
let mut discriminant_int_val_set = BTreeSet::new();
Expand All @@ -166,22 +171,15 @@ impl Parse for EnumInfo {
// `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
// and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
// keep track of whether we encountered such an attribute:
let mut is_std_default: bool = false;
let mut is_default: bool = false;
let mut is_catch_all: bool = false;
let mut err_token: Option<Box<dyn ToTokens>> = None;

for attribute in &variant.attrs {
if attribute.path().is_ident("default") {
if has_default_variant {
die!(attribute =>
"Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
);
} else if has_catch_all_variant {
die!(attribute =>
"Attribute `default` is mutually exclusive with `catch_all`"
);
}
is_default = true;
has_default_variant = true;
err_token = Some(Box::new(attribute.clone()));
is_std_default = true;
}

if attribute.path().is_ident("num_enum") {
Expand All @@ -190,29 +188,11 @@ impl Parse for EnumInfo {
for variant_attribute in variant_attributes.items {
match variant_attribute {
NumEnumVariantAttributeItem::Default(default) => {
if has_default_variant {
die!(default.keyword =>
"Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
);
} else if has_catch_all_variant {
die!(default.keyword =>
"Attribute `default` is mutually exclusive with `catch_all`"
);
}
err_token = Some(Box::new(default.keyword));
is_default = true;
has_default_variant = true;
}
NumEnumVariantAttributeItem::CatchAll(catch_all) => {
if has_catch_all_variant {
die!(catch_all.keyword =>
"Multiple variants marked with `#[num_enum(catch_all)]`"
);
} else if has_default_variant {
die!(catch_all.keyword =>
"Attribute `catch_all` is mutually exclusive with `default`"
);
}

err_token = Some(Box::new(catch_all.keyword));
match variant
.fields
.iter()
Expand All @@ -224,7 +204,6 @@ impl Parse for EnumInfo {
..
}] if path.is_ident(&repr) => {
is_catch_all = true;
has_catch_all_variant = true;
}
_ => {
die!(catch_all.keyword =>
Expand Down Expand Up @@ -358,15 +337,17 @@ impl Parse for EnumInfo {
discriminant_int_val_set.extend(sorted_alternate_int_values);
}

// Add the current discriminant to the the set to keep track.
// Add the current discriminant to the set to keep track.
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
discriminant_int_val_set.insert(canonical_value_int);
}

variants.push(VariantInfo {
ident,
is_std_default,
is_default,
is_catch_all,
err_token,
canonical_value: discriminant,
alternative_values: flattened_raw_alternative_values,
});
Expand All @@ -382,6 +363,46 @@ impl Parse for EnumInfo {
}
}

// Validate variants
let with_std_default: Vec<_> = variants.iter().filter(|v| v.is_std_default).collect();
let with_default: Vec<_> = variants.iter().filter(|v| v.is_default).collect();
let with_catch_all: Vec<_> = variants.iter().filter(|v| v.is_catch_all).collect();
if with_std_default.len() > 1 {
if let Some(token) = with_std_default[0].err_token.as_ref() {
die!(token =>
"Multiple variants marked `#[default]` found"
);
} else {
die!(name =>
"Multiple variants marked `#[default]` found"
);
}
}
if with_default.len() > 1 {
let msg = "Multiple variants marked #[num_enum(default)]` found";
if let Some(token) = with_default[0].err_token.as_ref() {
die!(token => msg);
} else {
die!(msg);
}
}
if with_catch_all.len() > 1 {
let msg = "Multiple variants marked #[num_enum(catch_all)]` found";
if let Some(token) = with_catch_all[0].err_token.as_ref() {
die!(token => msg);
} else {
die!(msg);
}
}
if with_default.len() > 0 && with_catch_all.len() > 0 {
let msg = "Attribute #[num_enum(catch_all)] is mutually exclusive with #[num_enum(default)]`";
if let Some(token) = with_catch_all[0].err_token.as_ref() {
die!(token => msg);
} else {
die!(msg);
}
}

let error_type_info = attributes.error_type.map(Into::into).unwrap_or_else(|| {
let crate_name = get_crate_path(crate_path.clone());
ErrorType {
Expand Down Expand Up @@ -501,8 +522,10 @@ fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {

pub(crate) struct VariantInfo {
ident: Ident,
is_std_default: bool,
is_default: bool,
is_catch_all: bool,
err_token: Option<Box<dyn ToTokens>>,
canonical_value: Expr,
alternative_values: Vec<Expr>,
}
Expand Down