diff --git a/README.md b/README.md index 0191472..5b2b83d 100644 --- a/README.md +++ b/README.md @@ -244,6 +244,68 @@ fn main() { As this is naturally exhaustive, this is only supported for `FromPrimitive`, not also `TryFromPrimitive`. +## Interaction with std Default + +You can define a different "fallback" default for FromPrimitive than what is used by the std lib's Default derive. + +```rust +use num_enum::FromPrimitive; +use std::convert::TryFrom; + +#[derive(Debug, Default, Eq, PartialEq, FromPrimitive)] +#[repr(u8)] +enum Number { + #[default] + Zero = 0, + #[num_enum(default)] + NonZero, +} + +fn main() { + let def = Number::default(); + assert_eq!(def, Number::Zero); + + let zero = Number::from(0u8); + assert_eq!(zero, Number::Zero); + + let one = Number::from(1u8); + assert_eq!(one, Number::NonZero); + + let two = Number::from(2u8); + assert_eq!(two, Number::NonZero); +} +``` + +Or with catch_all + +```rust +use num_enum::FromPrimitive; +use std::convert::TryFrom; + +#[derive(Debug, Default, Eq, PartialEq, FromPrimitive)] +#[repr(u8)] +enum Number { + #[default] + Zero = 0, + #[num_enum(catch_all)] + NonZero(u8), +} + +fn main() { + let def = Number::default(); + assert_eq!(def, Number::Zero); + + let zero = Number::from(0u8); + assert_eq!(zero, Number::Zero); + + let one = Number::from(1u8); + assert_eq!(one, Number::NonZero(1_u8)); + + let two = Number::from(2u8); + assert_eq!(two, Number::NonZero(2_u8)); +} +``` + ## Unsafely turning a primitive into an enum with unchecked_transmute_from If you're really certain a conversion will succeed (and have not made use of `#[num_enum(default)]` or `#[num_enum(alternatives = [..])]` diff --git a/num_enum/tests/complex.rs b/num_enum/tests/complex.rs new file mode 100644 index 0000000..8871257 --- /dev/null +++ b/num_enum/tests/complex.rs @@ -0,0 +1,37 @@ +use num_enum_derive::{FromPrimitive, IntoPrimitive}; + +// Guard against https://github.com/illicitonion/num_enum/issues/27 +mod alloc {} +mod core {} +mod num_enum {} +mod std {} + +#[test] +fn std_default_with_num_enum_catch_all() { + #[derive(Debug, Eq, PartialEq, Default, FromPrimitive, IntoPrimitive)] + #[repr(u8)] + enum Enum { + #[default] + Zero = 0, + #[num_enum(catch_all)] + NonZero(u8), + } + + assert_eq!(Enum::Zero, ::default()); + assert_eq!(Enum::NonZero(5), 5u8.into()); +} + +#[test] +fn std_default_with_num_enum_default() { + #[derive(Debug, Eq, PartialEq, Default, FromPrimitive, IntoPrimitive)] + #[repr(u8)] + enum Enum { + #[default] + Zero = 0, + #[num_enum(default)] + NonZero, + } + + assert_eq!(Enum::Zero, ::default()); + assert_eq!(Enum::NonZero, 5u8.into()); +} diff --git a/num_enum_derive/src/lib.rs b/num_enum_derive/src/lib.rs index 075ca56..fc30e39 100644 --- a/num_enum_derive/src/lib.rs +++ b/num_enum_derive/src/lib.rs @@ -97,10 +97,13 @@ pub fn derive_from_primitive(input: TokenStream) -> TokenStream { Ok(is_naturally_exhaustive) => { if is_naturally_exhaustive { quote! { unreachable!("exhaustive enum") } - } else if let Some(default_ident) = enum_info.default() { - quote! { Self::#default_ident } } else if let Some(catch_all_ident) = enum_info.catch_all() { quote! { Self::#catch_all_ident(number) } + } else if let Some(default_ident) = enum_info.default() { + quote! { Self::#default_ident } + } else if let Some(default_ident) = enum_info.std_default() { + // std default is the last priority to allow for a different num_enum FromPrimitive default than the std Default::default() + quote! { Self::#default_ident } } else { let span = Span::call_site(); let message = @@ -339,9 +342,17 @@ pub fn derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream { pub fn derive_default(stream: TokenStream) -> TokenStream { let enum_info = parse_macro_input!(stream as EnumInfo); - let default_ident = match enum_info.default() { - Some(ident) => ident, - None => { + let default_ident = match ( enum_info.default(), enum_info.std_default() ) { + // num_enum(default) takes precedence over std default + (Some(ident), None) => ident, + (None, Some(ident)) => ident, + (Some(_), Some(_)) => { + let span = Span::call_site(); + let message = + "#[derive(num_enum::Default)] cannot be used with both #[default] and #[num_enum(default)]"; + return syn::Error::new(span, message).to_compile_error().into(); + }, + (None, None) => { let span = Span::call_site(); let message = "#[derive(num_enum::Default)] requires enum to be exhaustive, or a variant marked with `#[default]` or `#[num_enum(default)]`"; diff --git a/num_enum_derive/src/parsing.rs b/num_enum_derive/src/parsing.rs index c0aced1..78248e6 100644 --- a/num_enum_derive/src/parsing.rs +++ b/num_enum_derive/src/parsing.rs @@ -45,6 +45,13 @@ impl EnumInfo { die!(self.repr.clone() => "Failed to parse repr into bit size"); } + pub(crate) fn std_default(&self) -> Option<&Ident> { + self.variants + .iter() + .find(|info| info.is_std_default) + .map(|info| &info.ident) + } + pub(crate) fn default(&self) -> Option<&Ident> { self.variants .iter() @@ -144,8 +151,6 @@ impl Parse for EnumInfo { let crate_path = attributes.crate_path.clone().map(|k| k.path); let mut variants: Vec = vec![]; - let mut has_default_variant: bool = false; - let mut has_catch_all_variant: bool = false; // Vec to keep track of the used discriminants and alt values. let mut discriminant_int_val_set = BTreeSet::new(); @@ -166,22 +171,15 @@ impl Parse for EnumInfo { // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]` // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to // keep track of whether we encountered such an attribute: + let mut is_std_default: bool = false; let mut is_default: bool = false; let mut is_catch_all: bool = false; + let mut err_token: Option> = None; for attribute in &variant.attrs { if attribute.path().is_ident("default") { - if has_default_variant { - die!(attribute => - "Multiple variants marked `#[default]` or `#[num_enum(default)]` found" - ); - } else if has_catch_all_variant { - die!(attribute => - "Attribute `default` is mutually exclusive with `catch_all`" - ); - } - is_default = true; - has_default_variant = true; + err_token = Some(Box::new(attribute.clone())); + is_std_default = true; } if attribute.path().is_ident("num_enum") { @@ -190,29 +188,11 @@ impl Parse for EnumInfo { for variant_attribute in variant_attributes.items { match variant_attribute { NumEnumVariantAttributeItem::Default(default) => { - if has_default_variant { - die!(default.keyword => - "Multiple variants marked `#[default]` or `#[num_enum(default)]` found" - ); - } else if has_catch_all_variant { - die!(default.keyword => - "Attribute `default` is mutually exclusive with `catch_all`" - ); - } + err_token = Some(Box::new(default.keyword)); is_default = true; - has_default_variant = true; } NumEnumVariantAttributeItem::CatchAll(catch_all) => { - if has_catch_all_variant { - die!(catch_all.keyword => - "Multiple variants marked with `#[num_enum(catch_all)]`" - ); - } else if has_default_variant { - die!(catch_all.keyword => - "Attribute `catch_all` is mutually exclusive with `default`" - ); - } - + err_token = Some(Box::new(catch_all.keyword)); match variant .fields .iter() @@ -224,7 +204,6 @@ impl Parse for EnumInfo { .. }] if path.is_ident(&repr) => { is_catch_all = true; - has_catch_all_variant = true; } _ => { die!(catch_all.keyword => @@ -358,15 +337,17 @@ impl Parse for EnumInfo { discriminant_int_val_set.extend(sorted_alternate_int_values); } - // Add the current discriminant to the the set to keep track. + // Add the current discriminant to the set to keep track. if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { discriminant_int_val_set.insert(canonical_value_int); } variants.push(VariantInfo { ident, + is_std_default, is_default, is_catch_all, + err_token, canonical_value: discriminant, alternative_values: flattened_raw_alternative_values, }); @@ -382,6 +363,46 @@ impl Parse for EnumInfo { } } + // Validate variants + let with_std_default: Vec<_> = variants.iter().filter(|v| v.is_std_default).collect(); + let with_default: Vec<_> = variants.iter().filter(|v| v.is_default).collect(); + let with_catch_all: Vec<_> = variants.iter().filter(|v| v.is_catch_all).collect(); + if with_std_default.len() > 1 { + if let Some(token) = with_std_default[0].err_token.as_ref() { + die!(token => + "Multiple variants marked `#[default]` found" + ); + } else { + die!(name => + "Multiple variants marked `#[default]` found" + ); + } + } + if with_default.len() > 1 { + let msg = "Multiple variants marked #[num_enum(default)]` found"; + if let Some(token) = with_default[0].err_token.as_ref() { + die!(token => msg); + } else { + die!(msg); + } + } + if with_catch_all.len() > 1 { + let msg = "Multiple variants marked #[num_enum(catch_all)]` found"; + if let Some(token) = with_catch_all[0].err_token.as_ref() { + die!(token => msg); + } else { + die!(msg); + } + } + if with_default.len() > 0 && with_catch_all.len() > 0 { + let msg = "Attribute #[num_enum(catch_all)] is mutually exclusive with #[num_enum(default)]`"; + if let Some(token) = with_catch_all[0].err_token.as_ref() { + die!(token => msg); + } else { + die!(msg); + } + } + let error_type_info = attributes.error_type.map(Into::into).unwrap_or_else(|| { let crate_name = get_crate_path(crate_path.clone()); ErrorType { @@ -501,8 +522,10 @@ fn parse_alternative_values(val_expr: &Expr) -> Result> { pub(crate) struct VariantInfo { ident: Ident, + is_std_default: bool, is_default: bool, is_catch_all: bool, + err_token: Option>, canonical_value: Expr, alternative_values: Vec, }