Skip to content

Commit 6be1240

Browse files
committed
allow one naga module to transpile into multiple outputs
1 parent 5cc7ec4 commit 6be1240

File tree

4 files changed

+78
-62
lines changed

4 files changed

+78
-62
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ exclude = [
1313
resolver = "2"
1414

1515
[workspace.dependencies]
16-
spirv-builder = { git = "https://github.com/Rust-GPU/rust-gpu", rev = "86fc48032c4cd4afb74f1d81ae859711d20386a1", default-features = false }
16+
spirv-builder = { git = "https://github.com/Rust-GPU/rust-gpu", rev = "8cb17db18d8a44e1de7c9b3ea2b65d5aaf24b919", default-features = false }
1717
anyhow = "1.0.94"
1818
clap = { version = "4.5.37", features = ["derive"] }
1919
crossterm = "0.28.1"

crates/cargo-gpu/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ default-run = "cargo-gpu"
1111
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1212

1313
[features]
14-
# Enable naga support to convert spirv to wgsl
14+
# Enable naga transpiling
1515
naga = ["dep:naga"]
16+
# Enable naga transpiling to wgsl
17+
wgsl-out = ["naga", "naga/wgsl-out"]
1618

1719
[dependencies]
1820
cargo_metadata.workspace = true
@@ -28,7 +30,7 @@ serde.workspace = true
2830
serde_json.workspace = true
2931
crossterm.workspace = true
3032
semver.workspace = true
31-
naga = { workspace = true, optional = true, features = ["spv-in", "wgsl-out"] }
33+
naga = { workspace = true, optional = true, features = ["spv-in"] }
3234

3335
[dev-dependencies]
3436
test-log.workspace = true

crates/cargo-gpu/src/naga.rs

Lines changed: 71 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
11
//! naga transpiling to wgsl support, hidden behind feature `naga`
22
3-
use crate::linkage::spv_entry_point_to_wgsl;
43
use anyhow::Context as _;
54
use naga::error::ShaderError;
65
pub use naga::valid::Capabilities;
7-
use spirv_builder::{CompileResult, ModuleResult};
8-
use std::collections::BTreeMap;
6+
use naga::valid::ModuleInfo;
7+
use naga::Module;
8+
use spirv_builder::{CompileResult, GenericCompileResult};
99
use std::path::{Path, PathBuf};
1010

11+
/// Naga [`Module`] with [`ModuleInfo`]
12+
#[derive(Clone, Debug)]
13+
#[expect(
14+
clippy::exhaustive_structs,
15+
reason = "never adding private members to this struct"
16+
)]
17+
pub struct NagaModule {
18+
/// path to the original spv
19+
pub spv_path: PathBuf,
20+
/// naga shader [`Module`]
21+
pub module: Module,
22+
/// naga [`ModuleInfo`] from validation
23+
pub info: ModuleInfo,
24+
}
25+
1126
/// convert a single spv file to a wgsl file using naga
12-
fn spv_to_wgsl(spv_src: &Path, wgsl_dst: &Path, capabilities: Capabilities) -> anyhow::Result<()> {
13-
let inner = || -> anyhow::Result<()> {
27+
fn parse_spv(spv_src: &Path, capabilities: Capabilities) -> anyhow::Result<NagaModule> {
28+
let inner = || -> anyhow::Result<_> {
1429
let spv_bytes = std::fs::read(spv_src).context("could not read spv file")?;
1530
let opts = naga::front::spv::Options::default();
1631
let module = naga::front::spv::parse_u8_slice(&spv_bytes, &opts)
@@ -30,69 +45,68 @@ fn spv_to_wgsl(spv_src: &Path, wgsl_dst: &Path, capabilities: Capabilities) -> a
3045
inner: Box::new(err),
3146
})
3247
.context("validation of naga module failed")?;
33-
let wgsl =
34-
naga::back::wgsl::write_string(&module, &info, naga::back::wgsl::WriterFlags::empty())
35-
.context("naga conversion to wgsl failed")?;
36-
std::fs::write(wgsl_dst, wgsl).context("failed to write wgsl file")?;
37-
Ok(())
48+
Ok(NagaModule {
49+
module,
50+
info,
51+
spv_path: PathBuf::from(spv_src),
52+
})
3853
};
39-
inner().with_context(|| {
40-
format!(
41-
"converting spv '{}' to wgsl '{}' failed ",
42-
spv_src.display(),
43-
wgsl_dst.display()
44-
)
45-
})
46-
}
47-
48-
/// convert spv file path to a valid unique wgsl file path
49-
fn wgsl_file_name(path: &Path) -> PathBuf {
50-
path.with_extension("wgsl")
54+
inner().with_context(|| format!("parsing spv '{}' failed", spv_src.display()))
5155
}
5256

