From c6c8056928c35e5bc673ee9f4bce34a628454627 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Tue, 3 Jun 2025 16:10:03 -0400 Subject: [PATCH] Fix union field access by representing unions as structs When a union has multiple fields, represent it as a struct with all fields at offset 0 instead of just using the largest field. This allows pointer casts between union fields to work correctly by enabling the recover_access_chain_from_offset function to find valid access chains. Fixes https://github.com/Rust-GPU/rust-gpu/issues/241. --- crates/rustc_codegen_spirv/src/abi.rs | 31 +++++++++++++------------- tests/ui/lang/core/union_cast.rs | 32 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 15 deletions(-) create mode 100644 tests/ui/lang/core/union_cast.rs diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index 04d1bf9860..34c968bdfc 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -718,21 +718,22 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx> FieldsShape::Union(_) => { assert!(!ty.is_unsized(), "{ty:#?}"); - // Represent the `union` with its largest case, which should work - // for at least `MaybeUninit` (which is between `T` and `()`), - // but also potentially some other ones as well. - // NOTE(eddyb) even if long-term this may become a byte array, that - // only works for "data types" and not "opaque handles" (images etc.). - let largest_case = (0..ty.fields.count()) - .map(|i| ty.field(cx, i)) - .max_by_key(|case| case.size); - - if let Some(case) = largest_case { - assert_eq!(ty.size, case.size); - case.spirv_type(span, cx) - } else { - assert_eq!(ty.size, Size::ZERO); - create_zst(cx, span, ty) + // NOTE(eddyb) even if long-term this may become a byte array, that only + // works for "data types" and not "opaque handles" (images etc.). + match ty.fields.count() { + // For unions with no fields, represent as a zero-sized type. + 0 => { + assert_eq!(ty.size, Size::ZERO); + create_zst(cx, span, ty) + } + // For unions with a single field, represent as the field itself. + 1 => { + let field = ty.field(cx, 0); + assert_eq!(ty.size, field.size); + field.spirv_type(span, cx) + } + // For unions with multiple fields, represent as struct with all fields at offset 0 + _ => trans_struct(cx, span, ty), } } FieldsShape::Array { stride, count } => { diff --git a/tests/ui/lang/core/union_cast.rs b/tests/ui/lang/core/union_cast.rs new file mode 100644 index 0000000000..c400df1a19 --- /dev/null +++ b/tests/ui/lang/core/union_cast.rs @@ -0,0 +1,32 @@ +// build-pass + +use spirv_std::spirv; + +#[repr(C)] +#[derive(Clone, Copy)] +struct Data { + a: f32, + b: [f32; 3], + c: f32, +} + +#[repr(C)] +union DataOrArray { + arr: [f32; 5], + str: Data, +} + +impl DataOrArray { + fn arr(&self) -> [f32; 5] { + unsafe { self.arr } + } + fn new(arr: [f32; 5]) -> Self { + Self { arr } + } +} + +#[spirv(fragment)] +pub fn main() { + let dora = DataOrArray::new([0.0, 0.0, 0.0, 0.0, 0.0]); + let _arr = dora.arr(); +} \ No newline at end of file