@@ -26,8 +26,9 @@ use crate::{
2626 dxgi:: { name:: ObjectExt as _, result:: HResult as _} ,
2727 } ,
2828 dx12:: {
29- borrow_optional_interface_temporarily, shader_compilation, suballocation, DCompLib ,
30- DynamicStorageBufferOffsets , Event , ShaderCacheKey , ShaderCacheValue ,
29+ borrow_optional_interface_temporarily, pipeline_desc:: RenderPipelineStateStreamDesc ,
30+ shader_compilation, suballocation, DCompLib , DynamicStorageBufferOffsets , Event ,
31+ ShaderCacheKey , ShaderCacheValue ,
3132 } ,
3233 AccelerationStructureEntries , TlasInstance ,
3334} ;
@@ -1866,8 +1867,6 @@ impl crate::Device for super::Device {
18661867 > ,
18671868 ) -> Result < super :: RenderPipeline , crate :: PipelineError > {
18681869 let mut shader_stages = wgt:: ShaderStages :: empty ( ) ;
1869- let root_signature =
1870- unsafe { borrow_optional_interface_temporarily ( & desc. layout . shared . signature ) } ;
18711870 let ( topology_class, topology) = conv:: map_topology ( desc. primitive . topology ) ;
18721871 let mut rtv_formats = [ Dxgi :: Common :: DXGI_FORMAT_UNKNOWN ;
18731872 Direct3D12 :: D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT as usize ] ;
@@ -1907,6 +1906,7 @@ impl crate::Device for super::Device {
19071906 Direct3D12 :: D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF
19081907 } ,
19091908 } ;
1909+
19101910 let blob_fs = match desc. fragment_stage {
19111911 Some ( ref stage) => {
19121912 shader_stages |= wgt:: ShaderStages :: FRAGMENT ;
@@ -1918,7 +1918,6 @@ impl crate::Device for super::Device {
19181918 Some ( shader) => shader. create_native_shader ( ) ,
19191919 None => Direct3D12 :: D3D12_SHADER_BYTECODE :: default ( ) ,
19201920 } ;
1921- let mut vertex_strides = [ None ; crate :: MAX_VERTEX_BUFFERS ] ;
19221921 let stream_output = Direct3D12 :: D3D12_STREAM_OUTPUT_DESC {
19231922 pSODeclaration : ptr:: null ( ) ,
19241923 NumEntries : 0 ,
@@ -1953,20 +1952,51 @@ impl crate::Device for super::Device {
19531952 } ;
19541953 let flags = Direct3D12 :: D3D12_PIPELINE_STATE_FLAG_NONE ;
19551954
1956- let raw: Direct3D12 :: ID3D12PipelineState = match & desc. vertex_processor {
1955+ let mut stream_desc = RenderPipelineStateStreamDesc {
1956+ // Shared by vertex and mesh pipelines
1957+ root_signature : desc. layout . shared . signature . as_ref ( ) ,
1958+ pixel_shader,
1959+ blend_state,
1960+ sample_mask : desc. multisample . mask as u32 ,
1961+ rasterizer_state,
1962+ depth_stencil_state,
1963+ primitive_topology_type : topology_class,
1964+ rtv_formats : Direct3D12 :: D3D12_RT_FORMAT_ARRAY {
1965+ RTFormats : rtv_formats,
1966+ NumRenderTargets : desc. color_targets . len ( ) as u32 ,
1967+ } ,
1968+ dsv_format,
1969+ sample_desc,
1970+ node_mask : 0 ,
1971+ cached_pso,
1972+ flags,
1973+
1974+ // Optional data that depends on the pipeline type (vertex vs mesh).
1975+ vertex_shader : Default :: default ( ) ,
1976+ input_layout : Default :: default ( ) ,
1977+ index_buffer_strip_cut_value : Default :: default ( ) ,
1978+ stream_output,
1979+ task_shader : Default :: default ( ) ,
1980+ mesh_shader : Default :: default ( ) ,
1981+ } ;
1982+ let mut input_element_descs = Vec :: new ( ) ;
1983+ let blob_vs;
1984+ let blob_ts;
1985+ let blob_ms;
1986+ let mut vertex_strides = [ None ; crate :: MAX_VERTEX_BUFFERS ] ;
1987+ match & desc. vertex_processor {
19571988 & crate :: VertexProcessor :: Standard {
19581989 vertex_buffers,
19591990 ref vertex_stage,
19601991 } => {
19611992 shader_stages |= wgt:: ShaderStages :: VERTEX ;
1962- let blob_vs = self . load_shader (
1993+ blob_vs = Some ( self . load_shader (
19631994 vertex_stage,
19641995 desc. layout ,
19651996 naga:: ShaderStage :: Vertex ,
19661997 desc. fragment_stage . as_ref ( ) ,
1967- ) ?;
1998+ ) ?) ;
19681999
1969- let mut input_element_descs = Vec :: new ( ) ;
19702000 for ( i, ( stride, vbuf) ) in vertex_strides. iter_mut ( ) . zip ( vertex_buffers) . enumerate ( )
19712001 {
19722002 * stride = Some ( vbuf. array_stride as u32 ) ;
@@ -1990,54 +2020,37 @@ impl crate::Device for super::Device {
19902020 } ) ;
19912021 }
19922022 }
1993- let raw_desc = Direct3D12 :: D3D12_GRAPHICS_PIPELINE_STATE_DESC {
1994- pRootSignature : root_signature,
1995- VS : blob_vs. create_native_shader ( ) ,
1996- PS : pixel_shader,
1997- GS : Direct3D12 :: D3D12_SHADER_BYTECODE :: default ( ) ,
1998- DS : Direct3D12 :: D3D12_SHADER_BYTECODE :: default ( ) ,
1999- HS : Direct3D12 :: D3D12_SHADER_BYTECODE :: default ( ) ,
2000- StreamOutput : stream_output,
2001- BlendState : blend_state,
2002- SampleMask : desc. multisample . mask as u32 ,
2003- RasterizerState : rasterizer_state,
2004- DepthStencilState : depth_stencil_state,
2005- InputLayout : Direct3D12 :: D3D12_INPUT_LAYOUT_DESC {
2006- pInputElementDescs : if input_element_descs. is_empty ( ) {
2007- ptr:: null ( )
2008- } else {
2009- input_element_descs. as_ptr ( )
2010- } ,
2011- NumElements : input_element_descs. len ( ) as u32 ,
2012- } ,
2013- IBStripCutValue : match desc. primitive . strip_index_format {
2014- Some ( wgt:: IndexFormat :: Uint16 ) => {
2015- Direct3D12 :: D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
2016- }
2017- Some ( wgt:: IndexFormat :: Uint32 ) => {
2018- Direct3D12 :: D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
2019- }
2020- None => Direct3D12 :: D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED ,
2023+ stream_desc. vertex_shader = blob_vs. as_ref ( ) . unwrap ( ) . create_native_shader ( ) ;
2024+ stream_desc. input_layout = Direct3D12 :: D3D12_INPUT_LAYOUT_DESC {
2025+ pInputElementDescs : if input_element_descs. is_empty ( ) {
2026+ ptr:: null ( )
2027+ } else {
2028+ input_element_descs. as_ptr ( )
20212029 } ,
2022- PrimitiveTopologyType : topology_class,
2023- NumRenderTargets : desc. color_targets . len ( ) as u32 ,
2024- RTVFormats : rtv_formats,
2025- DSVFormat : dsv_format,
2026- SampleDesc : sample_desc,
2027- NodeMask : 0 ,
2028- CachedPSO : cached_pso,
2029- Flags : flags,
2030+ NumElements : input_element_descs. len ( ) as u32 ,
2031+ } ;
2032+ stream_desc. index_buffer_strip_cut_value = match desc. primitive . strip_index_format {
2033+ Some ( wgt:: IndexFormat :: Uint16 ) => {
2034+ Direct3D12 :: D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
2035+ }
2036+ Some ( wgt:: IndexFormat :: Uint32 ) => {
2037+ Direct3D12 :: D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
2038+ }
2039+ None => Direct3D12 :: D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED ,
2040+ } ;
2041+ stream_desc. stream_output = Direct3D12 :: D3D12_STREAM_OUTPUT_DESC {
2042+ pSODeclaration : ptr:: null ( ) ,
2043+ NumEntries : 0 ,
2044+ pBufferStrides : ptr:: null ( ) ,
2045+ NumStrides : 0 ,
2046+ RasterizedStream : 0 ,
20302047 } ;
2031- unsafe {
2032- profiling:: scope!( "ID3D12Device::CreateGraphicsPipelineState" ) ;
2033- self . raw . CreateGraphicsPipelineState ( & raw_desc)
2034- }
20352048 }
20362049 crate :: VertexProcessor :: Mesh {
20372050 task_stage,
20382051 mesh_stage,
20392052 } => {
2040- let blob_ts = if let Some ( ts) = task_stage {
2053+ blob_ts = if let Some ( ts) = task_stage {
20412054 shader_stages |= wgt:: ShaderStages :: TASK ;
20422055 Some ( self . load_shader (
20432056 ts,
@@ -2054,48 +2067,36 @@ impl crate::Device for super::Device {
20542067 Default :: default ( )
20552068 } ;
20562069 shader_stages |= wgt:: ShaderStages :: MESH ;
2057- let blob_ms = self . load_shader (
2070+ blob_ms = Some ( self . load_shader (
20582071 mesh_stage,
20592072 desc. layout ,
20602073 naga:: ShaderStage :: Mesh ,
20612074 desc. fragment_stage . as_ref ( ) ,
2062- ) ?;
2063- let desc = super :: MeshShaderPipelineStateStream {
2064- root_signature : root_signature
2065- . as_ref ( )
2066- . map ( |a| a. as_raw ( ) . cast ( ) )
2067- . unwrap_or ( ptr:: null_mut ( ) ) ,
2068- task_shader,
2069- pixel_shader,
2070- mesh_shader : blob_ms. create_native_shader ( ) ,
2071- blend_state,
2072- sample_mask : desc. multisample . mask as u32 ,
2073- rasterizer_state,
2074- depth_stencil_state,
2075- primitive_topology_type : topology_class,
2076- rtv_formats : Direct3D12 :: D3D12_RT_FORMAT_ARRAY {
2077- RTFormats : rtv_formats,
2078- NumRenderTargets : desc. color_targets . len ( ) as u32 ,
2079- } ,
2080- dsv_format,
2081- sample_desc,
2082- node_mask : 0 ,
2083- cached_pso,
2084- flags,
2085- } ;
2086- let mut raw_desc = unsafe { desc. to_bytes ( ) } ;
2087- let stream_desc = Direct3D12 :: D3D12_PIPELINE_STATE_STREAM_DESC {
2088- SizeInBytes : raw_desc. len ( ) ,
2089- pPipelineStateSubobjectStream : raw_desc. as_mut_ptr ( ) . cast ( ) ,
2090- } ;
2091- let device: Direct3D12 :: ID3D12Device2 = self . raw . cast ( ) . unwrap ( ) ;
2075+ ) ?) ;
2076+ stream_desc. task_shader = task_shader;
2077+ stream_desc. mesh_shader = blob_ms. as_ref ( ) . unwrap ( ) . create_native_shader ( ) ;
2078+ }
2079+ } ;
2080+ let raw: Direct3D12 :: ID3D12PipelineState =
2081+ // If stream descriptors are available, use them as they are more flexible.
2082+ if let Ok ( device) = self . raw . cast :: < Direct3D12 :: ID3D12Device2 > ( ) {
2083+ // Prefer stream descs where possible
2084+ let mut stream = stream_desc. to_stream ( ) ;
20922085 unsafe {
20932086 profiling:: scope!( "ID3D12Device2::CreatePipelineState" ) ;
2094- device. CreatePipelineState ( & stream_desc)
2087+ stream. create_pipeline_state ( & device) . map_err ( |err| {
2088+ crate :: PipelineError :: Linkage ( shader_stages, err. to_string ( ) )
2089+ } ) ?
20952090 }
2096- }
2097- }
2098- . map_err ( |err| crate :: PipelineError :: Linkage ( shader_stages, err. to_string ( ) ) ) ?;
2091+ } else {
2092+ unsafe {
2093+ // Safety: `stream_desc` entirely outlives the `desc`.
2094+ let desc = stream_desc. to_graphics_pipeline_descriptor ( ) ;
2095+ self . raw . CreateGraphicsPipelineState ( & desc) . map_err ( |err| {
2096+ crate :: PipelineError :: Linkage ( shader_stages, err. to_string ( ) )
2097+ } ) ?
2098+ }
2099+ } ;
20992100
21002101 if let Some ( label) = desc. label {
21012102 raw. set_name ( label) ?;
0 commit comments