Skip to content

Commit 5cc7ec4

Browse files
committed
integrated naga transpiling to wgsl
1 parent 8ad72b1 commit 5cc7ec4

File tree

6 files changed

+250
-1
lines changed

6 files changed

+250
-1
lines changed

Cargo.lock

Lines changed: 136 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ tempdir = "0.3.7"
2828
test-log = "0.2.16"
2929
cargo_metadata = "0.19.2"
3030
semver = "1.0.26"
31+
naga = "25.0.1"
3132

3233
# This crate MUST NEVER be upgraded, we need this particular "first" version to support old rust-gpu builds
3334
legacy_target_specs = { package = "rustc_codegen_spirv-target-specs", version = "0.9.0", features = ["include_str"] }

crates/cargo-gpu/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ build = "build.rs"
1010
default-run = "cargo-gpu"
1111
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1212

13+
[features]
14+
# Enable naga support to convert spirv to wgsl
15+
naga = ["dep:naga"]
16+
1317
[dependencies]
1418
cargo_metadata.workspace = true
1519
anyhow.workspace = true
@@ -24,6 +28,7 @@ serde.workspace = true
2428
serde_json.workspace = true
2529
crossterm.workspace = true
2630
semver.workspace = true
31+
naga = { workspace = true, optional = true, features = ["spv-in", "wgsl-out"] }
2732

2833
[dev-dependencies]
2934
test-log.workspace = true

crates/cargo-gpu/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,15 @@ mod legacy_target_specs;
6565
mod linkage;
6666
mod lockfile;
6767
mod metadata;
68+
#[cfg(feature = "naga")]
69+
mod naga;
6870
mod show;
6971
mod spirv_source;
7072
mod test;
7173

7274
pub use install::*;
75+
#[cfg(feature = "naga")]
76+
pub use naga::*;
7377
pub use spirv_builder;
7478

7579
/// Central function to write to the user.

crates/cargo-gpu/src/linkage.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@ impl Linkage {
2323
.map(|comp| comp.as_os_str().to_string_lossy())
2424
.collect::<Vec<_>>()
2525
.join("/"),
26-
wgsl_entry_point: entry_point.as_ref().replace("::", ""),
26+
wgsl_entry_point: spv_entry_point_to_wgsl(entry_point.as_ref()),
2727
entry_point: entry_point.as_ref().to_owned(),
2828
}
2929
}
3030
}
31+
32+
/// Convert a spirv entry point to a valid wgsl entry point
33+
pub fn spv_entry_point_to_wgsl(entry_point: &str) -> String {
34+
entry_point.replace("::", "")
35+
}

crates/cargo-gpu/src/naga.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
//! naga transpiling to wgsl support, hidden behind feature `naga`
2+
3+
use crate::linkage::spv_entry_point_to_wgsl;
4+
use anyhow::Context as _;
5+
use naga::error::ShaderError;
6+
pub use naga::valid::Capabilities;
7+
use spirv_builder::{CompileResult, ModuleResult};
8+
use std::collections::BTreeMap;
9+
use std::path::{Path, PathBuf};
10+
11+
/// convert a single spv file to a wgsl file using naga
12+
fn spv_to_wgsl(spv_src: &Path, wgsl_dst: &Path, capabilities: Capabilities) -> anyhow::Result<()> {
13+
let inner = || -> anyhow::Result<()> {
14+
let spv_bytes = std::fs::read(spv_src).context("could not read spv file")?;
15+
let opts = naga::front::spv::Options::default();
16+
let module = naga::front::spv::parse_u8_slice(&spv_bytes, &opts)
17+
.map_err(|err| ShaderError {
18+
source: String::new(),
19+
label: None,
20+
inner: Box::new(err),
21+
})
22+
.context("naga could not parse spv")?;
23+
let mut validator =
24+
naga::valid::Validator::new(naga::valid::ValidationFlags::default(), capabilities);
25+
let info = validator
26+
.validate(&module)
27+
.map_err(|err| ShaderError {
28+
source: String::new(),
29+
label: None,
30+
inner: Box::new(err),
31+
})
32+
.context("validation of naga module failed")?;
33+
let wgsl =
34+
naga::back::wgsl::write_string(&module, &info, naga::back::wgsl::WriterFlags::empty())
35+
.context("naga conversion to wgsl failed")?;
36+
std::fs::write(wgsl_dst, wgsl).context("failed to write wgsl file")?;
37+
Ok(())
38+
};
39+
inner().with_context(|| {
40+
format!(
41+
"converting spv '{}' to wgsl '{}' failed ",
42+
spv_src.display(),
43+
wgsl_dst.display()
44+
)
45+
})
46+
}
47+
48+
/// convert spv file path to a valid unique wgsl file path
49+
fn wgsl_file_name(path: &Path) -> PathBuf {
50+
path.with_extension("wgsl")
51+
}
52+
53+
/// Extension trait for naga transpiling
54+
pub trait CompileResultNagaExt {
55+
/// Transpile the spirv binaries to wgsl source code using [`naga`], typically for webgpu compatibility.
56+
///
57+
/// Converts this [`CompileResult`] of spirv binaries and entry points to a [`CompileResult`] pointing to wgsl source code files and their associated wgsl entry
58+
/// points.
59+
///
60+
/// # Errors
61+
/// [`naga`] transpile may error in various ways
62+
fn transpile_to_wgsl(&self, capabilities: Capabilities) -> anyhow::Result<CompileResult>;
63+
}
64+
65+
impl CompileResultNagaExt for CompileResult {
66+
#[inline]
67+
fn transpile_to_wgsl(&self, capabilities: Capabilities) -> anyhow::Result<CompileResult> {
68+
Ok(match &self.module {
69+
ModuleResult::SingleModule(spv) => {
70+
let wgsl = wgsl_file_name(spv);
71+
spv_to_wgsl(spv, &wgsl, capabilities)?;
72+
let entry_points = self
73+
.entry_points
74+
.iter()
75+
.map(|entry| spv_entry_point_to_wgsl(entry))
76+
.collect();
77+
Self {
78+
entry_points,
79+
module: ModuleResult::SingleModule(wgsl),
80+
}
81+
}
82+
ModuleResult::MultiModule(map) => {
83+
let new_map: BTreeMap<String, PathBuf> = map
84+
.iter()
85+
.map(|(entry_point, spv)| {
86+
let wgsl = wgsl_file_name(spv);
87+
spv_to_wgsl(spv, &wgsl, capabilities)?;
88+
Ok((spv_entry_point_to_wgsl(entry_point), wgsl))
89+
})
90+
.collect::<anyhow::Result<_>>()?;
91+
Self {
92+
entry_points: new_map.keys().cloned().collect(),
93+
module: ModuleResult::MultiModule(new_map),
94+
}
95+
}
96+
})
97+
}
98+
}

0 commit comments

Comments
 (0)