Skip to content

Commit 6f5014f

Browse files
authored
[wgpu-core/-hal] move raytracing alignments into hal (#6563)
1 parent 2389106 commit 6f5014f

File tree

16 files changed

+102
-64
lines changed

16 files changed

+102
-64
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]
104104
#### General
105105

106106
- Return submission index in `map_async` and `on_submitted_work_done` to track down completion of async callbacks. By @eliemichel in [#6360](https://github.com/gfx-rs/wgpu/pull/6360).
107+
- Move raytracing alignments into HAL instead of in core. By @Vecvec in [#6563](https://github.com/gfx-rs/wgpu/pull/6563).
107108

108109
### Changes
109110

wgpu-core/src/command/ray_tracing.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
id::CommandEncoderId,
66
init_tracker::MemoryInitKind,
77
ray_tracing::{
8-
tlas_instance_into_bytes, BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry,
8+
BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry,
99
BuildAccelerationStructureError, TlasAction, TlasBuildEntry, TlasInstance, TlasPackage,
1010
TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
1111
TraceTlasPackage, ValidateBlasActionsError, ValidateTlasActionsError,
@@ -60,9 +60,6 @@ struct TlasBufferStore {
6060
entry: TlasBuildEntry,
6161
}
6262

63-
// TODO: Get this from the device (e.g. VkPhysicalDeviceAccelerationStructurePropertiesKHR.minAccelerationStructureScratchOffsetAlignment) this is currently the largest possible some devices have 0, 64, 128 (lower limits) so this could create excess allocation (Note: dx12 has 256).
64-
const SCRATCH_BUFFER_ALIGNMENT: u32 = 256;
65-
6663
impl Global {
6764
// Currently this function is very similar to its safe counterpart, however certain parts of it are very different,
6865
// making for the two to be implemented differently, the main difference is this function has separate buffers for each
@@ -193,6 +190,7 @@ impl Global {
193190
&mut scratch_buffer_blas_size,
194191
&mut blas_storage,
195192
hub,
193+
device.alignments.ray_tracing_scratch_buffer_alignment,
196194
)?;
197195

198196
let mut scratch_buffer_tlas_size = 0;
@@ -260,7 +258,7 @@ impl Global {
260258
let scratch_buffer_offset = scratch_buffer_tlas_size;
261259
scratch_buffer_tlas_size += align_to(
262260
tlas.size_info.build_scratch_size as u32,
263-
SCRATCH_BUFFER_ALIGNMENT,
261+
device.alignments.ray_tracing_scratch_buffer_alignment,
264262
) as u64;
265263

266264
tlas_storage.push(UnsafeTlasStore {
@@ -508,6 +506,7 @@ impl Global {
508506
&mut scratch_buffer_blas_size,
509507
&mut blas_storage,
510508
hub,
509+
device.alignments.ray_tracing_scratch_buffer_alignment,
511510
)?;
512511
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
513512

@@ -535,7 +534,7 @@ impl Global {
535534
let scratch_buffer_offset = scratch_buffer_tlas_size;
536535
scratch_buffer_tlas_size += align_to(
537536
tlas.size_info.build_scratch_size as u32,
538-
SCRATCH_BUFFER_ALIGNMENT,
537+
device.alignments.ray_tracing_scratch_buffer_alignment,
539538
) as u64;
540539

541540
let first_byte_index = instance_buffer_staging_source.len();
@@ -558,10 +557,13 @@ impl Global {
558557

559558
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
560559

561-
instance_buffer_staging_source.extend(tlas_instance_into_bytes(
562-
&instance,
563-
blas.handle,
564-
device.backend(),
560+
instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
561+
hal::TlasInstance {
562+
transform: *instance.transform,
563+
custom_index: instance.custom_index,
564+
mask: instance.mask,
565+
blas_address: blas.handle,
566+
},
565567
));
566568

567569
instance_count += 1;
@@ -1013,6 +1015,7 @@ fn iter_buffers<'a, 'b>(
10131015
scratch_buffer_blas_size: &mut u64,
10141016
blas_storage: &mut Vec<BlasStore<'a>>,
10151017
hub: &Hub,
1018+
ray_tracing_scratch_buffer_alignment: u32,
10161019
) -> Result<(), BuildAccelerationStructureError> {
10171020
let mut triangle_entries =
10181021
Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
@@ -1192,7 +1195,7 @@ fn iter_buffers<'a, 'b>(
11921195
let scratch_buffer_offset = *scratch_buffer_blas_size;
11931196
*scratch_buffer_blas_size += align_to(
11941197
blas.size_info.build_scratch_size as u32,
1195-
SCRATCH_BUFFER_ALIGNMENT,
1198+
ray_tracing_scratch_buffer_alignment,
11961199
) as u64;
11971200

11981201
blas_storage.push(BlasStore {

wgpu-core/src/device/ray_tracing.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
global::Global,
1313
id::{self, BlasId, TlasId},
1414
lock::RwLock,
15-
ray_tracing::{get_raw_tlas_instance_size, CreateBlasError, CreateTlasError},
15+
ray_tracing::{CreateBlasError, CreateTlasError},
1616
resource, LabelHelpers,
1717
};
1818
use hal::AccelerationStructureTriangleIndices;
@@ -135,7 +135,7 @@ impl Device {
135135
.map_err(DeviceError::from_hal)?;
136136

137137
let instance_buffer_size =
138-
get_raw_tlas_instance_size(self.backend()) * desc.max_instances.max(1) as usize;
138+
self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
139139
let instance_buffer = unsafe {
140140
self.raw().create_buffer(&hal::BufferDescriptor {
141141
label: Some("(wgpu-core) instances_buffer"),

wgpu-core/src/ray_tracing.rs

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use crate::{
1313
id::{BlasId, BufferId, TlasId},
1414
resource::CreateBufferError,
1515
};
16-
use std::{mem::size_of, sync::Arc};
17-
use std::{num::NonZeroU64, slice};
16+
use std::num::NonZeroU64;
17+
use std::sync::Arc;
1818

1919
use crate::resource::{Blas, ResourceErrorIdent, Tlas};
2020
use thiserror::Error;
@@ -276,48 +276,3 @@ pub struct TraceTlasPackage {
276276
pub instances: Vec<Option<TraceTlasInstance>>,
277277
pub lowest_unmodified: u32,
278278
}
279-
280-
pub(crate) fn get_raw_tlas_instance_size(backend: wgt::Backend) -> usize {
281-
// TODO: this should be provided by the backend
282-
match backend {
283-
wgt::Backend::Empty => 0,
284-
wgt::Backend::Vulkan => 64,
285-
_ => unimplemented!(),
286-
}
287-
}
288-
289-
#[derive(Clone)]
290-
#[repr(C)]
291-
struct RawTlasInstance {
292-
transform: [f32; 12],
293-
custom_index_and_mask: u32,
294-
shader_binding_table_record_offset_and_flags: u32,
295-
acceleration_structure_reference: u64,
296-
}
297-
298-
pub(crate) fn tlas_instance_into_bytes(
299-
instance: &TlasInstance,
300-
blas_address: u64,
301-
backend: wgt::Backend,
302-
) -> Vec<u8> {
303-
// TODO: get the device to do this
304-
match backend {
305-
wgt::Backend::Empty => vec![],
306-
wgt::Backend::Vulkan => {
307-
const MAX_U24: u32 = (1u32 << 24u32) - 1u32;
308-
let temp = RawTlasInstance {
309-
transform: *instance.transform,
310-
custom_index_and_mask: (instance.custom_index & MAX_U24)
311-
| (u32::from(instance.mask) << 24),
312-
shader_binding_table_record_offset_and_flags: 0,
313-
acceleration_structure_reference: blas_address,
314-
};
315-
let temp: *const _ = &temp;
316-
unsafe {
317-
slice::from_raw_parts::<u8>(temp.cast::<u8>(), size_of::<RawTlasInstance>())
318-
.to_vec()
319-
}
320-
}
321-
_ => unimplemented!(),
322-
}
323-
}

wgpu-hal/src/dx12/adapter.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,8 @@ impl super::Adapter {
522522
// Direct3D correctly bounds-checks all array accesses:
523523
// https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#18.6.8.2%20Device%20Memory%20Reads
524524
uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(),
525+
raw_tlas_instance_size: 0,
526+
ray_tracing_scratch_buffer_alignment: 0,
525527
},
526528
downlevel,
527529
},

wgpu-hal/src/dx12/device.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use super::{conv, descriptor, D3D12Lib};
2121
use crate::{
2222
auxil::{self, dxgi::result::HResult},
2323
dx12::{borrow_optional_interface_temporarily, shader_compilation, Event},
24+
TlasInstance,
2425
};
2526

2627
// this has to match Naga's HLSL backend, and also needs to be null-terminated
@@ -1939,4 +1940,8 @@ impl crate::Device for super::Device {
19391940
total_reserved_bytes: upstream.total_reserved_bytes,
19401941
})
19411942
}
1943+
1944+
fn tlas_instance_to_bytes(&self, _instance: TlasInstance) -> Vec<u8> {
1945+
todo!()
1946+
}
19421947
}

wgpu-hal/src/dynamic/device.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, PipelineCacheDescriptor,
66
PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RenderPipelineDescriptor,
77
SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, TextureDescriptor,
8-
TextureViewDescriptor,
8+
TextureViewDescriptor, TlasInstance,
99
};
1010

1111
use super::{
@@ -158,6 +158,7 @@ pub trait DynDevice: DynResource {
158158
&self,
159159
acceleration_structure: Box<dyn DynAccelerationStructure>,
160160
);
161+
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8>;
161162

162163
fn get_internal_counters(&self) -> wgt::HalCounters;
163164
fn generate_allocator_report(&self) -> Option<wgt::AllocatorReport>;
@@ -520,6 +521,10 @@ impl<D: Device + DynResource> DynDevice for D {
520521
unsafe { D::destroy_acceleration_structure(self, acceleration_structure.unbox()) }
521522
}
522523

524+
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8> {
525+
D::tlas_instance_to_bytes(self, instance)
526+
}
527+
523528
fn get_internal_counters(&self) -> wgt::HalCounters {
524529
D::get_internal_counters(self)
525530
}

wgpu-hal/src/empty.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![allow(unused_variables)]
22

3+
use crate::TlasInstance;
34
use std::ops::Range;
45

56
#[derive(Clone, Debug)]
@@ -306,6 +307,10 @@ impl crate::Device for Context {
306307
}
307308
unsafe fn destroy_acceleration_structure(&self, _acceleration_structure: Resource) {}
308309

310+
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8> {
311+
vec![]
312+
}
313+
309314
fn get_internal_counters(&self) -> wgt::HalCounters {
310315
Default::default()
311316
}

wgpu-hal/src/gles/adapter.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,8 @@ impl super::Adapter {
851851
// being, provide 1 as the value here, to cause as little
852852
// trouble as possible.
853853
uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(),
854+
raw_tlas_instance_size: 0,
855+
ray_tracing_scratch_buffer_alignment: 0,
854856
},
855857
},
856858
})

wgpu-hal/src/gles/device.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::{
88
sync::{Arc, Mutex},
99
};
1010

11-
use crate::AtomicFenceValue;
11+
use crate::{AtomicFenceValue, TlasInstance};
1212
use arrayvec::ArrayVec;
1313
use std::sync::atomic::Ordering;
1414

@@ -1633,6 +1633,10 @@ impl crate::Device for super::Device {
16331633
) {
16341634
}
16351635

1636+
fn tlas_instance_to_bytes(&self, _instance: TlasInstance) -> Vec<u8> {
1637+
unimplemented!()
1638+
}
1639+
16361640
fn get_internal_counters(&self) -> wgt::HalCounters {
16371641
self.counters.clone()
16381642
}

wgpu-hal/src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ pub trait Device: WasmNotSendSync {
971971
&self,
972972
acceleration_structure: <Self::A as Api>::AccelerationStructure,
973973
);
974+
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8>;
974975

975976
fn get_internal_counters(&self) -> wgt::HalCounters;
976977

@@ -1771,6 +1772,12 @@ pub struct Alignments {
17711772
/// [`Uniform`]: wgt::BufferBindingType::Uniform
17721773
/// [size]: BufferBinding::size
17731774
pub uniform_bounds_check_alignment: wgt::BufferSize,
1775+
1776+
/// The size of the raw TLAS instance
1777+
pub raw_tlas_instance_size: usize,
1778+
1779+
/// What the scratch buffer for building an acceleration structure must be aligned to
1780+
pub ray_tracing_scratch_buffer_alignment: u32,
17741781
}
17751782

17761783
#[derive(Clone, Debug)]
@@ -2519,3 +2526,11 @@ bitflags::bitflags! {
25192526
pub struct AccelerationStructureBarrier {
25202527
pub usage: Range<AccelerationStructureUses>,
25212528
}
2529+
2530+
#[derive(Debug, Copy, Clone)]
2531+
pub struct TlasInstance {
2532+
pub transform: [f32; 12],
2533+
pub custom_index: u32,
2534+
pub mask: u8,
2535+
pub blas_address: u64,
2536+
}

wgpu-hal/src/metal/adapter.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,8 @@ impl super::PrivateCapabilities {
10011001
// Metal Shading Language it generates, so from `wgpu_hal`'s
10021002
// users' point of view, references are tightly checked.
10031003
uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(),
1004+
raw_tlas_instance_size: 0,
1005+
ray_tracing_scratch_buffer_alignment: 0,
10041006
},
10051007
downlevel,
10061008
}

wgpu-hal/src/metal/device.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::{
88

99
use super::conv;
1010
use crate::auxil::map_naga_stage;
11+
use crate::TlasInstance;
1112

1213
type DeviceResult<T> = Result<T, crate::DeviceError>;
1314

@@ -1426,6 +1427,10 @@ impl crate::Device for super::Device {
14261427
unimplemented!()
14271428
}
14281429

1430+
fn tlas_instance_to_bytes(&self, _instance: TlasInstance) -> Vec<u8> {
1431+
unimplemented!()
1432+
}
1433+
14291434
fn get_internal_counters(&self) -> wgt::HalCounters {
14301435
self.counters.clone()
14311436
}

wgpu-hal/src/vulkan/adapter.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,13 @@ impl PhysicalDeviceProperties {
11401140
};
11411141
wgt::BufferSize::new(alignment).unwrap()
11421142
},
1143+
raw_tlas_instance_size: 64,
1144+
ray_tracing_scratch_buffer_alignment: self.acceleration_structure.map_or(
1145+
0,
1146+
|acceleration_structure| {
1147+
acceleration_structure.min_acceleration_structure_scratch_offset_alignment
1148+
},
1149+
),
11431150
}
11441151
}
11451152
}

wgpu-hal/src/vulkan/device.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
use super::conv;
1+
use super::{conv, RawTlasInstance};
22

33
use arrayvec::ArrayVec;
44
use ash::{khr, vk};
55
use parking_lot::Mutex;
66

7+
use crate::TlasInstance;
78
use std::{
89
borrow::Cow,
910
collections::{hash_map::Entry, BTreeMap},
1011
ffi::{CStr, CString},
12+
mem,
1113
mem::MaybeUninit,
1214
num::NonZeroU32,
13-
ptr,
15+
ptr, slice,
1416
sync::Arc,
1517
};
1618

@@ -2557,6 +2559,22 @@ impl crate::Device for super::Device {
25572559

25582560
self.counters.clone()
25592561
}
2562+
2563+
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8> {
2564+
const MAX_U24: u32 = (1u32 << 24u32) - 1u32;
2565+
let temp = RawTlasInstance {
2566+
transform: instance.transform,
2567+
custom_index_and_mask: (instance.custom_index & MAX_U24)
2568+
| (u32::from(instance.mask) << 24),
2569+
shader_binding_table_record_offset_and_flags: 0,
2570+
acceleration_structure_reference: instance.blas_address,
2571+
};
2572+
let temp: *const _ = &temp;
2573+
unsafe {
2574+
slice::from_raw_parts::<u8>(temp.cast::<u8>(), mem::size_of::<RawTlasInstance>())
2575+
.to_vec()
2576+
}
2577+
}
25602578
}
25612579

25622580
impl super::DeviceShared {

0 commit comments

Comments
 (0)