| // Copyright 2023 the Vello Authors |
| // SPDX-License-Identifier: Apache-2.0 OR MIT |
| |
| use naga::front::wgsl; |
| use naga::valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags}; |
| use naga::{AddressSpace, ArraySize, ImageClass, Module, StorageAccess, WithSpan}; |
| use std::collections::{HashMap, HashSet}; |
| use std::fmt; |
| use std::path::{Path, PathBuf}; |
| use std::sync::OnceLock; |
| use thiserror::Error; |
| |
| pub mod permutations; |
| pub mod preprocess; |
| |
| #[cfg(feature = "msl")] |
| pub mod msl; |
| |
| use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo}; |
| |
| pub type Result<T> = std::result::Result<T, Error>; |
| pub type CoalescedResult<T> = std::result::Result<T, ErrorVec>; |
| |
| #[derive(Error, Debug)] |
| pub struct ErrorVec(Vec<Error>); |
| |
| #[derive(Error, Debug)] |
| #[error("{source} ({name}) {msg}")] |
| pub struct Error { |
| name: String, |
| msg: String, |
| source: InnerError, |
| } |
| |
| #[derive(Error, Debug)] |
| enum InnerError { |
| #[error("failed to parse shader")] |
| Parse(#[from] wgsl::ParseError), |
| |
| #[error("failed to validate shader")] |
| Validate(#[from] WithSpan<ValidationError>), |
| |
| #[error("missing entry point function")] |
| EntryPointNotFound, |
| } |
| |
| impl fmt::Display for ErrorVec { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| for e in self.0.iter() { |
| write!(f, "{e}")?; |
| } |
| Ok(()) |
| } |
| } |
| |
| impl Error { |
| fn new(wgsl: &str, name: &str, error: impl Into<InnerError>) -> Self { |
| let source = error.into(); |
| Self { |
| name: name.to_owned(), |
| msg: source.emit_msg(wgsl, &format!("({name} preprocessed)")), |
| source, |
| } |
| } |
| } |
| |
| impl InnerError { |
| fn emit_msg(&self, wgsl: &str, name: &str) -> String { |
| match self { |
| Self::Parse(e) => e.emit_to_string_with_path(wgsl, name), |
| Self::Validate(e) => e.emit_to_string_with_path(wgsl, name), |
| _ => String::default(), |
| } |
| } |
| } |
| |
| #[derive(Debug)] |
| pub struct ShaderInfo { |
| pub source: String, |
| pub module: Module, |
| pub module_info: ModuleInfo, |
| pub workgroup_size: [u32; 3], |
| pub bindings: Vec<BindingInfo>, |
| pub workgroup_buffers: Vec<WorkgroupBufferInfo>, |
| } |
| |
| impl ShaderInfo { |
| pub fn new(name: &str, source: String, entry_point: &str) -> Result<Self> { |
| let module = wgsl::parse_str(&source).map_err(|error| Error::new(&source, name, error))?; |
| let module_info = naga::valid::Validator::new(ValidationFlags::all(), Capabilities::all()) |
| .validate(&module) |
| .map_err(|error| Error::new(&source, name, error))?; |
| let (entry_index, entry) = module |
| .entry_points |
| .iter() |
| .enumerate() |
| .find(|(_, entry)| entry.name.as_str() == entry_point) |
| .ok_or(Error::new(&source, name, InnerError::EntryPointNotFound))?; |
| let mut bindings = vec![]; |
| let mut workgroup_buffers = vec![]; |
| let mut wg_buffer_idx = 0; |
| let entry_info = module_info.get_entry_point(entry_index); |
| for (var_handle, var) in module.global_variables.iter() { |
| if entry_info[var_handle].is_empty() { |
| continue; |
| } |
| let binding_ty = match module.types[var.ty].inner { |
| naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner, |
| ref ty => ty, |
| }; |
| let Some(binding) = &var.binding else { |
| if var.space == AddressSpace::WorkGroup { |
| let index = wg_buffer_idx; |
| wg_buffer_idx += 1; |
| let size_in_bytes = match binding_ty { |
| naga::TypeInner::Array { |
| size: ArraySize::Constant(size), |
| stride, |
| .. |
| } => u32::from(*size) * stride, |
| naga::TypeInner::Struct { span, .. } => *span, |
| naga::TypeInner::Scalar(scalar) |
| | naga::TypeInner::Vector { scalar, .. } |
| | naga::TypeInner::Matrix { scalar, .. } |
| | naga::TypeInner::Atomic(scalar) => scalar.width as u32, |
| _ => { |
| // Not a valid workgroup variable type. At least not one that is used |
| // in our shaders. |
| continue; |
| } |
| }; |
| workgroup_buffers.push(WorkgroupBufferInfo { |
| size_in_bytes, |
| index, |
| }); |
| } |
| continue; |
| }; |
| let mut resource = BindingInfo { |
| name: var.name.clone(), |
| location: (binding.group, binding.binding), |
| ty: BindType::Buffer, |
| }; |
| if let naga::TypeInner::Image { class, .. } = &binding_ty { |
| resource.ty = BindType::ImageRead; |
| if let ImageClass::Storage { access, .. } = class { |
| if access.contains(StorageAccess::STORE) { |
| resource.ty = BindType::Image; |
| } |
| } |
| } else { |
| resource.ty = BindType::BufReadOnly; |
| match var.space { |
| AddressSpace::Storage { access } => { |
| if access.contains(StorageAccess::STORE) { |
| resource.ty = BindType::Buffer; |
| } |
| } |
| AddressSpace::Uniform => { |
| resource.ty = BindType::Uniform; |
| } |
| _ => {} |
| } |
| } |
| bindings.push(resource); |
| } |
| bindings.sort_by_key(|res| res.location); |
| let workgroup_size = entry.workgroup_size; |
| Ok(Self { |
| source: postprocess(&source), |
| module, |
| module_info, |
| workgroup_size, |
| bindings, |
| workgroup_buffers, |
| }) |
| } |
| |
| /// Same as [`ShaderInfo::from_dir`] but uses the default shader directory provided by [`shader_dir`]. |
| pub fn from_default() -> CoalescedResult<HashMap<String, Self>> { |
| Self::from_dir(shader_dir()) |
| } |
| |
| pub fn from_dir(shader_dir: impl AsRef<Path>) -> CoalescedResult<HashMap<String, Self>> { |
| use std::fs; |
| let shader_dir = shader_dir.as_ref(); |
| let permutation_map = |
| if let Ok(permutations_source) = fs::read_to_string(shader_dir.join("permutations")) { |
| permutations::parse(&permutations_source) |
| } else { |
| HashMap::default() |
| }; |
| //println!("{permutation_map:?}"); |
| let imports = preprocess::get_imports(shader_dir); |
| let mut errors = vec![]; |
| let mut info = HashMap::default(); |
| let defines: HashSet<_> = HashSet::default(); |
| for entry in shader_dir |
| .read_dir() |
| .expect("Can read shader import directory") |
| .filter_map(move |e| { |
| e.ok() |
| .filter(|e| e.path().extension().map(|e| e == "wgsl").unwrap_or(false)) |
| }) |
| { |
| let file_name = entry.file_name(); |
| if let Some(name) = file_name.to_str() { |
| let suffix = ".wgsl"; |
| if let Some(shader_name) = name.strip_suffix(suffix) { |
| let contents = fs::read_to_string(shader_dir.join(&file_name)) |
| .unwrap_or_else(|_| panic!("Couldn't read shader {shader_name} contents")); |
| if let Some(permutations) = permutation_map.get(shader_name) { |
| for permutation in permutations { |
| let mut defines = defines.clone(); |
| defines.extend(permutation.defines.iter().cloned()); |
| let source = preprocess::preprocess(&contents, &defines, &imports); |
| match Self::new(&permutation.name, source, "main") { |
| Ok(shader_info) => { |
| info.insert(permutation.name.clone(), shader_info); |
| } |
| Err(e) => { |
| errors.push(e); |
| } |
| } |
| } |
| } else { |
| let source = preprocess::preprocess(&contents, &defines, &imports); |
| match Self::new(shader_name, source, "main") { |
| Ok(shader_info) => { |
| info.insert(shader_name.to_string(), shader_info); |
| } |
| Err(e) => { |
| errors.push(e); |
| } |
| } |
| } |
| } |
| } |
| } |
| if !errors.is_empty() { |
| Err(ErrorVec(errors)) |
| } else { |
| Ok(info) |
| } |
| } |
| } |
| |
| // TODO: This is a workaround for gfx-rs/wgpu#5476. Since naga can't handle the `enable` directive, |
| // we allow its use in other WGSL compilers using our own "#enable" post-process directive. Remove |
| // this mechanism once naga supports the directive. |
| fn postprocess(wgsl: &str) -> String { |
| let mut output = String::with_capacity(wgsl.len()); |
| for line in wgsl.lines() { |
| if line.starts_with("//__#enable") { |
| output.push_str(&line["//__#".len()..]); |
| } else { |
| output.push_str(line); |
| } |
| output.push('\n'); |
| } |
| output |
| } |
| |
| /// Returns the absolute path to the directory containing the WGSL shaders. |
| /// |
| /// The path is determined at compile time and is likely only valid on the compiling machine. |
| // NOTE: Embedding build environment info into the code makes reproducible builds trickier. |
| pub fn shader_dir() -> &'static PathBuf { |
| static SHADER_DIR: OnceLock<PathBuf> = OnceLock::new(); |
| SHADER_DIR.get_or_init(|| manifest_dir().join("shader")) |
| } |
| |
| // In a regular cargo build the manifest directory is simply given by CARGO_MANIFEST_DIR. |
| // |
| // Skia, an external consumer of this crate, uses Bazel rules to compile Rust code. Due to |
| // limitations in Bazel's rust support, Skia maintains its own build definitions |
| // (https://source.chromium.org/chromium/chromium/src/+/main:third_party/skia/bazel/external/vello/BUILD.bazel). |
| // |
| // Because of the current setup, Bazel sets CARGO_MANIFEST_DIR to the workspace root instead of the |
| // actual crate being built. This could be improved but until then, we work around this by allowing |
| // the absolute path to the vello_shader crate's manifest to be specified using the |
| // `UNSTABLE_BAZEL_VELLO_SHADERS_CRATE_MANIFEST_PATH` build script environment variable. |
| // |
| // This should never be set when using cargo. |
| fn manifest_dir() -> PathBuf { |
| use std::env; |
| env::var_os("UNSTABLE_BAZEL_VELLO_SHADERS_CRATE_MANIFEST_PATH") |
| .and_then(|p| Path::new(&p).parent().map(|p| p.to_owned())) |
| .unwrap_or_else(|| PathBuf::from(env!("CARGO_MANIFEST_DIR"))) |
| .to_path_buf() |
| } |