blob: de00349a984688dfe688dd6927c9d357e44a7fdc [file] [log] [blame]
//! Generation of Rust derive functions for encoding.
use quote::{format_ident, quote};
use crate::layout::{LayoutModule, LayoutTypeDef};
use crate::parse::{GpuScalar, GpuType};
pub fn gen_derive(module: &LayoutModule) -> proc_macro2::TokenStream {
let mut ts = proc_macro2::TokenStream::new();
let module_name = format_ident!("{}", module.name);
for name in &module.def_names {
let def = module.defs.get(name).unwrap();
ts.extend(gen_derive_def(name, def.0.size, &def.1));
}
quote! {
mod #module_name {
pub trait HalfToLeBytes {
fn to_le_bytes(&self) -> [u8; 2];
}
impl HalfToLeBytes for half::f16 {
fn to_le_bytes(&self) -> [u8; 2] {
self.to_bits().to_le_bytes()
}
}
#ts
}
}
}
fn gen_derive_def(name: &str, size: usize, def: &LayoutTypeDef) -> proc_macro2::TokenStream {
let name_id = format_ident!("{}", name);
match def {
LayoutTypeDef::Struct(fields) => {
let mut gen_fields = proc_macro2::TokenStream::new();
let mut encode_fields = proc_macro2::TokenStream::new();
for (field_name, offset, ty) in fields {
let field_name_id = format_ident!("{}", field_name);
let gen_ty = gen_derive_ty(&ty.ty);
let gen_field = quote! {
pub #field_name_id: #gen_ty,
};
gen_fields.extend(gen_field);
encode_fields.extend(gen_encode_field(field_name, *offset, &ty.ty));
}
quote! {
#[derive(Clone)]
pub struct #name_id {
#gen_fields
}
impl crate::encoder::Encode for #name_id {
fn fixed_size() -> usize {
#size
}
fn encode_to(&self, buf: &mut [u8]) {
#encode_fields
}
}
}
}
LayoutTypeDef::Enum(variants) => {
let mut gen_variants = proc_macro2::TokenStream::new();
let mut cases = proc_macro2::TokenStream::new();
for (variant_ix, (variant_name, payload)) in variants.iter().enumerate() {
let variant_id = format_ident!("{}", variant_name);
let field_tys = payload.iter().map(|(_offset, ty)| gen_derive_ty(&ty.ty));
let variant = quote! {
#variant_id(#(#field_tys),*),
};
gen_variants.extend(variant);
let mut args = Vec::new();
let mut field_encoders = proc_macro2::TokenStream::new();
let mut tag_field = None;
for (i, (offset, ty)) in payload.iter().enumerate() {
let field_id = format_ident!("f{}", i);
if matches!(ty.ty, GpuType::Scalar(GpuScalar::TagFlags)) {
tag_field = Some(field_id.clone());
} else {
let field_encoder = quote! {
#field_id.encode_to(&mut buf[#offset..]);
};
field_encoders.extend(field_encoder);
}
args.push(field_id);
}
let tag = variant_ix as u32;
let tag_encode = match tag_field {
None => quote! {
buf[0..4].copy_from_slice(&#tag.to_le_bytes());
},
Some(tag_field) => quote! {
buf[0..4].copy_from_slice(&(#tag | ((*#tag_field as u32) << 16)).to_le_bytes());
},
};
let case = quote! {
#name_id::#variant_id(#(#args),*) => {
#tag_encode
#field_encoders
}
};
cases.extend(case);
}
quote! {
#[derive(Clone)]
pub enum #name_id {
#gen_variants
}
impl crate::encoder::Encode for #name_id {
fn fixed_size() -> usize {
#size
}
fn encode_to(&self, buf: &mut [u8]) {
match self {
#cases
}
}
}
}
}
}
}
/// Generate a Rust type.
fn gen_derive_ty(ty: &GpuType) -> proc_macro2::TokenStream {
match ty {
GpuType::Scalar(s) => gen_derive_scalar_ty(s),
GpuType::Vector(s, len) => {
let scalar = gen_derive_scalar_ty(s);
quote! { [#scalar; #len] }
}
GpuType::InlineStruct(name) => {
let name_id = format_ident!("{}", name);
quote! { #name_id }
}
GpuType::Ref(ty) => {
let gen_ty = gen_derive_ty(ty);
quote! { crate::encoder::Ref<#gen_ty> }
}
}
}
fn gen_derive_scalar_ty(ty: &GpuScalar) -> proc_macro2::TokenStream {
match ty {
GpuScalar::F16 => quote!(half::f16),
GpuScalar::F32 => quote!(f32),
GpuScalar::I8 => quote!(i8),
GpuScalar::I16 => quote!(i16),
GpuScalar::I32 => quote!(i32),
GpuScalar::U8 => quote!(u8),
GpuScalar::U16 => quote!(u16),
GpuScalar::U32 => quote!(u32),
GpuScalar::TagFlags => quote!(u16),
}
}
fn gen_encode_field(name: &str, offset: usize, ty: &GpuType) -> proc_macro2::TokenStream {
let name_id = format_ident!("{}", name);
match ty {
// encoding of flags into tag word is handled elsewhere
GpuType::Scalar(GpuScalar::TagFlags) => quote! {},
GpuType::Scalar(s) => {
let end = offset + s.size();
quote! {
buf[#offset..#end].copy_from_slice(&self.#name_id.to_le_bytes());
}
}
GpuType::Vector(s, len) => {
let size = s.size();
quote! {
for i in 0..#len {
let offset = #offset + i * #size;
buf[offset..offset + #size].copy_from_slice(&self.#name_id[i].to_le_bytes());
}
}
}
GpuType::Ref(_) => {
quote! {
buf[#offset..#offset + 4].copy_from_slice(&self.#name_id.offset().to_le_bytes());
}
}
_ => {
quote! {
&self.#name_id.encode_to(&mut buf[#offset..]);
}
}
}
}