Skip to content

Commit bd51b65

Browse files
Switch to pipeline stream desc for dx12 pipeline creation
Co-authored-by: Inner Daemons <[email protected]>
1 parent 71820ee commit bd51b65

File tree

3 files changed

+420
-198
lines changed

3 files changed

+420
-198
lines changed

wgpu-hal/src/dx12/device.rs

Lines changed: 86 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ use crate::{
2626
dxgi::{name::ObjectExt as _, result::HResult as _},
2727
},
2828
dx12::{
29-
borrow_optional_interface_temporarily, shader_compilation, suballocation, DCompLib,
30-
DynamicStorageBufferOffsets, Event, ShaderCacheKey, ShaderCacheValue,
29+
borrow_optional_interface_temporarily, pipeline_desc::RenderPipelineStateStreamDesc,
30+
shader_compilation, suballocation, DCompLib, DynamicStorageBufferOffsets, Event,
31+
ShaderCacheKey, ShaderCacheValue,
3132
},
3233
AccelerationStructureEntries, TlasInstance,
3334
};
@@ -1866,8 +1867,6 @@ impl crate::Device for super::Device {
18661867
>,
18671868
) -> Result<super::RenderPipeline, crate::PipelineError> {
18681869
let mut shader_stages = wgt::ShaderStages::empty();
1869-
let root_signature =
1870-
unsafe { borrow_optional_interface_temporarily(&desc.layout.shared.signature) };
18711870
let (topology_class, topology) = conv::map_topology(desc.primitive.topology);
18721871
let mut rtv_formats = [Dxgi::Common::DXGI_FORMAT_UNKNOWN;
18731872
Direct3D12::D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT as usize];
@@ -1907,6 +1906,7 @@ impl crate::Device for super::Device {
19071906
Direct3D12::D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF
19081907
},
19091908
};
1909+
19101910
let blob_fs = match desc.fragment_stage {
19111911
Some(ref stage) => {
19121912
shader_stages |= wgt::ShaderStages::FRAGMENT;
@@ -1918,7 +1918,6 @@ impl crate::Device for super::Device {
19181918
Some(shader) => shader.create_native_shader(),
19191919
None => Direct3D12::D3D12_SHADER_BYTECODE::default(),
19201920
};
1921-
let mut vertex_strides = [None; crate::MAX_VERTEX_BUFFERS];
19221921
let stream_output = Direct3D12::D3D12_STREAM_OUTPUT_DESC {
19231922
pSODeclaration: ptr::null(),
19241923
NumEntries: 0,
@@ -1953,20 +1952,51 @@ impl crate::Device for super::Device {
19531952
};
19541953
let flags = Direct3D12::D3D12_PIPELINE_STATE_FLAG_NONE;
19551954

1956-
let raw: Direct3D12::ID3D12PipelineState = match &desc.vertex_processor {
1955+
let mut stream_desc = RenderPipelineStateStreamDesc {
1956+
// Shared by vertex and mesh pipelines
1957+
root_signature: desc.layout.shared.signature.as_ref(),
1958+
pixel_shader,
1959+
blend_state,
1960+
sample_mask: desc.multisample.mask as u32,
1961+
rasterizer_state,
1962+
depth_stencil_state,
1963+
primitive_topology_type: topology_class,
1964+
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY {
1965+
RTFormats: rtv_formats,
1966+
NumRenderTargets: desc.color_targets.len() as u32,
1967+
},
1968+
dsv_format,
1969+
sample_desc,
1970+
node_mask: 0,
1971+
cached_pso,
1972+
flags,
1973+
1974+
// Optional data that depends on the pipeline type (vertex vs mesh).
1975+
vertex_shader: Default::default(),
1976+
input_layout: Default::default(),
1977+
index_buffer_strip_cut_value: Default::default(),
1978+
stream_output,
1979+
task_shader: Default::default(),
1980+
mesh_shader: Default::default(),
1981+
};
1982+
let mut input_element_descs = Vec::new();
1983+
let blob_vs;
1984+
let blob_ts;
1985+
let blob_ms;
1986+
let mut vertex_strides = [None; crate::MAX_VERTEX_BUFFERS];
1987+
match &desc.vertex_processor {
19571988
&crate::VertexProcessor::Standard {
19581989
vertex_buffers,
19591990
ref vertex_stage,
19601991
} => {
19611992
shader_stages |= wgt::ShaderStages::VERTEX;
1962-
let blob_vs = self.load_shader(
1993+
blob_vs = Some(self.load_shader(
19631994
vertex_stage,
19641995
desc.layout,
19651996
naga::ShaderStage::Vertex,
19661997
desc.fragment_stage.as_ref(),
1967-
)?;
1998+
)?);
19681999

1969-
let mut input_element_descs = Vec::new();
19702000
for (i, (stride, vbuf)) in vertex_strides.iter_mut().zip(vertex_buffers).enumerate()
19712001
{
19722002
*stride = Some(vbuf.array_stride as u32);
@@ -1990,54 +2020,37 @@ impl crate::Device for super::Device {
19902020
});
19912021
}
19922022
}
1993-
let raw_desc = Direct3D12::D3D12_GRAPHICS_PIPELINE_STATE_DESC {
1994-
pRootSignature: root_signature,
1995-
VS: blob_vs.create_native_shader(),
1996-
PS: pixel_shader,
1997-
GS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
1998-
DS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
1999-
HS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
2000-
StreamOutput: stream_output,
2001-
BlendState: blend_state,
2002-
SampleMask: desc.multisample.mask as u32,
2003-
RasterizerState: rasterizer_state,
2004-
DepthStencilState: depth_stencil_state,
2005-
InputLayout: Direct3D12::D3D12_INPUT_LAYOUT_DESC {
2006-
pInputElementDescs: if input_element_descs.is_empty() {
2007-
ptr::null()
2008-
} else {
2009-
input_element_descs.as_ptr()
2010-
},
2011-
NumElements: input_element_descs.len() as u32,
2012-
},
2013-
IBStripCutValue: match desc.primitive.strip_index_format {
2014-
Some(wgt::IndexFormat::Uint16) => {
2015-
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
2016-
}
2017-
Some(wgt::IndexFormat::Uint32) => {
2018-
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
2019-
}
2020-
None => Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED,
2023+
stream_desc.vertex_shader = blob_vs.as_ref().unwrap().create_native_shader();
2024+
stream_desc.input_layout = Direct3D12::D3D12_INPUT_LAYOUT_DESC {
2025+
pInputElementDescs: if input_element_descs.is_empty() {
2026+
ptr::null()
2027+
} else {
2028+
input_element_descs.as_ptr()
20212029
},
2022-
PrimitiveTopologyType: topology_class,
2023-
NumRenderTargets: desc.color_targets.len() as u32,
2024-
RTVFormats: rtv_formats,
2025-
DSVFormat: dsv_format,
2026-
SampleDesc: sample_desc,
2027-
NodeMask: 0,
2028-
CachedPSO: cached_pso,
2029-
Flags: flags,
2030+
NumElements: input_element_descs.len() as u32,
2031+
};
2032+
stream_desc.index_buffer_strip_cut_value = match desc.primitive.strip_index_format {
2033+
Some(wgt::IndexFormat::Uint16) => {
2034+
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
2035+
}
2036+
Some(wgt::IndexFormat::Uint32) => {
2037+
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
2038+
}
2039+
None => Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED,
2040+
};
2041+
stream_desc.stream_output = Direct3D12::D3D12_STREAM_OUTPUT_DESC {
2042+
pSODeclaration: ptr::null(),
2043+
NumEntries: 0,
2044+
pBufferStrides: ptr::null(),
2045+
NumStrides: 0,
2046+
RasterizedStream: 0,
20302047
};
2031-
unsafe {
2032-
profiling::scope!("ID3D12Device::CreateGraphicsPipelineState");
2033-
self.raw.CreateGraphicsPipelineState(&raw_desc)
2034-
}
20352048
}
20362049
crate::VertexProcessor::Mesh {
20372050
task_stage,
20382051
mesh_stage,
20392052
} => {
2040-
let blob_ts = if let Some(ts) = task_stage {
2053+
blob_ts = if let Some(ts) = task_stage {
20412054
shader_stages |= wgt::ShaderStages::TASK;
20422055
Some(self.load_shader(
20432056
ts,
@@ -2054,48 +2067,36 @@ impl crate::Device for super::Device {
20542067
Default::default()
20552068
};
20562069
shader_stages |= wgt::ShaderStages::MESH;
2057-
let blob_ms = self.load_shader(
2070+
blob_ms = Some(self.load_shader(
20582071
mesh_stage,
20592072
desc.layout,
20602073
naga::ShaderStage::Mesh,
20612074
desc.fragment_stage.as_ref(),
2062-
)?;
2063-
let desc = super::MeshShaderPipelineStateStream {
2064-
root_signature: root_signature
2065-
.as_ref()
2066-
.map(|a| a.as_raw().cast())
2067-
.unwrap_or(ptr::null_mut()),
2068-
task_shader,
2069-
pixel_shader,
2070-
mesh_shader: blob_ms.create_native_shader(),
2071-
blend_state,
2072-
sample_mask: desc.multisample.mask as u32,
2073-
rasterizer_state,
2074-
depth_stencil_state,
2075-
primitive_topology_type: topology_class,
2076-
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY {
2077-
RTFormats: rtv_formats,
2078-
NumRenderTargets: desc.color_targets.len() as u32,
2079-
},
2080-
dsv_format,
2081-
sample_desc,
2082-
node_mask: 0,
2083-
cached_pso,
2084-
flags,
2085-
};
2086-
let mut raw_desc = unsafe { desc.to_bytes() };
2087-
let stream_desc = Direct3D12::D3D12_PIPELINE_STATE_STREAM_DESC {
2088-
SizeInBytes: raw_desc.len(),
2089-
pPipelineStateSubobjectStream: raw_desc.as_mut_ptr().cast(),
2090-
};
2091-
let device: Direct3D12::ID3D12Device2 = self.raw.cast().unwrap();
2075+
)?);
2076+
stream_desc.task_shader = task_shader;
2077+
stream_desc.mesh_shader = blob_ms.as_ref().unwrap().create_native_shader();
2078+
}
2079+
};
2080+
let raw: Direct3D12::ID3D12PipelineState =
2081+
// If stream descriptors are available, use them as they are more flexible.
2082+
if let Ok(device) = self.raw.cast::<Direct3D12::ID3D12Device2>() {
2083+
// Prefer stream descs where possible
2084+
let mut stream = stream_desc.to_stream();
20922085
unsafe {
20932086
profiling::scope!("ID3D12Device2::CreatePipelineState");
2094-
device.CreatePipelineState(&stream_desc)
2087+
stream.create_pipeline_state(&device).map_err(|err| {
2088+
crate::PipelineError::Linkage(shader_stages, err.to_string())
2089+
})?
20952090
}
2096-
}
2097-
}
2098-
.map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?;
2091+
} else {
2092+
unsafe {
2093+
// Safety: `stream_desc` entirely outlives the `desc`.
2094+
let desc = stream_desc.to_graphics_pipeline_descriptor();
2095+
self.raw.CreateGraphicsPipelineState(&desc).map_err(|err| {
2096+
crate::PipelineError::Linkage(shader_stages, err.to_string())
2097+
})?
2098+
}
2099+
};
20992100

21002101
if let Some(label) = desc.label {
21012102
raw.set_name(label)?;

wgpu-hal/src/dx12/mod.rs

Lines changed: 1 addition & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ mod dcomp;
7979
mod descriptor;
8080
mod device;
8181
mod instance;
82+
mod pipeline_desc;
8283
mod sampler;
8384
mod shader_compilation;
8485
mod suballocation;
@@ -1615,116 +1616,3 @@ pub enum ShaderModuleSource {
16151616
DxilPassthrough(DxilPassthroughShader),
16161617
HlslPassthrough(HlslPassthroughShader),
16171618
}
1618-
1619-
#[repr(C)]
1620-
#[derive(Debug)]
1621-
struct MeshShaderPipelineStateStream {
1622-
root_signature: *mut Direct3D12::ID3D12RootSignature,
1623-
task_shader: Direct3D12::D3D12_SHADER_BYTECODE,
1624-
mesh_shader: Direct3D12::D3D12_SHADER_BYTECODE,
1625-
pixel_shader: Direct3D12::D3D12_SHADER_BYTECODE,
1626-
blend_state: Direct3D12::D3D12_BLEND_DESC,
1627-
sample_mask: u32,
1628-
rasterizer_state: Direct3D12::D3D12_RASTERIZER_DESC,
1629-
depth_stencil_state: Direct3D12::D3D12_DEPTH_STENCIL_DESC,
1630-
primitive_topology_type: Direct3D12::D3D12_PRIMITIVE_TOPOLOGY_TYPE,
1631-
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY,
1632-
dsv_format: Dxgi::Common::DXGI_FORMAT,
1633-
sample_desc: Dxgi::Common::DXGI_SAMPLE_DESC,
1634-
node_mask: u32,
1635-
cached_pso: Direct3D12::D3D12_CACHED_PIPELINE_STATE,
1636-
flags: Direct3D12::D3D12_PIPELINE_STATE_FLAGS,
1637-
}
1638-
impl MeshShaderPipelineStateStream {
1639-
/// # Safety
1640-
///
1641-
/// Returned bytes contain pointers into this struct, for them to be valid,
1642-
/// this struct may be at the same location. As if `as_bytes<'a>(&'a self) -> Vec<u8> + 'a`
1643-
pub unsafe fn to_bytes(&self) -> Vec<u8> {
1644-
use Direct3D12::*;
1645-
let mut bytes = Vec::new();
1646-
1647-
macro_rules! push_subobject {
1648-
($subobject_type:expr, $data:expr) => {{
1649-
// Ensure 8-byte alignment for the subobject start
1650-
let alignment = 8;
1651-
let aligned_length = bytes.len().next_multiple_of(alignment);
1652-
bytes.resize(aligned_length, 0);
1653-
1654-
// Append the type tag (u32)
1655-
let tag: u32 = $subobject_type.0 as u32;
1656-
bytes.extend_from_slice(&tag.to_ne_bytes());
1657-
1658-
// Align the data
1659-
let obj_align = align_of_val(&$data);
1660-
let data_start = bytes.len().next_multiple_of(obj_align);
1661-
bytes.resize(data_start, 0);
1662-
1663-
// Append the data itself
1664-
#[allow(clippy::ptr_as_ptr, trivial_casts)]
1665-
let data_ptr = &$data as *const _ as *const u8;
1666-
let data_size = size_of_val(&$data);
1667-
let slice = unsafe { core::slice::from_raw_parts(data_ptr, data_size) };
1668-
bytes.extend_from_slice(slice);
1669-
}};
1670-
}
1671-
push_subobject!(
1672-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE,
1673-
self.root_signature
1674-
);
1675-
if !self.task_shader.pShaderBytecode.is_null() {
1676-
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_AS, self.task_shader);
1677-
}
1678-
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MS, self.mesh_shader);
1679-
if !self.pixel_shader.pShaderBytecode.is_null() {
1680-
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS, self.pixel_shader);
1681-
}
1682-
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND, self.blend_state);
1683-
push_subobject!(
1684-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK,
1685-
self.sample_mask
1686-
);
1687-
push_subobject!(
1688-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER,
1689-
self.rasterizer_state
1690-
);
1691-
push_subobject!(
1692-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL,
1693-
self.depth_stencil_state
1694-
);
1695-
push_subobject!(
1696-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY,
1697-
self.primitive_topology_type
1698-
);
1699-
if self.rtv_formats.NumRenderTargets != 0 {
1700-
push_subobject!(
1701-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS,
1702-
self.rtv_formats
1703-
);
1704-
}
1705-
if self.dsv_format != Dxgi::Common::DXGI_FORMAT_UNKNOWN {
1706-
push_subobject!(
1707-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT,
1708-
self.dsv_format
1709-
);
1710-
}
1711-
push_subobject!(
1712-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC,
1713-
self.sample_desc
1714-
);
1715-
if self.node_mask != 0 {
1716-
push_subobject!(
1717-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK,
1718-
self.node_mask
1719-
);
1720-
}
1721-
if !self.cached_pso.pCachedBlob.is_null() {
1722-
push_subobject!(
1723-
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO,
1724-
self.cached_pso
1725-
);
1726-
}
1727-
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS, self.flags);
1728-
bytes
1729-
}
1730-
}

0 commit comments

Comments
 (0)