blob: 5c3f6b9495b28f31f03e26696db4883114b62f58 [file] [log] [blame] [edit]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Also licensed under MIT license, at your choice.
use std::{
borrow::Cow,
collections::{hash_map::Entry, HashMap},
num::{NonZeroU32, NonZeroU64},
sync::atomic::{AtomicU64, Ordering},
};
use futures_intrusive::channel::shared::GenericOneshotReceiver;
use parking_lot::RawMutex;
use wgpu::{
util::DeviceExt, BindGroup, BindGroupLayout, Buffer, BufferAsyncError, BufferSlice, BufferView,
ComputePipeline, Device, Queue, Texture, TextureAspect, TextureFormat, TextureUsages,
TextureView, TextureViewDimension,
};
pub type Error = Box<dyn std::error::Error>;
#[derive(Clone, Copy)]
pub struct ShaderId(usize);
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Id(NonZeroU64);
static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct Engine {
shaders: Vec<Shader>,
}
struct Shader {
pipeline: ComputePipeline,
bind_group_layout: BindGroupLayout,
}
#[derive(Default)]
pub struct Recording {
commands: Vec<Command>,
}
#[derive(Clone, Copy)]
pub struct BufProxy {
size: u64,
id: Id,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum ImageFormat {
Rgba8,
Bgra8,
}
#[derive(Clone, Copy)]
pub struct ImageProxy {
width: u32,
height: u32,
format: ImageFormat,
id: Id,
}
#[derive(Clone, Copy)]
pub enum ResourceProxy {
Buf(BufProxy),
Image(ImageProxy),
}
pub enum ExternalResource<'a> {
Buf(BufProxy, &'a Buffer),
Image(ImageProxy, &'a TextureView),
}
pub enum Command {
Upload(BufProxy, Vec<u8>),
UploadUniform(BufProxy, Vec<u8>),
UploadImage(ImageProxy, Vec<u8>),
// Discussion question: third argument is vec of resources?
// Maybe use tricks to make more ergonomic?
// Alternative: provide bufs & images as separate sequences
Dispatch(ShaderId, (u32, u32, u32), Vec<ResourceProxy>),
Download(BufProxy),
Clear(BufProxy, u64, Option<NonZeroU64>),
}
#[derive(Default)]
pub struct Downloads {
buf_map: HashMap<Id, Buffer>,
}
/// The type of resource that will be bound to a slot in a shader.
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum BindType {
/// A storage buffer with read/write access.
Buffer,
/// A storage buffer with read only access.
BufReadOnly,
/// A small storage buffer to be used as uniforms.
Uniform,
/// A storage image.
Image(ImageFormat),
/// A storage image with read only access.
ImageRead(ImageFormat),
// TODO: Uniform, Sampler, maybe others
}
#[derive(Default)]
struct BindMap {
buf_map: HashMap<Id, Buffer>,
image_map: HashMap<Id, (Texture, TextureView)>,
}
impl Engine {
pub fn new() -> Engine {
Engine { shaders: vec![] }
}
/// Add a shader.
///
/// This function is somewhat limited, it doesn't apply a label, only allows one bind group,
/// doesn't support push constants, and entry point is hardcoded as "main".
///
/// Maybe should do template instantiation here? But shader compilation pipeline feels maybe
/// a bit separate.
pub fn add_shader(
&mut self,
device: &Device,
wgsl: Cow<'static, str>,
layout: &[BindType],
) -> Result<ShaderId, Error> {
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(wgsl),
});
let entries = layout
.iter()
.enumerate()
.map(|(i, bind_type)| match bind_type {
BindType::Buffer | BindType::BufReadOnly => wgpu::BindGroupLayoutEntry {
binding: i as u32,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage {
read_only: *bind_type == BindType::BufReadOnly,
},
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
BindType::Uniform => wgpu::BindGroupLayoutEntry {
binding: i as u32,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
BindType::Image(format) | BindType::ImageRead(format) => {
wgpu::BindGroupLayoutEntry {
binding: i as u32,
visibility: wgpu::ShaderStages::COMPUTE,
ty: if *bind_type == BindType::ImageRead(*format) {
wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
}
} else {
wgpu::BindingType::StorageTexture {
access: wgpu::StorageTextureAccess::WriteOnly,
format: format.to_wgpu(),
view_dimension: wgpu::TextureViewDimension::D2,
}
},
count: None,
}
}
_ => todo!(),
})
.collect::<Vec<_>>();
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &entries,
});
let compute_pipeline_layout =
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&compute_pipeline_layout),
module: &shader_module,
entry_point: "main",
});
let shader = Shader {
pipeline,
bind_group_layout,
};
let id = self.shaders.len();
self.shaders.push(shader);
Ok(ShaderId(id))
}
pub fn run_recording(
&mut self,
device: &Device,
queue: &Queue,
recording: &Recording,
external_resources: &[ExternalResource],
) -> Result<Downloads, Error> {
let mut bind_map = BindMap::default();
let mut downloads = Downloads::default();
let mut encoder = device.create_command_encoder(&Default::default());
for command in &recording.commands {
match command {
Command::Upload(buf_proxy, bytes) => {
let buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: &bytes,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
bind_map.insert_buf(buf_proxy.id, buf);
}
Command::UploadUniform(buf_proxy, bytes) => {
let buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: &bytes,
usage: wgpu::BufferUsages::UNIFORM,
});
bind_map.insert_buf(buf_proxy.id, buf);
}
Command::UploadImage(image_proxy, bytes) => {
let buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: &bytes,
usage: wgpu::BufferUsages::COPY_SRC,
});
let texture = device.create_texture(&wgpu::TextureDescriptor {
label: None,
size: wgpu::Extent3d {
width: image_proxy.width,
height: image_proxy.height,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: TextureUsages::TEXTURE_BINDING | TextureUsages::COPY_DST,
format: image_proxy.format.to_wgpu(),
});
let texture_view = texture.create_view(&wgpu::TextureViewDescriptor {
label: None,
dimension: Some(TextureViewDimension::D2),
aspect: TextureAspect::All,
mip_level_count: None,
base_mip_level: 0,
base_array_layer: 0,
array_layer_count: None,
format: Some(TextureFormat::Rgba8Unorm),
});
encoder.copy_buffer_to_texture(
wgpu::ImageCopyBuffer {
buffer: &buf,
layout: wgpu::ImageDataLayout {
offset: 0,
bytes_per_row: NonZeroU32::new(image_proxy.width * 4),
rows_per_image: None,
},
},
wgpu::ImageCopyTexture {
texture: &texture,
mip_level: 0,
origin: wgpu::Origin3d { x: 0, y: 0, z: 0 },
aspect: TextureAspect::All,
},
wgpu::Extent3d {
width: image_proxy.width,
height: image_proxy.height,
depth_or_array_layers: 1,
},
);
bind_map.insert_image(image_proxy.id, texture, texture_view)
}
Command::Dispatch(shader_id, wg_size, bindings) => {
// println!("dispatching {:?} with {} bindings", wg_size, bindings.len());
let shader = &self.shaders[shader_id.0];
let bind_group = bind_map.create_bind_group(
device,
&shader.bind_group_layout,
bindings,
external_resources,
)?;
let mut cpass = encoder.begin_compute_pass(&Default::default());
cpass.set_pipeline(&shader.pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups(wg_size.0, wg_size.1, wg_size.2);
}
Command::Download(proxy) => {
let src_buf = bind_map.buf_map.get(&proxy.id).ok_or("buffer not in map")?;
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: proxy.size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(src_buf, 0, &buf, 0, proxy.size);
downloads.buf_map.insert(proxy.id, buf);
}
Command::Clear(proxy, offset, size) => {
let buffer = bind_map.get_or_create(*proxy, device)?;
encoder.clear_buffer(buffer, *offset, *size);
}
}
}
queue.submit(Some(encoder.finish()));
Ok(downloads)
}
}
impl Recording {
pub fn push(&mut self, cmd: Command) {
self.commands.push(cmd);
}
pub fn upload(&mut self, data: impl Into<Vec<u8>>) -> BufProxy {
let data = data.into();
let buf_proxy = BufProxy::new(data.len() as u64);
self.push(Command::Upload(buf_proxy, data));
buf_proxy
}
pub fn upload_uniform(&mut self, data: impl Into<Vec<u8>>) -> BufProxy {
let data = data.into();
let buf_proxy = BufProxy::new(data.len() as u64);
self.push(Command::UploadUniform(buf_proxy, data));
buf_proxy
}
pub fn upload_image(
&mut self,
width: u32,
height: u32,
format: ImageFormat,
data: impl Into<Vec<u8>>,
) -> ImageProxy {
let data = data.into();
let image_proxy = ImageProxy::new(width, height, format);
self.push(Command::UploadImage(image_proxy, data));
image_proxy
}
pub fn dispatch<R>(&mut self, shader: ShaderId, wg_size: (u32, u32, u32), resources: R)
where
R: IntoIterator,
R::Item: Into<ResourceProxy>,
{
self.push(Command::Dispatch(
shader,
wg_size,
resources.into_iter().map(|r| r.into()).collect(),
));
}
pub fn download(&mut self, buf: BufProxy) {
self.push(Command::Download(buf));
}
pub fn clear_all(&mut self, buf: BufProxy) {
self.push(Command::Clear(buf, 0, None));
}
}
impl BufProxy {
pub fn new(size: u64) -> Self {
let id = Id::next();
BufProxy {
id,
size: size.max(16),
}
}
}
impl ImageFormat {
pub fn to_wgpu(self) -> wgpu::TextureFormat {
match self {
Self::Rgba8 => wgpu::TextureFormat::Rgba8Unorm,
Self::Bgra8 => wgpu::TextureFormat::Bgra8Unorm,
}
}
}
impl ImageProxy {
pub fn new(width: u32, height: u32, format: ImageFormat) -> Self {
let id = Id::next();
ImageProxy {
width,
height,
format,
id,
}
}
}
impl ResourceProxy {
pub fn new_buf(size: u64) -> Self {
Self::Buf(BufProxy::new(size))
}
pub fn new_image(width: u32, height: u32, format: ImageFormat) -> Self {
Self::Image(ImageProxy::new(width, height, format))
}
pub fn as_buf(&self) -> Option<&BufProxy> {
match self {
Self::Buf(proxy) => Some(&proxy),
_ => None,
}
}
pub fn as_image(&self) -> Option<&ImageProxy> {
match self {
Self::Image(proxy) => Some(&proxy),
_ => None,
}
}
}
impl From<BufProxy> for ResourceProxy {
fn from(value: BufProxy) -> Self {
Self::Buf(value)
}
}
impl From<ImageProxy> for ResourceProxy {
fn from(value: ImageProxy) -> Self {
Self::Image(value)
}
}
impl Id {
pub fn next() -> Id {
let val = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
// could use new_unchecked
Id(NonZeroU64::new(val + 1).unwrap())
}
}
impl BindMap {
fn insert_buf(&mut self, id: Id, buf: Buffer) {
self.buf_map.insert(id, buf);
}
fn insert_image(&mut self, id: Id, image: Texture, image_view: TextureView) {
self.image_map.insert(id, (image, image_view));
}
fn create_bind_group(
&mut self,
device: &Device,
layout: &BindGroupLayout,
bindings: &[ResourceProxy],
external_resources: &[ExternalResource],
) -> Result<BindGroup, Error> {
// These functions are ugly and linear, but the remap array should generally be
// small. Should find a better solution for this.
fn find_buf<'a>(
resources: &[ExternalResource<'a>],
proxy: &BufProxy,
) -> Option<&'a Buffer> {
for resource in resources {
match resource {
ExternalResource::Buf(p, buf) if p.id == proxy.id => {
return Some(buf);
}
_ => {}
}
}
None
}
fn find_image<'a>(
resources: &[ExternalResource<'a>],
proxy: &ImageProxy,
) -> Option<&'a TextureView> {
for resource in resources {
match resource {
ExternalResource::Image(p, view) if p.id == proxy.id => {
return Some(view);
}
_ => {}
}
}
None
}
for proxy in bindings {
match proxy {
ResourceProxy::Buf(proxy) => {
if find_buf(external_resources, proxy).is_some() {
continue;
}
if let Entry::Vacant(v) = self.buf_map.entry(proxy.id) {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: proxy.size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
v.insert(buf);
}
}
ResourceProxy::Image(proxy) => {
if find_image(external_resources, proxy).is_some() {
continue;
}
if let Entry::Vacant(v) = self.image_map.entry(proxy.id) {
let texture = device.create_texture(&wgpu::TextureDescriptor {
label: None,
size: wgpu::Extent3d {
width: proxy.width,
height: proxy.height,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: TextureUsages::TEXTURE_BINDING | TextureUsages::COPY_DST,
format: proxy.format.to_wgpu(),
});
let texture_view = texture.create_view(&wgpu::TextureViewDescriptor {
label: None,
dimension: Some(TextureViewDimension::D2),
aspect: TextureAspect::All,
mip_level_count: None,
base_mip_level: 0,
base_array_layer: 0,
array_layer_count: None,
format: Some(proxy.format.to_wgpu()),
});
v.insert((texture, texture_view));
}
}
}
}
let entries = bindings
.iter()
.enumerate()
.map(|(i, proxy)| match proxy {
ResourceProxy::Buf(proxy) => {
let buf = find_buf(external_resources, proxy)
.or_else(|| self.buf_map.get(&proxy.id))
.unwrap();
Ok(wgpu::BindGroupEntry {
binding: i as u32,
resource: buf.as_entire_binding(),
})
}
ResourceProxy::Image(proxy) => {
let view = find_image(external_resources, proxy)
.or_else(|| self.image_map.get(&proxy.id).map(|v| &v.1))
.unwrap();
Ok(wgpu::BindGroupEntry {
binding: i as u32,
resource: wgpu::BindingResource::TextureView(view),
})
}
})
.collect::<Result<Vec<_>, Error>>()?;
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout,
entries: &entries,
});
Ok(bind_group)
}
fn get_or_create(&mut self, proxy: BufProxy, device: &Device) -> Result<&Buffer, Error> {
match self.buf_map.entry(proxy.id) {
Entry::Occupied(occupied) => Ok(occupied.into_mut()),
Entry::Vacant(vacant) => {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: proxy.size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Ok(vacant.insert(buf))
}
}
}
}
pub struct DownloadsMapped<'a>(
HashMap<
Id,
(
BufferSlice<'a>,
GenericOneshotReceiver<RawMutex, Result<(), BufferAsyncError>>,
),
>,
);
impl Downloads {
// Discussion: should API change so we get one buffer, rather than mapping all?
pub fn map(&self) -> DownloadsMapped {
let mut map = HashMap::new();
for (id, buf) in &self.buf_map {
let buf_slice = buf.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buf_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
map.insert(*id, (buf_slice, receiver));
}
DownloadsMapped(map)
}
}
impl<'a> DownloadsMapped<'a> {
pub async fn get_mapped(&self, proxy: BufProxy) -> Result<BufferView, Error> {
let (slice, recv) = self.0.get(&proxy.id).ok_or("buffer not in map")?;
if let Some(recv_result) = recv.receive().await {
recv_result?;
} else {
return Err("channel was closed".into());
}
Ok(slice.get_mapped_range())
}
}