From 90cf470434d08b64ae73a514de2d19300d1ba31b Mon Sep 17 00:00:00 2001 From: redain <52823087+redain@users.noreply.github.com> Date: Sat, 7 Jun 2025 15:52:55 +0400 Subject: [PATCH 1/2] fix: array stride validation errors --- .../src/linker/array_stride_fixer.rs | 612 ++++++++++++++++++ crates/rustc_codegen_spirv/src/linker/mod.rs | 7 + 2 files changed, 619 insertions(+) create mode 100644 crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs diff --git a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs new file mode 100644 index 0000000000..129494b771 --- /dev/null +++ b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs @@ -0,0 +1,612 @@ +//! Fix ArrayStride decorations for newer SPIR-V versions. +//! +//! Newer SPIR-V versions forbid explicit layouts (ArrayStride decorations) in certain +//! storage classes (Function, Private, Workgroup), but allow them in others +//! (StorageBuffer, Uniform). This module removes ArrayStride decorations from +//! array types that are used in contexts where they're forbidden. + +use rspirv::dr::{Module, Operand}; +use rspirv::spirv::{Capability, Decoration, Op, StorageClass, Word}; +use rustc_data_structures::fx::FxHashSet; + +/// Check if a storage class allows explicit layout decorations based on SPIR-V version and capabilities. +/// This matches the logic from SPIRV-Tools validate_decorations.cpp AllowsLayout function. +fn allows_layout( + storage_class: StorageClass, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, +) -> bool { + match storage_class { + // Always explicitly laid out + StorageClass::StorageBuffer + | StorageClass::Uniform + | StorageClass::PhysicalStorageBuffer + | StorageClass::PushConstant => true, + + // Never allows layout + StorageClass::UniformConstant => false, + + // Requires explicit capability + StorageClass::Workgroup => has_workgroup_layout_capability, + + // Only forbidden in SPIR-V 1.4+ + StorageClass::Function | StorageClass::Private => spirv_version < (1, 4), + + // Block is used generally and mesh shaders use Offset + StorageClass::Input | StorageClass::Output => true, + + // TODO: Some storage classes in ray tracing use explicit layout + // decorations, but it is not well documented which. For now treat other + // storage classes as allowed to be laid out. + _ => true, + } +} + +/// Remove ArrayStride decorations from array types used in storage classes where +/// newer SPIR-V versions forbid explicit layouts. +pub fn fix_array_stride_decorations(module: &mut Module) { + // Get SPIR-V version from module header + let spirv_version = module + .header + .as_ref() + .map(|h| h.version()) + .unwrap_or((1, 0)); // Default to 1.0 if no header + + // Check for WorkgroupMemoryExplicitLayoutKHR capability + let has_workgroup_layout_capability = module.capabilities.iter().any(|inst| { + inst.class.opcode == Op::Capability + && inst.operands.first() + == Some(&Operand::Capability( + Capability::WorkgroupMemoryExplicitLayoutKHR, + )) + }); + + // Find all array types that have ArrayStride decorations + let mut array_types_with_stride = FxHashSet::default(); + for inst in &module.annotations { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + array_types_with_stride.insert(target_id); + } + } + + // Check each array type with ArrayStride to see if it's used in forbidden contexts + let mut array_types_to_fix = FxHashSet::default(); + for &array_type_id in &array_types_with_stride { + if is_array_type_used_in_forbidden_storage_class( + array_type_id, + module, + spirv_version, + has_workgroup_layout_capability, + ) { + array_types_to_fix.insert(array_type_id); + } + } + + // Remove ArrayStride decorations for the problematic types + module.annotations.retain(|inst| { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + !array_types_to_fix.contains(&target_id) + } else { + true + } + }); +} + +/// Check if an array type is used in any variable with a forbidden storage class +fn is_array_type_used_in_forbidden_storage_class( + array_type_id: Word, + module: &Module, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, +) -> bool { + // Check global variables + for inst in &module.types_global_values { + if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + let storage_class = inst.operands[0].unwrap_storage_class(); + + // Check if this storage class forbids explicit layouts + if !allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + // Check if this variable's type hierarchy contains the array type + if let Some(var_type_id) = inst.result_type { + if type_hierarchy_contains_array_type(var_type_id, array_type_id, module) { + return true; + } + } + } + } + } + + // Check function-local variables + for function in &module.functions { + for block in &function.blocks { + for inst in &block.instructions { + if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + let storage_class = inst.operands[0].unwrap_storage_class(); + + // Check if this storage class forbids explicit layouts + if !allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + // Check if this variable's type hierarchy contains the array type + if let Some(var_type_id) = inst.result_type { + if type_hierarchy_contains_array_type( + var_type_id, + array_type_id, + module, + ) { + return true; + } + } + } + } + } + } + } + + false +} + +/// Check if a type hierarchy contains a specific array type +fn type_hierarchy_contains_array_type( + type_id: Word, + target_array_type_id: Word, + module: &Module, +) -> bool { + if type_id == target_array_type_id { + return true; + } + + // Find the type definition + if let Some(type_inst) = module + .types_global_values + .iter() + .find(|inst| inst.result_id == Some(type_id)) + { + match type_inst.class.opcode { + Op::TypeArray | Op::TypeRuntimeArray => { + // Check element type recursively + if !type_inst.operands.is_empty() { + let element_type = type_inst.operands[0].unwrap_id_ref(); + return type_hierarchy_contains_array_type( + element_type, + target_array_type_id, + module, + ); + } + } + Op::TypeStruct => { + // Check all field types + for operand in &type_inst.operands { + if let Ok(field_type) = operand.id_ref_any().ok_or(()) { + if type_hierarchy_contains_array_type( + field_type, + target_array_type_id, + module, + ) { + return true; + } + } + } + } + Op::TypePointer => { + // Follow pointer to pointee type + if type_inst.operands.len() >= 2 { + let pointee_type = type_inst.operands[1].unwrap_id_ref(); + return type_hierarchy_contains_array_type( + pointee_type, + target_array_type_id, + module, + ); + } + } + _ => {} + } + } + false +} + +#[cfg(test)] +mod tests { + use super::*; + use rspirv::dr::Module; + + // Helper function to assemble SPIR-V from text + fn assemble_spirv(spirv: &str) -> Vec { + use spirv_tools::assembler::{self, Assembler}; + + let assembler = assembler::create(None); + let spv_binary = assembler + .assemble(spirv, assembler::AssemblerOptions::default()) + .expect("Failed to assemble test SPIR-V"); + let contents: &[u8] = spv_binary.as_ref(); + contents.to_vec() + } + + // Helper function to load SPIR-V binary into Module + fn load_spirv(bytes: &[u8]) -> Module { + use rspirv::dr::Loader; + + let mut loader = Loader::new(); + rspirv::binary::parse_bytes(bytes, &mut loader).unwrap(); + loader.module() + } + + // Helper function to count ArrayStride decorations + fn count_array_stride_decorations(module: &Module) -> usize { + module + .annotations + .iter() + .filter(|inst| { + inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + }) + .count() + } + + #[test] + fn test_removes_array_stride_from_workgroup_arrays() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for workgroup storage (forbidden in newer SPIR-V) + %ptr_workgroup = OpTypePointer Workgroup %array_ty + + ; Variables in workgroup storage class + %workgroup_var = OpVariable %ptr_workgroup Workgroup + + ; ArrayStride decoration that should be removed + OpDecorate %array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed from arrays used in Workgroup storage + assert_eq!(count_array_stride_decorations(&module), 0); + } + + #[test] + fn test_keeps_array_stride_for_workgroup_with_capability() { + let spirv = r#" + OpCapability Shader + OpCapability WorkgroupMemoryExplicitLayoutKHR + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for workgroup storage (allowed with capability) + %ptr_workgroup = OpTypePointer Workgroup %array_ty + + ; Variables in workgroup storage class + %workgroup_var = OpVariable %ptr_workgroup Workgroup + + ; ArrayStride decoration that should be kept with capability + OpDecorate %array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be kept when WorkgroupMemoryExplicitLayoutKHR capability is present + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_keeps_array_stride_for_storage_buffer_arrays() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for storage buffer (always allowed) + %ptr_storage_buffer = OpTypePointer StorageBuffer %array_ty + + ; Variables in storage buffer storage class + %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer + + ; ArrayStride decoration that should be kept + OpDecorate %array_ty ArrayStride 4 + OpDecorate %storage_buffer_var DescriptorSet 0 + OpDecorate %storage_buffer_var Binding 0 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be kept for StorageBuffer storage class + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_handles_runtime_arrays_in_workgroup() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %runtime_array_ty = OpTypeRuntimeArray %u32 + + ; Pointer types for workgroup storage (forbidden) + %ptr_workgroup = OpTypePointer Workgroup %runtime_array_ty + + ; Variables in workgroup storage class + %workgroup_var = OpVariable %ptr_workgroup Workgroup + + ; ArrayStride decoration that should be removed + OpDecorate %runtime_array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed from runtime arrays in Workgroup storage + assert_eq!(count_array_stride_decorations(&module), 0); + } + + #[test] + fn test_mixed_storage_classes_removes_problematic_arrays() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %forbidden_array_ty = OpTypeArray %u32 %u32_256 + %allowed_array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for different storage classes + %ptr_workgroup = OpTypePointer Workgroup %forbidden_array_ty + %ptr_storage_buffer = OpTypePointer StorageBuffer %allowed_array_ty + + ; Variables in different storage classes + %workgroup_var = OpVariable %ptr_workgroup Workgroup + %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer + + ; ArrayStride decorations + OpDecorate %forbidden_array_ty ArrayStride 4 + OpDecorate %allowed_array_ty ArrayStride 4 + OpDecorate %storage_buffer_var DescriptorSet 0 + OpDecorate %storage_buffer_var Binding 0 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 2); + + fix_array_stride_decorations(&mut module); + + // Only the Workgroup array should have its ArrayStride removed + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_nested_structs_and_arrays_in_function_storage() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; ArrayStride decoration that should be removed in SPIR-V 1.4+ + OpDecorate %inner_array_ty ArrayStride 16 + + ; Type declarations + %void = OpTypeVoid + %float = OpTypeFloat 32 + %u32 = OpTypeInt 32 0 + %u32_4 = OpConstant %u32 4 + %inner_array_ty = OpTypeArray %float %u32_4 + %inner_struct_ty = OpTypeStruct %inner_array_ty + %outer_struct_ty = OpTypeStruct %inner_struct_ty + + ; Pointer types for function storage (forbidden in SPIR-V 1.4+) + %ptr_function = OpTypePointer Function %outer_struct_ty + %func_ty = OpTypeFunction %void + + ; Function variable inside function + %main = OpFunction %void None %func_ty + %entry = OpLabel + %function_var = OpVariable %ptr_function Function + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + // Force SPIR-V 1.4 for this test + if let Some(ref mut header) = module.header { + header.set_version(1, 4); + } + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed in SPIR-V 1.4+ for Function storage + assert_eq!(count_array_stride_decorations(&module), 0); + } + + #[test] + fn test_function_storage_spirv_13_keeps_decorations() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for function storage + %ptr_function = OpTypePointer Function %array_ty + + ; Function variable + %main = OpFunction %void None %func_ty + %entry = OpLabel + %function_var = OpVariable %ptr_function Function + OpReturn + OpFunctionEnd + + ; ArrayStride decoration that should be kept in SPIR-V 1.3 + OpDecorate %array_ty ArrayStride 4 + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + // Force SPIR-V 1.3 for this test + if let Some(ref mut header) = module.header { + header.set_version(1, 3); + } + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be kept in SPIR-V 1.3 for Function storage + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_private_storage_spirv_14_removes_decorations() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for private storage + %ptr_private = OpTypePointer Private %array_ty + + ; Variables in private storage class + %private_var = OpVariable %ptr_private Private + + ; ArrayStride decoration that should be removed in SPIR-V 1.4+ + OpDecorate %array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + // Force SPIR-V 1.4 for this test + if let Some(ref mut header) = module.header { + header.set_version(1, 4); + } + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed in SPIR-V 1.4+ for Private storage + assert_eq!(count_array_stride_decorations(&module), 0); + } +} diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index fa69dc8e7f..8e4933da36 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod test; +mod array_stride_fixer; mod dce; mod destructure_composites; mod duplicates; @@ -355,6 +356,12 @@ pub fn link( }); } + // Fix ArrayStride decorations for arrays in storage classes where newer SPIR-V versions forbid explicit layouts + { + let _timer = sess.timer("fix_array_stride_decorations"); + array_stride_fixer::fix_array_stride_decorations(&mut output); + } + // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too! { if opts.dce { From 44f5287eb53d875f3f9b25a2a68d9a3f1c71e269 Mon Sep 17 00:00:00 2001 From: redain <52823087+redain@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:25:19 +0400 Subject: [PATCH 2/2] fixes --- .../src/linker/array_stride_fixer.rs | 942 +++++++++--------- .../src/linker/duplicates.rs | 54 +- crates/rustc_codegen_spirv/src/linker/mod.rs | 4 +- .../function_storage_spirv13_kept.rs | 21 + .../function_storage_spirv13_kept.stderr | 30 + .../mixed_storage_classes.rs | 21 + .../mixed_storage_classes.stderr | 36 + .../nested_structs_function_storage.rs | 34 + .../nested_structs_function_storage.stderr | 45 + .../private_storage_spirv14_removed.rs | 22 + .../private_storage_spirv14_removed.stderr | 22 + .../runtime_arrays_in_workgroup.rs | 22 + .../runtime_arrays_in_workgroup.stderr | 37 + .../storage_buffer_arrays_kept.rs | 17 + .../storage_buffer_arrays_kept.stderr | 31 + .../workgroup_arrays_removed.rs | 21 + .../workgroup_arrays_removed.stderr | 36 + .../workgroup_arrays_with_capability.rs | 22 + .../workgroup_arrays_with_capability.stderr | 39 + 19 files changed, 987 insertions(+), 469 deletions(-) create mode 100644 tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs create mode 100644 tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr create mode 100644 tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs create mode 100644 tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr create mode 100644 tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs create mode 100644 tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr create mode 100644 tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs create mode 100644 tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr create mode 100644 tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs create mode 100644 tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr create mode 100644 tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs create mode 100644 tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr diff --git a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs index 129494b771..ca1736af79 100644 --- a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs +++ b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs @@ -1,28 +1,44 @@ -//! Fix ArrayStride decorations for newer SPIR-V versions. +//! Fix `ArrayStride` decorations for newer SPIR-V versions. //! -//! Newer SPIR-V versions forbid explicit layouts (ArrayStride decorations) in certain +//! Newer SPIR-V versions forbid explicit layouts (`ArrayStride` decorations) in certain //! storage classes (Function, Private, Workgroup), but allow them in others -//! (StorageBuffer, Uniform). This module removes ArrayStride decorations from +//! (`StorageBuffer`, Uniform). This module removes `ArrayStride` decorations from //! array types that are used in contexts where they're forbidden. use rspirv::dr::{Module, Operand}; use rspirv::spirv::{Capability, Decoration, Op, StorageClass, Word}; -use rustc_data_structures::fx::FxHashSet; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; + +/// Describes how an array type is used across different storage classes +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ArrayUsagePattern { + /// Array is only used in storage classes that require explicit layout + LayoutRequired, + /// Array is only used in storage classes that forbid explicit layout + LayoutForbidden, + /// Array is used in both types of storage classes (needs specialization) + MixedUsage, + /// Array is not used in any variables (orphaned) + Unused, +} + +/// Context information about array type usage +#[derive(Debug, Clone)] +pub struct ArrayStorageContext { + /// Which storage classes this array type is used in + pub storage_classes: FxHashSet, + /// Whether this array allows or forbids layout in its contexts + pub usage_pattern: ArrayUsagePattern, +} /// Check if a storage class allows explicit layout decorations based on SPIR-V version and capabilities. -/// This matches the logic from SPIRV-Tools validate_decorations.cpp AllowsLayout function. +/// This matches the logic from SPIRV-Tools `validate_decorations.cpp` `AllowsLayout` function. fn allows_layout( storage_class: StorageClass, spirv_version: (u8, u8), has_workgroup_layout_capability: bool, ) -> bool { match storage_class { - // Always explicitly laid out - StorageClass::StorageBuffer - | StorageClass::Uniform - | StorageClass::PhysicalStorageBuffer - | StorageClass::PushConstant => true, - // Never allows layout StorageClass::UniformConstant => false, @@ -32,25 +48,18 @@ fn allows_layout( // Only forbidden in SPIR-V 1.4+ StorageClass::Function | StorageClass::Private => spirv_version < (1, 4), - // Block is used generally and mesh shaders use Offset - StorageClass::Input | StorageClass::Output => true, - - // TODO: Some storage classes in ray tracing use explicit layout - // decorations, but it is not well documented which. For now treat other - // storage classes as allowed to be laid out. + // All other storage classes allow layout by default _ => true, } } -/// Remove ArrayStride decorations from array types used in storage classes where -/// newer SPIR-V versions forbid explicit layouts. -pub fn fix_array_stride_decorations(module: &mut Module) { +/// Comprehensive fix for `ArrayStride` decorations with optional type deduplication +pub fn fix_array_stride_decorations_with_deduplication( + module: &mut Module, + use_context_aware_deduplication: bool, +) { // Get SPIR-V version from module header - let spirv_version = module - .header - .as_ref() - .map(|h| h.version()) - .unwrap_or((1, 0)); // Default to 1.0 if no header + let spirv_version = module.header.as_ref().map_or((1, 0), |h| h.version()); // Default to 1.0 if no header // Check for WorkgroupMemoryExplicitLayoutKHR capability let has_workgroup_layout_capability = module.capabilities.iter().any(|inst| { @@ -61,94 +70,139 @@ pub fn fix_array_stride_decorations(module: &mut Module) { )) }); - // Find all array types that have ArrayStride decorations - let mut array_types_with_stride = FxHashSet::default(); - for inst in &module.annotations { - if inst.class.opcode == Op::Decorate - && inst.operands.len() >= 2 - && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) - { - let target_id = inst.operands[0].unwrap_id_ref(); - array_types_with_stride.insert(target_id); - } - } + // Analyze storage class contexts for all array types + let array_contexts = + analyze_array_storage_contexts(module, spirv_version, has_workgroup_layout_capability); - // Check each array type with ArrayStride to see if it's used in forbidden contexts - let mut array_types_to_fix = FxHashSet::default(); - for &array_type_id in &array_types_with_stride { - if is_array_type_used_in_forbidden_storage_class( - array_type_id, + // Handle mixed usage arrays by creating specialized versions + let specializations = create_specialized_array_types(module, &array_contexts); + + // Update references to use appropriate specialized types + if !specializations.is_empty() { + update_references_for_specialized_arrays( module, + &specializations, spirv_version, has_workgroup_layout_capability, - ) { - array_types_to_fix.insert(array_type_id); - } + ); } - // Remove ArrayStride decorations for the problematic types + // Apply context-aware type deduplication if requested + if use_context_aware_deduplication { + crate::linker::duplicates::remove_duplicate_types_with_array_context( + module, + Some(&array_contexts), + ); + } + + // Remove ArrayStride decorations from arrays used in forbidden contexts + remove_array_stride_decorations_for_forbidden_contexts(module, &array_contexts); +} + +/// Remove `ArrayStride` decorations from arrays used in layout-forbidden storage classes +fn remove_array_stride_decorations_for_forbidden_contexts( + module: &mut Module, + array_contexts: &FxHashMap, +) { + // Find array types that should have their ArrayStride decorations removed + // Remove from arrays used in forbidden contexts OR mixed usage that includes forbidden contexts + let arrays_to_remove_stride: FxHashSet = array_contexts + .iter() + .filter_map(|(&id, context)| { + match context.usage_pattern { + // Always remove from arrays used only in forbidden contexts + ArrayUsagePattern::LayoutForbidden => Some(id), + // For mixed usage, remove if it includes forbidden contexts that would cause validation errors + ArrayUsagePattern::MixedUsage => { + // If the array is used in any context that forbids layout, remove the decoration + // This is a conservative approach that prevents validation errors + let has_forbidden_context = context.storage_classes.iter().any(|&sc| { + !allows_layout(sc, (1, 4), false) // Use SPIR-V 1.4 rules for conservative check + }); + + if has_forbidden_context { + Some(id) + } else { + None + } + } + ArrayUsagePattern::LayoutRequired | ArrayUsagePattern::Unused => None, + } + }) + .collect(); + + // Remove ArrayStride decorations for layout-forbidden arrays module.annotations.retain(|inst| { if inst.class.opcode == Op::Decorate && inst.operands.len() >= 2 && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) { let target_id = inst.operands[0].unwrap_id_ref(); - !array_types_to_fix.contains(&target_id) + !arrays_to_remove_stride.contains(&target_id) } else { true } }); } -/// Check if an array type is used in any variable with a forbidden storage class -fn is_array_type_used_in_forbidden_storage_class( - array_type_id: Word, +/// Analyze storage class contexts for all array types in the module +pub fn analyze_array_storage_contexts( module: &Module, spirv_version: (u8, u8), has_workgroup_layout_capability: bool, -) -> bool { - // Check global variables +) -> FxHashMap { + let mut array_contexts: FxHashMap = FxHashMap::default(); + + // Find all array and runtime array types + let mut array_types = FxHashSet::default(); for inst in &module.types_global_values { - if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { + if let Some(result_id) = inst.result_id { + array_types.insert(result_id); + array_contexts.insert(result_id, ArrayStorageContext { + storage_classes: FxHashSet::default(), + usage_pattern: ArrayUsagePattern::Unused, + }); + } + } + } + + // Analyze global variables + for inst in &module.types_global_values { + if inst.class.opcode == Op::Variable && !inst.operands.is_empty() { let storage_class = inst.operands[0].unwrap_storage_class(); - // Check if this storage class forbids explicit layouts - if !allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - // Check if this variable's type hierarchy contains the array type - if let Some(var_type_id) = inst.result_type { + if let Some(var_type_id) = inst.result_type { + // Check if this variable's type hierarchy contains any array types + for &array_type_id in &array_types { if type_hierarchy_contains_array_type(var_type_id, array_type_id, module) { - return true; + if let Some(context) = array_contexts.get_mut(&array_type_id) { + context.storage_classes.insert(storage_class); + } } } } } } - // Check function-local variables + // Analyze function-local variables for function in &module.functions { for block in &function.blocks { for inst in &block.instructions { - if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + if inst.class.opcode == Op::Variable && !inst.operands.is_empty() { let storage_class = inst.operands[0].unwrap_storage_class(); - // Check if this storage class forbids explicit layouts - if !allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - // Check if this variable's type hierarchy contains the array type - if let Some(var_type_id) = inst.result_type { + if let Some(var_type_id) = inst.result_type { + // Check if this variable's type hierarchy contains any array types + for &array_type_id in &array_types { if type_hierarchy_contains_array_type( var_type_id, array_type_id, module, ) { - return true; + if let Some(context) = array_contexts.get_mut(&array_type_id) { + context.storage_classes.insert(storage_class); + } } } } @@ -157,7 +211,357 @@ fn is_array_type_used_in_forbidden_storage_class( } } - false + // Determine usage patterns + for context in array_contexts.values_mut() { + if context.storage_classes.is_empty() { + context.usage_pattern = ArrayUsagePattern::Unused; + } else { + let mut requires_layout = false; + let mut forbids_layout = false; + + for &storage_class in &context.storage_classes { + if allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + requires_layout = true; + } else { + forbids_layout = true; + } + } + + context.usage_pattern = match (requires_layout, forbids_layout) { + (true, true) => ArrayUsagePattern::MixedUsage, + (true, false) => ArrayUsagePattern::LayoutRequired, + (false, true) => ArrayUsagePattern::LayoutForbidden, + (false, false) => ArrayUsagePattern::Unused, // Should not happen + }; + } + } + + array_contexts +} + +/// Create specialized array types for mixed usage scenarios +fn create_specialized_array_types( + module: &mut Module, + array_contexts: &FxHashMap, +) -> FxHashMap { + let mut specializations = FxHashMap::default(); // original_id -> (layout_required_id, layout_forbidden_id) + + // Find arrays that need specialization (mixed usage) + let arrays_to_specialize: Vec = array_contexts + .iter() + .filter_map(|(&id, context)| { + if context.usage_pattern == ArrayUsagePattern::MixedUsage { + Some(id) + } else { + None + } + }) + .collect(); + + if arrays_to_specialize.is_empty() { + return specializations; + } + + // Generate new IDs for specialized types + let mut next_id = module.header.as_ref().map_or(1, |h| h.bound); + + for &original_id in &arrays_to_specialize { + let layout_required_id = next_id; + next_id += 1; + let layout_forbidden_id = next_id; + next_id += 1; + + specializations.insert(original_id, (layout_required_id, layout_forbidden_id)); + } + + // Update the module header bound + if let Some(ref mut header) = module.header { + header.bound = next_id; + } + + // Create specialized array type definitions + let mut new_type_instructions = Vec::new(); + + for &original_id in &arrays_to_specialize { + if let Some((layout_required_id, layout_forbidden_id)) = specializations.get(&original_id) { + // Find the original array type instruction + if let Some(original_inst) = module + .types_global_values + .iter() + .find(|inst| inst.result_id == Some(original_id)) + .cloned() + { + // Create layout-required variant (keeps ArrayStride decorations) + let mut layout_required_inst = original_inst.clone(); + layout_required_inst.result_id = Some(*layout_required_id); + new_type_instructions.push(layout_required_inst); + + // Create layout-forbidden variant (will have ArrayStride decorations removed later) + let mut layout_forbidden_inst = original_inst.clone(); + layout_forbidden_inst.result_id = Some(*layout_forbidden_id); + new_type_instructions.push(layout_forbidden_inst); + } + } + } + + // IMPORTANT: Do not add the specialized arrays to the end - this would create forward references + // Instead, we need to insert them in the correct position to maintain SPIR-V type ordering + + // Find the insertion point: after the last original array type that needs specialization + // This ensures all specialized arrays are defined before any types that might reference them + let mut insertion_point = 0; + for (i, inst) in module.types_global_values.iter().enumerate() { + if let Some(result_id) = inst.result_id { + if arrays_to_specialize.contains(&result_id) { + insertion_point = i + 1; + } + } + } + + // Insert the specialized array types at the calculated position + // This maintains the invariant that referenced types appear before referencing types + for (i, new_inst) in new_type_instructions.into_iter().enumerate() { + module + .types_global_values + .insert(insertion_point + i, new_inst); + } + + specializations +} + +/// Update all references to specialized array types based on storage class context +fn update_references_for_specialized_arrays( + module: &mut Module, + specializations: &FxHashMap, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, +) { + // Update struct types that contain specialized arrays + // This is safe now because all specialized arrays have been properly positioned in the types section + for inst in &mut module.types_global_values { + if inst.class.opcode == Op::TypeStruct { + for operand in &mut inst.operands { + if let Some(referenced_id) = operand.id_ref_any() { + if let Some(&(layout_required_id, _layout_forbidden_id)) = + specializations.get(&referenced_id) + { + // For struct types, we use the layout-required variant since structs + // can be used in both layout-required and layout-forbidden contexts + *operand = Operand::IdRef(layout_required_id); + } + } + } + } + } + + // Collect all existing pointer types that reference specialized arrays FIRST + let mut existing_pointers_to_specialize = Vec::new(); + for inst in &module.types_global_values { + if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { + let pointee_type = inst.operands[1].unwrap_id_ref(); + if specializations.contains_key(&pointee_type) { + existing_pointers_to_specialize.push(inst.clone()); + } + } + } + + // Create ALL specialized pointer types from the collected existing ones + let mut next_id = module.header.as_ref().map_or(1, |h| h.bound); + let mut new_pointer_instructions = Vec::new(); + let mut pointer_type_mappings = FxHashMap::default(); // old_pointer_id -> new_pointer_id + + // Create new pointer types for each storage class context + for inst in &existing_pointers_to_specialize { + let storage_class = inst.operands[0].unwrap_storage_class(); + let pointee_type = inst.operands[1].unwrap_id_ref(); + + if let Some(&(layout_required_id, layout_forbidden_id)) = specializations.get(&pointee_type) + { + let allows_layout_for_sc = allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ); + + // Create new pointer type pointing to appropriate specialized array + let target_array_id = if allows_layout_for_sc { + layout_required_id + } else { + layout_forbidden_id + }; + + let mut new_pointer_inst = inst.clone(); + new_pointer_inst.result_id = Some(next_id); + new_pointer_inst.operands[1] = Operand::IdRef(target_array_id); + new_pointer_instructions.push(new_pointer_inst); + + // Map old pointer to new pointer + if let Some(old_pointer_id) = inst.result_id { + pointer_type_mappings.insert(old_pointer_id, next_id); + } + next_id += 1; + } + } + + // Update module header bound to account for the new pointer types + if let Some(ref mut header) = module.header { + header.bound = next_id; + } + + // Insert new pointer type instructions in the correct position + // They must come after the specialized arrays they reference, but before any variables that use them + + // Find the last specialized array position to ensure pointers come after their pointee types + let mut pointer_insertion_point = 0; + for (i, inst) in module.types_global_values.iter().enumerate() { + if let Some(result_id) = inst.result_id { + // Check if this is one of our specialized arrays + if specializations + .values() + .any(|&(req_id, forb_id)| result_id == req_id || result_id == forb_id) + { + pointer_insertion_point = i + 1; + } + } + } + + // Insert the new pointer types at the calculated position + // This ensures they appear after specialized arrays but before variables + for (i, new_pointer_inst) in new_pointer_instructions.into_iter().enumerate() { + module + .types_global_values + .insert(pointer_insertion_point + i, new_pointer_inst); + } + + // Update ALL references to old pointer types throughout the entire module + // This includes variables, function parameters, and all instructions + + // Update global variables and function types + for inst in &mut module.types_global_values { + match inst.class.opcode { + Op::Variable => { + if let Some(var_type_id) = inst.result_type { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&var_type_id) { + inst.result_type = Some(new_pointer_id); + } + } + } + Op::TypeFunction => { + // Update function type operands (return type and parameter types) + for operand in &mut inst.operands { + if let Some(referenced_id) = operand.id_ref_any() { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&referenced_id) { + *operand = Operand::IdRef(new_pointer_id); + } + } + } + } + _ => {} + } + } + + // Update function signatures and local variables + for function in &mut module.functions { + // Update function parameters + for param in &mut function.parameters { + if let Some(param_type_id) = param.result_type { + if let Some(&new_pointer_id) = pointer_type_mappings.get(¶m_type_id) { + param.result_type = Some(new_pointer_id); + } + } + } + + // Update all instructions in function bodies + for block in &mut function.blocks { + for inst in &mut block.instructions { + // Update result type + if let Some(result_type_id) = inst.result_type { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&result_type_id) { + inst.result_type = Some(new_pointer_id); + } + } + + // Update operand references + for operand in &mut inst.operands { + if let Some(referenced_id) = operand.id_ref_any() { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&referenced_id) { + *operand = Operand::IdRef(new_pointer_id); + } + } + } + } + } + } + + // Remove old pointer type instructions that reference specialized arrays + module.types_global_values.retain(|inst| { + if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { + let pointee_type = inst.operands[1].unwrap_id_ref(); + !specializations.contains_key(&pointee_type) + } else { + true + } + }); + + // Remove original array type instructions that were specialized + let arrays_to_remove: FxHashSet = specializations.keys().cloned().collect(); + module.types_global_values.retain(|inst| { + if let Some(result_id) = inst.result_id { + !arrays_to_remove.contains(&result_id) + } else { + true + } + }); + + // STEP 8: Copy ArrayStride decorations from original arrays to layout-required variants + // and remove them from layout-forbidden variants + let mut decorations_to_add = Vec::new(); + let layout_forbidden_arrays: FxHashSet = specializations + .values() + .map(|&(_, layout_forbidden_id)| layout_forbidden_id) + .collect(); + + // Find existing ArrayStride decorations on original arrays and copy them to layout-required variants + for inst in &module.annotations { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + if let Some(&(layout_required_id, _)) = specializations.get(&target_id) { + // Copy the decoration to the layout-required variant + let mut new_decoration = inst.clone(); + new_decoration.operands[0] = Operand::IdRef(layout_required_id); + decorations_to_add.push(new_decoration); + } + } + } + + // Add the copied decorations + module.annotations.extend(decorations_to_add); + + // Remove ArrayStride decorations from layout-forbidden arrays and original arrays + let arrays_to_remove_decorations: FxHashSet = layout_forbidden_arrays + .iter() + .cloned() + .chain(specializations.keys().cloned()) // Also remove from original arrays + .collect(); + + module.annotations.retain(|inst| { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + !arrays_to_remove_decorations.contains(&target_id) + } else { + true + } + }); } /// Check if a type hierarchy contains a specific array type @@ -218,395 +622,3 @@ fn type_hierarchy_contains_array_type( } false } - -#[cfg(test)] -mod tests { - use super::*; - use rspirv::dr::Module; - - // Helper function to assemble SPIR-V from text - fn assemble_spirv(spirv: &str) -> Vec { - use spirv_tools::assembler::{self, Assembler}; - - let assembler = assembler::create(None); - let spv_binary = assembler - .assemble(spirv, assembler::AssemblerOptions::default()) - .expect("Failed to assemble test SPIR-V"); - let contents: &[u8] = spv_binary.as_ref(); - contents.to_vec() - } - - // Helper function to load SPIR-V binary into Module - fn load_spirv(bytes: &[u8]) -> Module { - use rspirv::dr::Loader; - - let mut loader = Loader::new(); - rspirv::binary::parse_bytes(bytes, &mut loader).unwrap(); - loader.module() - } - - // Helper function to count ArrayStride decorations - fn count_array_stride_decorations(module: &Module) -> usize { - module - .annotations - .iter() - .filter(|inst| { - inst.class.opcode == Op::Decorate - && inst.operands.len() >= 2 - && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) - }) - .count() - } - - #[test] - fn test_removes_array_stride_from_workgroup_arrays() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for workgroup storage (forbidden in newer SPIR-V) - %ptr_workgroup = OpTypePointer Workgroup %array_ty - - ; Variables in workgroup storage class - %workgroup_var = OpVariable %ptr_workgroup Workgroup - - ; ArrayStride decoration that should be removed - OpDecorate %array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed from arrays used in Workgroup storage - assert_eq!(count_array_stride_decorations(&module), 0); - } - - #[test] - fn test_keeps_array_stride_for_workgroup_with_capability() { - let spirv = r#" - OpCapability Shader - OpCapability WorkgroupMemoryExplicitLayoutKHR - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for workgroup storage (allowed with capability) - %ptr_workgroup = OpTypePointer Workgroup %array_ty - - ; Variables in workgroup storage class - %workgroup_var = OpVariable %ptr_workgroup Workgroup - - ; ArrayStride decoration that should be kept with capability - OpDecorate %array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be kept when WorkgroupMemoryExplicitLayoutKHR capability is present - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_keeps_array_stride_for_storage_buffer_arrays() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for storage buffer (always allowed) - %ptr_storage_buffer = OpTypePointer StorageBuffer %array_ty - - ; Variables in storage buffer storage class - %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer - - ; ArrayStride decoration that should be kept - OpDecorate %array_ty ArrayStride 4 - OpDecorate %storage_buffer_var DescriptorSet 0 - OpDecorate %storage_buffer_var Binding 0 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be kept for StorageBuffer storage class - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_handles_runtime_arrays_in_workgroup() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %runtime_array_ty = OpTypeRuntimeArray %u32 - - ; Pointer types for workgroup storage (forbidden) - %ptr_workgroup = OpTypePointer Workgroup %runtime_array_ty - - ; Variables in workgroup storage class - %workgroup_var = OpVariable %ptr_workgroup Workgroup - - ; ArrayStride decoration that should be removed - OpDecorate %runtime_array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed from runtime arrays in Workgroup storage - assert_eq!(count_array_stride_decorations(&module), 0); - } - - #[test] - fn test_mixed_storage_classes_removes_problematic_arrays() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %forbidden_array_ty = OpTypeArray %u32 %u32_256 - %allowed_array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for different storage classes - %ptr_workgroup = OpTypePointer Workgroup %forbidden_array_ty - %ptr_storage_buffer = OpTypePointer StorageBuffer %allowed_array_ty - - ; Variables in different storage classes - %workgroup_var = OpVariable %ptr_workgroup Workgroup - %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer - - ; ArrayStride decorations - OpDecorate %forbidden_array_ty ArrayStride 4 - OpDecorate %allowed_array_ty ArrayStride 4 - OpDecorate %storage_buffer_var DescriptorSet 0 - OpDecorate %storage_buffer_var Binding 0 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 2); - - fix_array_stride_decorations(&mut module); - - // Only the Workgroup array should have its ArrayStride removed - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_nested_structs_and_arrays_in_function_storage() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; ArrayStride decoration that should be removed in SPIR-V 1.4+ - OpDecorate %inner_array_ty ArrayStride 16 - - ; Type declarations - %void = OpTypeVoid - %float = OpTypeFloat 32 - %u32 = OpTypeInt 32 0 - %u32_4 = OpConstant %u32 4 - %inner_array_ty = OpTypeArray %float %u32_4 - %inner_struct_ty = OpTypeStruct %inner_array_ty - %outer_struct_ty = OpTypeStruct %inner_struct_ty - - ; Pointer types for function storage (forbidden in SPIR-V 1.4+) - %ptr_function = OpTypePointer Function %outer_struct_ty - %func_ty = OpTypeFunction %void - - ; Function variable inside function - %main = OpFunction %void None %func_ty - %entry = OpLabel - %function_var = OpVariable %ptr_function Function - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - // Force SPIR-V 1.4 for this test - if let Some(ref mut header) = module.header { - header.set_version(1, 4); - } - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed in SPIR-V 1.4+ for Function storage - assert_eq!(count_array_stride_decorations(&module), 0); - } - - #[test] - fn test_function_storage_spirv_13_keeps_decorations() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for function storage - %ptr_function = OpTypePointer Function %array_ty - - ; Function variable - %main = OpFunction %void None %func_ty - %entry = OpLabel - %function_var = OpVariable %ptr_function Function - OpReturn - OpFunctionEnd - - ; ArrayStride decoration that should be kept in SPIR-V 1.3 - OpDecorate %array_ty ArrayStride 4 - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - // Force SPIR-V 1.3 for this test - if let Some(ref mut header) = module.header { - header.set_version(1, 3); - } - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be kept in SPIR-V 1.3 for Function storage - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_private_storage_spirv_14_removes_decorations() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for private storage - %ptr_private = OpTypePointer Private %array_ty - - ; Variables in private storage class - %private_var = OpVariable %ptr_private Private - - ; ArrayStride decoration that should be removed in SPIR-V 1.4+ - OpDecorate %array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - // Force SPIR-V 1.4 for this test - if let Some(ref mut header) = module.header { - header.set_version(1, 4); - } - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed in SPIR-V 1.4+ for Private storage - assert_eq!(count_array_stride_decorations(&module), 0); - } -} diff --git a/crates/rustc_codegen_spirv/src/linker/duplicates.rs b/crates/rustc_codegen_spirv/src/linker/duplicates.rs index 6b1b45d8cd..7631972001 100644 --- a/crates/rustc_codegen_spirv/src/linker/duplicates.rs +++ b/crates/rustc_codegen_spirv/src/linker/duplicates.rs @@ -117,11 +117,14 @@ fn gather_names(debug_names: &[Instruction]) -> FxHashMap { .collect() } -fn make_dedupe_key( +fn make_dedupe_key_with_array_context( inst: &Instruction, unresolved_forward_pointers: &FxHashSet, annotations: &FxHashMap>, names: &FxHashMap, + array_contexts: Option< + &FxHashMap, + >, ) -> Vec { let mut data = vec![inst.class.opcode as u32]; @@ -169,6 +172,38 @@ fn make_dedupe_key( } } + // For array types, include storage class context in the key to prevent + // inappropriate deduplication between different storage class contexts + if let Some(result_id) = inst.result_id { + if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { + if let Some(contexts) = array_contexts { + if let Some(context) = contexts.get(&result_id) { + // Include usage pattern in the key so arrays with different contexts won't deduplicate + let usage_pattern_discriminant = match context.usage_pattern { + crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutRequired => { + 1u32 + } + crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutForbidden => { + 2u32 + } + crate::linker::array_stride_fixer::ArrayUsagePattern::MixedUsage => 3u32, + crate::linker::array_stride_fixer::ArrayUsagePattern::Unused => 4u32, + }; + data.push(usage_pattern_discriminant); + + // Also include the specific storage classes for fine-grained differentiation + let mut storage_classes: Vec = context + .storage_classes + .iter() + .map(|sc| *sc as u32) + .collect(); + storage_classes.sort(); // Ensure deterministic ordering + data.extend(storage_classes); + } + } + } + } + data } @@ -185,6 +220,15 @@ fn rewrite_inst_with_rules(inst: &mut Instruction, rules: &FxHashMap) } pub fn remove_duplicate_types(module: &mut Module) { + remove_duplicate_types_with_array_context(module, None); +} + +pub fn remove_duplicate_types_with_array_context( + module: &mut Module, + array_contexts: Option< + &FxHashMap, + >, +) { // Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module. // When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID. @@ -222,7 +266,13 @@ pub fn remove_duplicate_types(module: &mut Module) { // all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess. rewrite_inst_with_rules(inst, &rewrite_rules); - let key = make_dedupe_key(inst, &unresolved_forward_pointers, &annotations, &names); + let key = make_dedupe_key_with_array_context( + inst, + &unresolved_forward_pointers, + &annotations, + &names, + array_contexts, + ); match key_to_result_id.entry(key) { hash_map::Entry::Vacant(entry) => { diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 8e4933da36..d78525fcff 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -356,10 +356,10 @@ pub fn link( }); } - // Fix ArrayStride decorations for arrays in storage classes where newer SPIR-V versions forbid explicit layouts + // Fix ArrayStride decorations (after storage classes are resolved to avoid conflicts) { let _timer = sess.timer("fix_array_stride_decorations"); - array_stride_fixer::fix_array_stride_decorations(&mut output); + array_stride_fixer::fix_array_stride_decorations_with_deduplication(&mut output, false); } // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too! diff --git a/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs new file mode 100644 index 0000000000..c7219048b2 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs @@ -0,0 +1,21 @@ +// Test that ArrayStride decorations are kept for function storage in SPIR-V 1.3 + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +// only-spv1.3 +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], +) { + // Function storage in SPIR-V 1.3 should keep ArrayStride decorations + let mut function_var: [u32; 256] = [0; 256]; + function_var[0] = 42; + function_var[1] = function_var[0] + 1; + // Force the array to be used by writing to output + output[0] = function_var[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr new file mode 100644 index 0000000000..5c2290c73a --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr @@ -0,0 +1,30 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 1 1 1 +%2 = OpString "$OPSTRING_FILENAME/function_storage_spirv13_kept.rs" +OpDecorate %4 ArrayStride 4 +OpDecorate %5 Block +OpMemberDecorate %5 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 0 +%9 = OpConstant %8 1 +%4 = OpTypeArray %8 %9 +%10 = OpTypePointer StorageBuffer %4 +%5 = OpTypeStruct %4 +%11 = OpTypePointer StorageBuffer %5 +%3 = OpVariable %11 StorageBuffer +%12 = OpConstant %8 0 +%13 = OpTypeBool +%14 = OpConstant %8 256 +%15 = OpConstant %8 42 +%16 = OpTypePointer StorageBuffer %8 diff --git a/tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs new file mode 100644 index 0000000000..9947ce3426 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs @@ -0,0 +1,21 @@ +// Test that mixed storage class usage results in proper ArrayStride handling + +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.1 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] storage_data: &mut [u32; 256], + #[spirv(workgroup)] workgroup_data: &mut [u32; 256], +) { + // Both variables use the same array type [u32; 256] but in different storage classes: + // - storage_data is in StorageBuffer (requires ArrayStride) + // - workgroup_data is in Workgroup (forbids ArrayStride in SPIR-V 1.4+) + + storage_data[0] = 42; + workgroup_data[0] = storage_data[0]; +} diff --git a/tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr new file mode 100644 index 0000000000..a6d066e4ae --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr @@ -0,0 +1,36 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 64 1 1 +%2 = OpString "$OPSTRING_FILENAME/mixed_storage_classes.rs" +OpName %4 "workgroup_data" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 256 +%5 = OpTypeArray %9 %10 +%11 = OpTypePointer StorageBuffer %5 +%6 = OpTypeStruct %5 +%12 = OpTypePointer StorageBuffer %6 +%3 = OpVariable %12 StorageBuffer +%13 = OpConstant %9 0 +%14 = OpTypeBool +%15 = OpTypePointer StorageBuffer %9 +%16 = OpConstant %9 42 +%17 = OpTypePointer Workgroup %9 +%18 = OpTypeArray %9 %10 +%19 = OpTypePointer Workgroup %18 +%4 = OpVariable %19 Workgroup diff --git a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs new file mode 100644 index 0000000000..3407f66b90 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs @@ -0,0 +1,34 @@ +// Test that ArrayStride decorations are removed from nested structs with arrays in Function storage class + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.2 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[derive(Copy, Clone)] +struct InnerStruct { + data: [f32; 4], +} + +#[derive(Copy, Clone)] +struct OuterStruct { + inner: InnerStruct, +} + +#[spirv(compute(threads(1)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [f32; 1], +) { + // Function-local variables with nested structs containing arrays + // Should have ArrayStride removed in SPIR-V 1.4+ + let mut function_var = OuterStruct { + inner: InnerStruct { data: [0.0; 4] }, + }; + function_var.inner.data[0] = 42.0; + function_var.inner.data[1] = function_var.inner.data[0] + 1.0; + // Force usage to prevent optimization + output[0] = function_var.inner.data[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr new file mode 100644 index 0000000000..89a68f3a87 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr @@ -0,0 +1,45 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 +OpExecutionMode %1 LocalSize 1 1 1 +%3 = OpString "$OPSTRING_FILENAME/nested_structs_function_storage.rs" +OpName %4 "InnerStruct" +OpMemberName %4 0 "data" +OpName %5 "OuterStruct" +OpMemberName %5 0 "inner" +OpDecorate %6 ArrayStride 4 +OpDecorate %7 Block +OpMemberDecorate %7 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpMemberDecorate %4 0 Offset 0 +OpMemberDecorate %5 0 Offset 0 +%8 = OpTypeFloat 32 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 1 +%6 = OpTypeArray %8 %10 +%7 = OpTypeStruct %6 +%11 = OpTypePointer StorageBuffer %7 +%12 = OpTypeVoid +%13 = OpTypeFunction %12 +%14 = OpTypePointer StorageBuffer %6 +%2 = OpVariable %11 StorageBuffer +%15 = OpConstant %9 0 +%16 = OpConstant %9 4 +%17 = OpTypeArray %8 %16 +%4 = OpTypeStruct %17 +%5 = OpTypeStruct %4 +%18 = OpConstant %8 0 +%19 = OpConstantComposite %17 %18 %18 %18 %18 +%20 = OpUndef %5 +%21 = OpTypeBool +%22 = OpConstant %8 1109917696 +%23 = OpConstant %8 1065353216 +%24 = OpTypePointer StorageBuffer %8 diff --git a/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs new file mode 100644 index 0000000000..19de8523be --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs @@ -0,0 +1,22 @@ +// Test that ArrayStride decorations are removed from private storage in SPIR-V 1.4 + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +// only-spv1.4 +use spirv_std::spirv; + +// Helper function to create an array in private storage +fn create_private_array() -> [u32; 4] { + [0, 1, 2, 3] +} + +#[spirv(compute(threads(1)))] +pub fn main() { + // This creates a private storage array in SPIR-V 1.4+ + // ArrayStride decorations should be removed + let mut private_array = create_private_array(); + private_array[0] = 42; +} diff --git a/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr new file mode 100644 index 0000000000..2eff9c5bee --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr @@ -0,0 +1,22 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 1 1 1 +%2 = OpString "$OPSTRING_FILENAME/private_storage_spirv14_removed.rs" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 0 +%7 = OpConstant %6 4 +%8 = OpTypeArray %6 %7 +%9 = OpTypeFunction %8 +%10 = OpConstant %6 0 +%11 = OpConstant %6 1 +%12 = OpConstant %6 2 +%13 = OpConstant %6 3 +%14 = OpTypeBool diff --git a/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs new file mode 100644 index 0000000000..9c486544ab --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs @@ -0,0 +1,22 @@ +// Test that ArrayStride decorations are removed from runtime arrays in Workgroup storage class + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.1 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::RuntimeArray; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], + #[spirv(workgroup)] shared_array: &mut [u32; 256], +) { + // Workgroup arrays should have ArrayStride removed + shared_array[0] = 42; + shared_array[1] = shared_array[0] + 1; + // Force usage to prevent optimization + output[0] = shared_array[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr new file mode 100644 index 0000000000..42ca0186aa --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr @@ -0,0 +1,37 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 64 1 1 +%2 = OpString "$OPSTRING_FILENAME/runtime_arrays_in_workgroup.rs" +OpName %4 "shared_array" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 1 +%5 = OpTypeArray %9 %10 +%11 = OpTypePointer StorageBuffer %5 +%6 = OpTypeStruct %5 +%12 = OpTypePointer StorageBuffer %6 +%3 = OpVariable %12 StorageBuffer +%13 = OpConstant %9 0 +%14 = OpTypeBool +%15 = OpConstant %9 256 +%16 = OpTypePointer Workgroup %9 +%17 = OpTypeArray %9 %15 +%18 = OpTypePointer Workgroup %17 +%4 = OpVariable %18 Workgroup +%19 = OpConstant %9 42 +%20 = OpTypePointer StorageBuffer %9 diff --git a/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs new file mode 100644 index 0000000000..5b9261a8b0 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs @@ -0,0 +1,17 @@ +// Test that ArrayStride decorations are kept for arrays in StorageBuffer storage class + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.1 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] storage_buffer_var: &mut [u32; 256], +) { + // StorageBuffer storage class should keep ArrayStride decorations + storage_buffer_var[0] = 42; +} diff --git a/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr new file mode 100644 index 0000000000..5f0007f120 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr @@ -0,0 +1,31 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 1 1 1 +%2 = OpString "$OPSTRING_FILENAME/storage_buffer_arrays_kept.rs" +OpDecorate %4 ArrayStride 4 +OpDecorate %5 Block +OpMemberDecorate %5 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 0 +%9 = OpConstant %8 256 +%4 = OpTypeArray %8 %9 +%10 = OpTypePointer StorageBuffer %4 +%5 = OpTypeStruct %4 +%11 = OpTypePointer StorageBuffer %5 +%3 = OpVariable %11 StorageBuffer +%12 = OpConstant %8 0 +%13 = OpTypeBool +%14 = OpTypePointer StorageBuffer %8 +%15 = OpConstant %8 42 diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs new file mode 100644 index 0000000000..967113c70f --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs @@ -0,0 +1,21 @@ +// Test that ArrayStride decorations are removed from arrays in Function storage class (SPIR-V 1.4+) + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.2 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], + #[spirv(workgroup)] shared_data: &mut [u32; 256], +) { + // Workgroup storage arrays should have ArrayStride removed + shared_data[0] = 42; + shared_data[1] = shared_data[0] + 1; + // Force usage to prevent optimization + output[0] = shared_data[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr new file mode 100644 index 0000000000..3f769cba39 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr @@ -0,0 +1,36 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 %3 +OpExecutionMode %1 LocalSize 64 1 1 +%4 = OpString "$OPSTRING_FILENAME/workgroup_arrays_removed.rs" +OpName %3 "shared_data" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 1 +%5 = OpTypeArray %7 %8 +%6 = OpTypeStruct %5 +%9 = OpTypePointer StorageBuffer %6 +%10 = OpConstant %7 256 +%11 = OpTypeArray %7 %10 +%12 = OpTypePointer Workgroup %11 +%13 = OpTypeVoid +%14 = OpTypeFunction %13 +%15 = OpTypePointer StorageBuffer %5 +%2 = OpVariable %9 StorageBuffer +%16 = OpConstant %7 0 +%17 = OpTypeBool +%18 = OpTypePointer Workgroup %7 +%3 = OpVariable %12 Workgroup +%19 = OpConstant %7 42 +%20 = OpTypePointer StorageBuffer %7 diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs new file mode 100644 index 0000000000..505809003e --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs @@ -0,0 +1,22 @@ +// Test that ArrayStride decorations are kept for arrays in Workgroup storage class with WorkgroupMemoryExplicitLayoutKHR capability + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals -Ctarget-feature=+WorkgroupMemoryExplicitLayoutKHR,+ext:SPV_KHR_workgroup_memory_explicit_layout +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +// only-vulkan1.2 + +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], + #[spirv(workgroup)] shared_data: &mut [u32; 256], +) { + // With WorkgroupMemoryExplicitLayoutKHR capability, ArrayStride should be kept + shared_data[0] = 42; + shared_data[1] = shared_data[0] + 1; + // Force usage to prevent optimization + output[0] = shared_data[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr new file mode 100644 index 0000000000..a645d0bcbd --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr @@ -0,0 +1,39 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability WorkgroupMemoryExplicitLayoutKHR +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_workgroup_memory_explicit_layout" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 %3 +OpExecutionMode %1 LocalSize 64 1 1 +%4 = OpString "$OPSTRING_FILENAME/workgroup_arrays_with_capability.rs" +OpName %3 "shared_data" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %7 ArrayStride 4 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +%8 = OpTypeInt 32 0 +%9 = OpConstant %8 1 +%5 = OpTypeArray %8 %9 +%6 = OpTypeStruct %5 +%10 = OpTypePointer StorageBuffer %6 +%11 = OpConstant %8 256 +%7 = OpTypeArray %8 %11 +%12 = OpTypePointer Workgroup %7 +%13 = OpTypeVoid +%14 = OpTypeFunction %13 +%15 = OpTypePointer StorageBuffer %5 +%2 = OpVariable %10 StorageBuffer +%16 = OpConstant %8 0 +%17 = OpTypeBool +%18 = OpTypePointer Workgroup %8 +%3 = OpVariable %12 Workgroup +%19 = OpConstant %8 42 +%20 = OpTypePointer StorageBuffer %8