5357
/// Extension trait for naga transpiling
5458
pub trait CompileResultNagaExt {
55-
/// Transpile the spirv binaries to wgsl source code using [`naga`], typically for webgpu compatibility.
56-
///
57-
/// Converts this [`CompileResult`] of spirv binaries and entry points to a [`CompileResult`] pointing to wgsl source code files and their associated wgsl entry
58-
/// points.
59+
/// Transpile the spirv binaries to some other format using [`naga`].
5960
///
6061
/// # Errors
6162
/// [`naga`] transpile may error in various ways
62-
fn transpile_to_wgsl(&self, capabilities: Capabilities) -> anyhow::Result<CompileResult>;
63+
fn naga_transpile(&self, capabilities: Capabilities) -> anyhow::Result<NagaTranspile>;
6364
}
6465

6566
impl CompileResultNagaExt for CompileResult {
6667
#[inline]
67-
fn transpile_to_wgsl(&self, capabilities: Capabilities) -> anyhow::Result<CompileResult> {
68-
Ok(match &self.module {
69-
ModuleResult::SingleModule(spv) => {
70-
let wgsl = wgsl_file_name(spv);
71-
spv_to_wgsl(spv, &wgsl, capabilities)?;
72-
let entry_points = self
73-
.entry_points
74-
.iter()
75-
.map(|entry| spv_entry_point_to_wgsl(entry))
76-
.collect();
77-
Self {
78-
entry_points,
79-
module: ModuleResult::SingleModule(wgsl),
80-
}
81-
}
82-
ModuleResult::MultiModule(map) => {
83-
let new_map: BTreeMap<String, PathBuf> = map
84-
.iter()
85-
.map(|(entry_point, spv)| {
86-
let wgsl = wgsl_file_name(spv);
87-
spv_to_wgsl(spv, &wgsl, capabilities)?;
88-
Ok((spv_entry_point_to_wgsl(entry_point), wgsl))
89-
})
90-
.collect::<anyhow::Result<_>>()?;
91-
Self {
92-
entry_points: new_map.keys().cloned().collect(),
93-
module: ModuleResult::MultiModule(new_map),
94-
}
95-
}
96-
})
68+
fn naga_transpile(&self, capabilities: Capabilities) -> anyhow::Result<NagaTranspile> {
69+
Ok(NagaTranspile(self.try_map(
70+
|entry| Ok(entry.clone()),
71+
|spv| parse_spv(spv, capabilities),
72+
)?))
73+
}
74+
}
75+
76+
/// Main struct for naga transpilation
77+
#[expect(
78+
clippy::exhaustive_structs,
79+
reason = "never adding private members to this struct"
80+
)]
81+
pub struct NagaTranspile(pub GenericCompileResult<NagaModule>);
82+
83+
impl NagaTranspile {
84+
/// Transpile to wgsl source code, typically for webgpu compatibility.
85+
///
86+
/// Returns a [`CompileResult`] of wgsl source code files and their associated wgsl entry points.
87+
///
88+
/// # Errors
89+
/// converting naga module to wgsl may fail
90+
#[inline]
91+
#[cfg(feature = "wgsl-out")]
92+
pub fn to_wgsl(
93+
&self,
94+
writer_flags: naga::back::wgsl::WriterFlags,
95+
) -> anyhow::Result<CompileResult> {
96+
self.0.try_map(
97+
|entry| Ok(crate::linkage::spv_entry_point_to_wgsl(entry)),
98+
|module| {
99+
let inner = || -> anyhow::Result<_> {
100+
let wgsl_dst = module.spv_path.with_extension("wgsl");
101+
let wgsl =
102+
naga::back::wgsl::write_string(&module.module, &module.info, writer_flags)
103+
.context("naga conversion to wgsl failed")?;
104+
std::fs::write(&wgsl_dst, wgsl).context("failed to write wgsl file")?;
105+
Ok(wgsl_dst)
106+
};
107+
inner()
108+
.with_context(|| format!("transpiling to wgsl '{}'", module.spv_path.display()))
109+
},
110+
)
97111
}
98112
}

0 commit comments

Comments
 (0)