diff --git a/shared/src/core/filter.rs b/shared/src/core/filter.rs index 4cf2b81..740c036 100644 --- a/shared/src/core/filter.rs +++ b/shared/src/core/filter.rs @@ -13,13 +13,13 @@ pub struct FilterSample { #[repr(C)] #[derive(Clone, Debug, Copy)] -pub struct FilterSampler { +pub struct DeviceFilterSampler { pub domain: Bounds2f, pub distrib: DevicePiecewiseConstant2D, pub f: DeviceArray2D, } -impl FilterSampler { +impl DeviceFilterSampler { pub fn sample(&self, u: Point2f) -> FilterSample { let (p, pdf, pi) = self.distrib.sample(u); @@ -38,7 +38,7 @@ pub trait FilterTrait { fn radius(&self) -> Vector2f; fn evaluate(&self, p: Point2f) -> Float; fn integral(&self) -> Float; - fn sample(&self, u: Point2f) -> FilterSample; + fn sample(&self, u: Point2f) -> DeviceFilterSample; } #[repr(C)] diff --git a/shared/src/core/primitive.rs b/shared/src/core/primitive.rs index 6cbb18d..72b3476 100644 --- a/shared/src/core/primitive.rs +++ b/shared/src/core/primitive.rs @@ -205,6 +205,7 @@ impl PrimitiveTrait for KdTreeAggregate { } } +#[repr(C)] #[derive(Clone, Debug, Copy)] #[enum_dispatch(PrimitiveTrait)] pub enum Primitive { diff --git a/shared/src/filters/boxf.rs b/shared/src/filters/boxf.rs index bee4080..413cb8d 100644 --- a/shared/src/filters/boxf.rs +++ b/shared/src/filters/boxf.rs @@ -1,5 +1,5 @@ use crate::Float; -use crate::core::filter::{FilterSample, FilterTrait}; +use crate::core::filter::{DeviceFilterSample, FilterTrait}; use crate::core::geometry::{Point2f, Vector2f}; use crate::utils::math::lerp; @@ -31,7 +31,7 @@ impl FilterTrait for BoxFilter { (2.0 * self.radius.x()) * (2.0 * self.radius.y()) } - fn sample(&self, u: Point2f) -> FilterSample { + fn sample(&self, u: Point2f) -> DeviceFilterSample { let p = Point2f::new( lerp(u[0], -self.radius.x(), self.radius.x()), lerp(u[1], -self.radius.y(), self.radius.y()), diff --git a/shared/src/filters/gaussian.rs b/shared/src/filters/gaussian.rs index 9d3997e..da81189 100644 --- a/shared/src/filters/gaussian.rs +++ b/shared/src/filters/gaussian.rs @@ -1,5 +1,5 @@ use crate::Float; -use crate::core::filter::{FilterSample, FilterSampler, FilterTrait}; +use crate::core::filter::{DeviceFilterSample, FilterSampler, FilterTrait}; use crate::core::geometry::{Point2f, Vector2f}; use crate::utils::math::{gaussian, gaussian_integral}; @@ -30,7 +30,7 @@ impl FilterTrait for GaussianFilter { - 2.0 * self.radius.y() * self.exp_y) } - fn sample(&self, u: Point2f) -> FilterSample { + fn sample(&self, u: Point2f) -> DeviceFilterSample { self.sampler.sample(u) } } diff --git a/shared/src/filters/lanczos.rs b/shared/src/filters/lanczos.rs index 700c110..a22db3d 100644 --- a/shared/src/filters/lanczos.rs +++ b/shared/src/filters/lanczos.rs @@ -1,5 +1,5 @@ use crate::Float; -use crate::core::filter::{FilterSample, FilterSampler, FilterTrait}; +use crate::core::filter::{DeviceFilterSample, FilterSampler, FilterTrait}; use crate::core::geometry::{Point2f, Vector2f}; use crate::utils::math::{lerp, windowed_sinc}; @@ -26,7 +26,7 @@ impl FilterTrait for LanczosSincFilter { self.integral } - fn sample(&self, u: Point2f) -> FilterSample { + fn sample(&self, u: Point2f) -> DeviceFilterSample { self.sampler.sample(u) } } diff --git a/shared/src/filters/mitchell.rs b/shared/src/filters/mitchell.rs index 03425ad..c17ceb0 100644 --- a/shared/src/filters/mitchell.rs +++ b/shared/src/filters/mitchell.rs @@ -1,5 +1,5 @@ use crate::Float; -use crate::core::filter::{FilterSample, FilterSampler, FilterTrait}; +use crate::core::filter::{DeviceFilterSample, FilterSampler, FilterTrait}; use crate::core::geometry::{Point2f, Vector2f}; use num_traits::Float as NumFloat; @@ -9,7 +9,7 @@ pub struct MitchellFilter { pub radius: Vector2f, pub b: Float, pub c: Float, - pub sampler: FilterSampler, + pub sampler: DeviceFilterSampler, } impl MitchellFilter { @@ -50,7 +50,7 @@ impl FilterTrait for MitchellFilter { self.radius.x() * self.radius.y() / 4.0 } - fn sample(&self, u: Point2f) -> FilterSample { + fn sample(&self, u: Point2f) -> DeviceFilterSample { self.sampler.sample(u) } } diff --git a/shared/src/filters/triangle.rs b/shared/src/filters/triangle.rs index 1572ca5..b014fce 100644 --- a/shared/src/filters/triangle.rs +++ b/shared/src/filters/triangle.rs @@ -1,5 +1,5 @@ use crate::Float; -use crate::core::filter::{FilterSample, FilterTrait}; +use crate::core::filter::{DeviceFilterSample, FilterTrait}; use crate::core::geometry::{Point2f, Vector2f}; use crate::utils::math::sample_tent; use num_traits::Float as NumFloat; @@ -29,11 +29,11 @@ impl FilterTrait for TriangleFilter { self.radius.x().powi(2) * self.radius.y().powi(2) } - fn sample(&self, u: Point2f) -> FilterSample { + fn sample(&self, u: Point2f) -> DeviceFilterSample { let p = Point2f::new( sample_tent(u[0], self.radius.x()), sample_tent(u[1], self.radius.y()), ); - FilterSample { p, weight: 1.0 } + DeviceFilterSample { p, weight: 1.0 } } } diff --git a/shared/src/utils/mod.rs b/shared/src/utils/mod.rs index 8207bdc..a18ea95 100644 --- a/shared/src/utils/mod.rs +++ b/shared/src/utils/mod.rs @@ -18,6 +18,15 @@ pub use options::PBRTOptions; pub use ptr::Ptr; pub use transform::{AnimatedTransform, Transform, TransformGeneric}; +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::{ + parse_macro_input, Attribute, Data, DeriveInput, Expr, Fields, GenericArgument, Ident, Lit, + PathArguments, Type, Variant, +}; + + use crate::Float; use core::sync::atomic::{AtomicU32, Ordering}; @@ -128,3 +137,515 @@ pub fn gpu_array_from_fn(mut f: impl FnMut(usize) -> T) -> [T arr.assume_init() } } + +/// # Enum variant attributes +/// +/// | Attribute | Effect | +/// |-----------|--------| +/// | *(none)* | Inner type has `DeviceRepr`; auto-call `upload_value` | +/// | `#[device(clone)]` | Same type on both sides, just clone | +/// | `#[device(custom = "method")]` | You provide `fn method(inner: &T, arena) -> DeviceT` | +/// | `#[device(variant_type = "T")]` | Override the device-side variant's inner type | +/// +/// # Container attribute +/// +/// `#[device(name = "DeviceFoo")]` — override the generated type name (default: `Device{Name}`). +#[proc_macro_derive(Device, attributes(device))] +pub fn derive_device(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match derive_impl(input) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} + +fn derive_impl(input: DeriveInput) -> syn::Result { + match &input.data { + Data::Struct(_) => derive_struct(input), + Data::Enum(_) => derive_enum(input), + Data::Union(_) => Err(syn::Error::new_spanned( + &input.ident, + "Device derive does not support unions", + )), + } +} + +// Struct derivation + +fn derive_struct(input: DeriveInput) -> syn::Result { + let host_name = &input.ident; + let vis = &input.vis; + let device_name = get_device_name(&input.attrs, host_name)?; + + let fields = match &input.data { + Data::Struct(s) => match &s.fields { + Fields::Named(named) => &named.named, + _ => { + return Err(syn::Error::new_spanned( + host_name, + "Device derive only supports structs with named fields", + )) + } + }, + _ => unreachable!(), + }; + + let mut device_fields = Vec::new(); + let mut upload_stmts = Vec::new(); + let mut device_field_inits = Vec::new(); + let mut spread_expr: Option = None; + + for field in fields { + let field_name = field.ident.as_ref().unwrap(); + let attrs = parse_field_attrs(&field.attrs)?; + + if attrs.skip { + continue; + } + + if let Some(ref expr_str) = attrs.spread { + spread_expr = Some(syn::parse_str(expr_str).map_err(|e| { + syn::Error::new_spanned(field, format!("invalid device(spread): {}", e)) + })?); + continue; + } + + if let Some(expr_str) = &attrs.expr { + let expr: Expr = syn::parse_str(expr_str).map_err(|e| { + syn::Error::new_spanned(field, format!("invalid device(expr): {}", e)) + })?; + let ty = &field.ty; + device_fields.push(quote! { pub #field_name: #ty }); + upload_stmts.push(quote! { let #field_name = #expr; }); + device_field_inits.push(quote! { #field_name }); + continue; + } + + match classify_type(&field.ty) { + FieldClass::VecCopy(inner_ty) => { + let len_name = format_ident!("{}_len", field_name); + device_fields.push(quote! { pub #field_name: Ptr<#inner_ty> }); + device_fields.push(quote! { pub #len_name: usize }); + upload_stmts.push(quote! { + let (#field_name, #len_name) = arena.alloc_slice(&self.#field_name); + }); + device_field_inits.push(quote! { #field_name }); + device_field_inits.push(quote! { #len_name }); + } + FieldClass::VecUploadable(inner_ty) => { + let len_name = format_ident!("{}_len", field_name); + device_fields.push(quote! { + pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target> + }); + device_fields.push(quote! { pub #len_name: usize }); + upload_stmts.push(quote! { + let __up: Vec<<#inner_ty as DeviceRepr>::Target> = self.#field_name + .iter() + .map(|item| DeviceRepr::upload_value(item, arena)) + .collect(); + let (#field_name, #len_name) = arena.alloc_slice(&__up); + }); + device_field_inits.push(quote! { #field_name }); + device_field_inits.push(quote! { #len_name }); + } + FieldClass::Option(inner_ty) => { + device_fields.push(quote! { + pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target> + }); + upload_stmts.push(quote! { + let #field_name = match &self.#field_name { + Some(val) => DeviceRepr::upload(val, arena), + None => Ptr::null(), + }; + }); + device_field_inits.push(quote! { #field_name }); + } + FieldClass::Arc(inner_ty) => { + if attrs.flatten { + device_fields.push(quote! { + pub #field_name: <#inner_ty as DeviceRepr>::Target + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload_value(&*self.#field_name, arena); + }); + } else { + device_fields.push(quote! { + pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target> + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload(&*self.#field_name, arena); + }); + } + device_field_inits.push(quote! { #field_name }); + } + FieldClass::Plain => { + let ty = &field.ty; + if attrs.copy_upload { + device_fields.push(quote! { pub #field_name: #ty }); + upload_stmts.push(quote! { + let #field_name = self.#field_name.clone(); + }); + } else if attrs.flatten { + device_fields.push(quote! { + pub #field_name: <#ty as DeviceRepr>::Target + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload_value(&self.#field_name, arena); + }); + } else if attrs.upload { + device_fields.push(quote! { + pub #field_name: Ptr<<#ty as DeviceRepr>::Target> + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload(&self.#field_name, arena); + }); + } else { + device_fields.push(quote! { pub #field_name: #ty }); + upload_stmts.push(quote! { + let #field_name = self.#field_name; + }); + } + device_field_inits.push(quote! { #field_name }); + } + } + } + + let constructor = if let Some(spread) = spread_expr { + quote! { + #device_name { + #(#device_field_inits,)* + ..#spread + } + } + } else { + quote! { + #device_name { + #(#device_field_inits,)* + } + } + }; + + Ok(quote! { + #[repr(C)] + #[derive(Debug, Copy, Clone)] + #vis struct #device_name { + #(#device_fields,)* + } + + unsafe impl Send for #device_name {} + unsafe impl Sync for #device_name {} + + impl DeviceRepr for #host_name { + type Target = #device_name; + + fn upload_value(&self, arena: &Arena) -> Self::Target { + #(#upload_stmts)* + #constructor + } + } + }) +} + +// Enum derivation +fn derive_enum(input: DeriveInput) -> syn::Result { + let host_name = &input.ident; + let vis = &input.vis; + let device_name = get_device_name(&input.attrs, host_name)?; + + let variants = match &input.data { + Data::Enum(e) => &e.variants, + _ => unreachable!(), + }; + + let mut device_variants = Vec::new(); + let mut match_arms = Vec::new(); + + for variant in variants { + let var_name = &variant.ident; + let var_attrs = parse_variant_attrs(&variant.attrs)?; + let inner_ty = get_variant_inner_type(variant)?; + + // Determine the device-side inner type for this variant + let device_inner: Type = if let Some(ref ty_str) = var_attrs.variant_type { + syn::parse_str(ty_str).map_err(|e| { + syn::Error::new_spanned(variant, format!("invalid variant_type: {}", e)) + })? + } else if var_attrs.clone_variant { + // clone: same type on both sides + inner_ty.clone() + } else { + // auto-upload: use DeviceRepr::Target + syn::parse_str(&format!("<{} as DeviceRepr>::Target", quote!(#inner_ty))).map_err( + |e| { + syn::Error::new_spanned(variant, format!("cannot construct Target type: {}", e)) + }, + )? + }; + + device_variants.push(quote! { #var_name(#device_inner) }); + + if var_attrs.clone_variant { + match_arms.push(quote! { + #host_name::#var_name(inner) => #device_name::#var_name(inner.clone()) + }); + } else if let Some(ref method) = var_attrs.custom { + let method_ident = format_ident!("{}", method); + match_arms.push(quote! { + #host_name::#var_name(inner) => { + #device_name::#var_name(Self::#method_ident(inner, arena)) + } + }); + } else { + // Default: inner implements DeviceRepr + match_arms.push(quote! { + #host_name::#var_name(inner) => { + #device_name::#var_name(DeviceRepr::upload_value(inner, arena)) + } + }); + } + } + + Ok(quote! { + #[repr(C)] + #[derive(Debug, Copy, Clone)] + #vis enum #device_name { + #(#device_variants,)* + } + + unsafe impl Send for #device_name {} + unsafe impl Sync for #device_name {} + + impl DeviceRepr for #host_name { + type Target = #device_name; + + fn upload_value(&self, arena: &Arena) -> Self::Target { + match self { + #(#match_arms,)* + } + } + } + }) +} + +fn get_variant_inner_type(variant: &Variant) -> syn::Result { + match &variant.fields { + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + Ok(fields.unnamed.first().unwrap().ty.clone()) + } + Fields::Unit => Err(syn::Error::new_spanned( + variant, + "Device derive: enum variants must have exactly one field, e.g. Variant(Type)", + )), + _ => Err(syn::Error::new_spanned( + variant, + "Device derive: only single-field tuple variants supported, e.g. Variant(Type)", + )), + } +} + +// Attribute parsing for variants +struct VariantAttrs { + clone_variant: bool, + custom: Option, + variant_type: Option, +} + +fn parse_variant_attrs(attrs: &[Attribute]) -> syn::Result { + let mut result = VariantAttrs { + clone_variant: false, + custom: None, + variant_type: None, + }; + + for attr in attrs { + if !attr.path().is_ident("device") { + continue; + } + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("clone") { + result.clone_variant = true; + Ok(()) + } else if meta.path.is_ident("custom") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.custom = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else if meta.path.is_ident("variant_type") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.variant_type = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else { + Err(meta.error("unknown device variant attribute")) + } + })?; + } + + Ok(result) +} + +// Attribute parsing for fields +struct FieldAttrs { + skip: bool, + expr: Option, + copy_upload: bool, + flatten: bool, + upload: bool, + spread: Option, +} + +fn parse_field_attrs(attrs: &[Attribute]) -> syn::Result { + let mut result = FieldAttrs { + skip: false, + expr: None, + copy_upload: false, + flatten: false, + upload: false, + spread: None, + }; + + for attr in attrs { + if !attr.path().is_ident("device") { + continue; + } + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("skip") { + result.skip = true; + Ok(()) + } else if meta.path.is_ident("expr") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.expr = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else if meta.path.is_ident("copy_upload") { + result.copy_upload = true; + Ok(()) + } else if meta.path.is_ident("flatten") { + result.flatten = true; + Ok(()) + } else if meta.path.is_ident("upload") { + result.upload = true; + Ok(()) + } else if meta.path.is_ident("spread") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.spread = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else { + Err(meta.error("unknown device attribute")) + } + })?; + } + + Ok(result) +} + +// Container-level name attribute +fn get_device_name(attrs: &[Attribute], host_name: &Ident) -> syn::Result { + for attr in attrs { + if !attr.path().is_ident("device") { + continue; + } + let mut name = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + name = Some(format_ident!("{}", s.value())); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else { + Ok(()) + } + })?; + if let Some(n) = name { + return Ok(n); + } + } + Ok(format_ident!("Device{}", host_name)) +} + +// Type classification +enum FieldClass { + VecCopy(Type), + VecUploadable(Type), + Option(Type), + Arc(Type), + Plain, +} + +fn classify_type(ty: &Type) -> FieldClass { + if let Some(inner) = extract_generic_arg(ty, "Vec") { + if is_copy_primitive(&inner) { + FieldClass::VecCopy(inner) + } else { + FieldClass::VecUploadable(inner) + } + } else if let Some(inner) = extract_generic_arg(ty, "Option") { + FieldClass::Option(inner) + } else if let Some(inner) = extract_generic_arg(ty, "Arc") { + FieldClass::Arc(inner) + } else { + FieldClass::Plain + } +} + +fn extract_generic_arg(ty: &Type, wrapper: &str) -> Option { + if let Type::Path(type_path) = ty { + let seg = type_path.path.segments.last()?; + if seg.ident != wrapper { + return None; + } + if let PathArguments::AngleBracketed(args) = &seg.arguments { + if let Some(GenericArgument::Type(inner)) = args.args.first() { + return Some(inner.clone()); + } + } + } + None +} + +fn is_copy_primitive(ty: &Type) -> bool { + if let Type::Path(type_path) = ty { + if let Some(seg) = type_path.path.segments.last() { + let name = seg.ident.to_string(); + return matches!( + name.as_str(), + "f32" + | "f64" + | "u8" + | "u16" + | "u32" + | "u64" + | "i8" + | "i16" + | "i32" + | "i64" + | "usize" + | "isize" + | "bool" + | "Float" + ); + } + } + false +} diff --git a/shared/src/utils/sampling.rs b/shared/src/utils/sampling.rs index dffe01a..5a8bd2a 100644 --- a/shared/src/utils/sampling.rs +++ b/shared/src/utils/sampling.rs @@ -695,8 +695,8 @@ impl VarianceEstimator { } } -#[derive(Debug, Copy, Clone, Default)] #[repr(C)] +#[derive(Debug, Copy, Clone, Default)] pub struct PLSample { pub p: Point2f, pub pdf: Float, diff --git a/src/core/aggregates.rs b/src/core/aggregates.rs index 4cb8dea..3fda232 100644 --- a/src/core/aggregates.rs +++ b/src/core/aggregates.rs @@ -1,5 +1,5 @@ use rayon::prelude::*; -use shared::core::aggregates::DeviceBVHAggregate; +use shared::core::aggregates::{{DeviceBVHAggregate, LinearBVHNode}; use shared::core::geometry::{Bounds3f, Point3f, Ray, Vector3f}; use shared::core::primitive::{Primitive, PrimitiveTrait}; use shared::core::shape::ShapeIntersection; @@ -26,14 +26,6 @@ struct BVHSplitBucket { pub bounds: Bounds3f, } -#[derive(Debug, Clone, Default)] -pub struct LinearBVHNode { - pub bounds: Bounds3f, - pub primitives_offset: usize, - pub n_primitives: u16, - pub axis: u8, - pub pad: u8, -} #[derive(Debug, Clone, Copy, Default)] struct MortonPrimitive { @@ -48,7 +40,7 @@ struct LBVHTreelet { #[derive(Debug, Clone)] pub struct BVHPrimitiveInfo { - primitive_number: usize, // Index into the original primitives vector + primitive_number: usize, bounds: Bounds3f, centroid: Point3f, } diff --git a/src/core/filter.rs b/src/core/filter.rs index ec7ba5c..b8874d2 100644 --- a/src/core/filter.rs +++ b/src/core/filter.rs @@ -1,12 +1,13 @@ use crate::filters::*; use crate::utils::containers::Array2D; use crate::utils::sampling::PiecewiseConstant2D; +use crate::utils::DeviceRepr; use crate::utils::{FileLoc, ParameterDictionary}; -use anyhow::{Result, anyhow}; -use shared::Float; -use shared::core::filter::{Filter, FilterSampler}; +use anyhow::{anyhow, Result}; +use shared::core::filter::{DeviceFilterSampler, Filter}; use shared::core::geometry::{Bounds2f, Point2f, Vector2f}; use shared::filters::*; +use shared::Float; pub trait FilterFactory { fn create(name: &str, params: &ParameterDictionary, loc: &FileLoc) -> Result; @@ -54,14 +55,16 @@ impl FilterFactory for Filter { } } -pub trait CreateFilterSampler { - fn new(radius: Vector2f, func: F) -> Self - where - F: Fn(Point2f) -> Float; +#[repr(C)] +#[derive(Clone, Debug, Copy)] +pub struct FilterSampler { + pub domain: Bounds2f, + pub distrib: PiecewiseConstant2D, + pub f: Array2D, } -impl CreateFilterSampler for FilterSampler { - fn new(radius: Vector2f, func: F) -> Self +impl FilterSampler { + pub fn new(radius: Vector2f, func: F) -> Self where F: Fn(Point2f) -> Float, { @@ -72,7 +75,6 @@ impl CreateFilterSampler for FilterSampler { let nx = (32.0 * radius.x()) as i32; let ny = (32.0 * radius.y()) as i32; - let mut f = Array2D::new_dims(nx, ny); for y in 0..f.y_size() { for x in 0..f.x_size() { @@ -83,11 +85,21 @@ impl CreateFilterSampler for FilterSampler { f[(x as i32, y as i32)] = func(p); } } + let distrib = PiecewiseConstant2D::new_with_bounds(&f, domain); - Self { - domain, - f: *f.device(), - distrib: distrib.device, + + Self { domain, distrib, f } + } +} + +impl DeviceRepr for FilterSampler { + type Target = DeviceFilterSampler; + + fn upload_value(&self, arena: &Arena) -> DeviceFilterSampler { + DeviceFilterSampler { + domain: self.domain, + distrib: self.distrib.upload_value(arena), + f: self.f.upload_value(arena), } } } diff --git a/src/filters/gaussian.rs b/src/filters/gaussian.rs index 3954837..0267a54 100644 --- a/src/filters/gaussian.rs +++ b/src/filters/gaussian.rs @@ -5,27 +5,26 @@ use shared::core::geometry::{Point2f, Vector2f}; use shared::filters::GaussianFilter; use shared::utils::math::gaussian; -pub trait GaussianFilterCreator { - fn new(radius: Vector2f, sigma: Float) -> Self; +#[derive(Clone, Debug, Device)] +#[device(name = "GaussianFilter")] +pub struct GaussianFilterHost { + pub radius: Vector2f, + pub sigma: Float, + pub exp_x: Float, + pub exp_y: Float, + #[device(flatten)] + pub sampler: FilterSampler, } -impl GaussianFilterCreator for GaussianFilter { - fn new(radius: Vector2f, sigma: Float) -> Self { - let exp_x = gaussian(radius.x(), 0., sigma); - let exp_y = gaussian(radius.y(), 0., sigma); - +impl GaussianFilterHost { + pub fn new(radius: Vector2f, sigma: Float) -> Self { + let exp_x = gaussian(radius.x(), 0.0, sigma); + let exp_y = gaussian(radius.y(), 0.0, sigma); let sampler = FilterSampler::new(radius, move |p: Point2f| { - let gx = (gaussian(p.x(), 0., sigma) - exp_x).max(0.0); - let gy = (gaussian(p.y(), 0., sigma) - exp_y).max(0.0); + let gx = (gaussian(p.x(), 0.0, sigma) - exp_x).max(0.0); + let gy = (gaussian(p.y(), 0.0, sigma) - exp_y).max(0.0); gx * gy }); - - Self { - radius, - sigma, - exp_x: gaussian(radius.x(), 0., sigma), - exp_y: gaussian(radius.y(), 0., sigma), - sampler, - } + Self { radius, sigma, exp_x, exp_y, sampler } } } diff --git a/src/filters/lanczos.rs b/src/filters/lanczos.rs index d5340fa..4497bf0 100644 --- a/src/filters/lanczos.rs +++ b/src/filters/lanczos.rs @@ -6,43 +6,35 @@ use shared::core::geometry::{Point2f, Vector2f}; use shared::filters::LanczosSincFilter; use shared::utils::math::{lerp, windowed_sinc}; -pub trait LanczosFilterCreator { - fn new(radius: Vector2f, tau: Float) -> Self; +#[derive(Clone, Debug, Device)] +#[device(name = "LanczosSincFilter")] +pub struct LanczosSincFilterHost { + pub radius: Vector2f, + pub tau: Float, + #[device(flatten)] + pub sampler: FilterSampler, } -impl LanczosFilterCreator for LanczosSincFilter { - fn new(radius: Vector2f, tau: Float) -> Self { - let evaluate = |p: Point2f| -> Float { +impl LanczosSincFilterHost { + pub fn new(radius: Vector2f, tau: Float) -> Self { + let sampler = FilterSampler::new(radius, move |p: Point2f| { windowed_sinc(p.x(), radius.x(), tau) * windowed_sinc(p.y(), radius.y(), tau) - }; - - let sampler = FilterSampler::new(radius, evaluate); - let sqrt_samples = 64; - let n_samples = sqrt_samples * sqrt_samples; - let area = (2.0 * radius.x()) * (2.0 * radius.y()); - let mut sum = 0.0; - let mut rng = rand::rng(); - - for y in 0..sqrt_samples { - for x in 0..sqrt_samples { - let u = Point2f::new( - (x as Float + rng.random::()) / sqrt_samples as Float, - (y as Float + rng.random::()) / sqrt_samples as Float, - ); - let p = Point2f::new( - lerp(u.x(), -radius.x(), radius.x()), - lerp(u.y(), -radius.y(), radius.y()), - ); - sum += evaluate(p); - } - } - let integral = sum / n_samples as Float * area; - - Self { - radius, - tau, - sampler, - integral, - } + }); + Self { radius, tau, sampler } + } +} + +fn windowed_sinc(x: Float, radius: Float, tau: Float) -> Float { + use std::f32::consts::PI; + let x = x.abs(); + if x > radius { + return 0.0; + } + if x < 1e-5 { + 1.0 + } else { + let xpi = x * PI; + let xpit = xpi * tau; + (xpi.sin() / xpi) * (xpit.sin() / xpit) } } diff --git a/src/filters/mitchell.rs b/src/filters/mitchell.rs index 56b9c1b..a8b2626 100644 --- a/src/filters/mitchell.rs +++ b/src/filters/mitchell.rs @@ -4,23 +4,39 @@ use shared::core::filter::FilterSampler; use shared::core::geometry::{Point2f, Vector2f}; use shared::filters::MitchellFilter; -pub trait MitchellFilterCreator { - fn new(radius: Vector2f, b: Float, c: Float) -> Self; +#[derive(Clone, Debug, Device)] +#[device(name = "MitchellFilter")] +pub struct MitchellFilterHost { + pub radius: Vector2f, + pub b: Float, + pub c: Float, + #[device(flatten)] + pub sampler: FilterSampler, } -impl MitchellFilterCreator for MitchellFilter { - fn new(radius: Vector2f, b: Float, c: Float) -> Self { +impl MitchellFilterHost { + pub fn new(radius: Vector2f, b: Float, c: Float) -> Self { let sampler = FilterSampler::new(radius, move |p: Point2f| { - let nx = 2.0 * p.x() / radius.x(); - let ny = 2.0 * p.y() / radius.y(); - Self::mitchell_1d_eval(b, c, nx) * Self::mitchell_1d_eval(b, c, ny) + mitchell_1d(p.x() / radius.x(), b, c) * mitchell_1d(p.y() / radius.y(), b, c) }); - - Self { - radius, - b, - c, - sampler, - } + Self { radius, b, c, sampler } + } +} + +fn mitchell_1d(x: Float, b: Float, c: Float) -> Float { + let x = (2.0 * x).abs(); + if x <= 1.0 { + ((12.0 - 9.0 * b - 6.0 * c) * x * x * x + + (-18.0 + 12.0 * b + 6.0 * c) * x * x + + (6.0 - 2.0 * b)) + / 6.0 + } else if x <= 2.0 { + ((-b - 6.0 * c) * x * x * x + + (6.0 * b + 30.0 * c) * x * x + + (-12.0 * b - 48.0 * c) * x + + (8.0 * b + 24.0 * c)) + / 6.0 + } else { + 0.0 } } diff --git a/src/shapes/mesh.rs b/src/shapes/mesh.rs index 9b761c5..ee62212 100644 --- a/src/shapes/mesh.rs +++ b/src/shapes/mesh.rs @@ -21,16 +21,18 @@ pub struct TriQuadMesh { pub quad_indices: Vec, } -#[derive(Debug)] -pub(crate) struct TriangleMeshStorage { - pub p: Vec, - pub n: Vec, - pub s: Vec, - pub uv: Vec, - pub vertex_indices: Vec, - pub face_indices: Vec, +#[derive(DeviceRepr)] +#[device(name = "DeviceTriangleMesh")] +pub struct TriangleMeshStorage { + pub vertex_indices: Vec, // → Ptr + len (always present) + pub p: Vec, // → Ptr + len (always present) + pub n: Vec, // → Ptr + len (empty → null Ptr, len 0) + pub s: Vec, // → Ptr + len + pub uv: Vec, // → Ptr + len + pub face_indices: Vec, // → Ptr + len } + #[derive(Debug)] pub(crate) struct BilinearMeshStorage { pub vertex_indices: Vec, diff --git a/src/utils/arena.rs b/src/utils/arena.rs index f6267f2..485a183 100644 --- a/src/utils/arena.rs +++ b/src/utils/arena.rs @@ -1,147 +1,148 @@ -use crate::core::image::Image; -use crate::core::texture::{FloatTexture, SpectrumTexture}; -use crate::shapes::{BilinearPatchMesh, TriangleMesh}; -use crate::spectra::DenselySampledSpectrumBuffer; use crate::utils::backend::GpuAllocator; -use crate::utils::mipmap::MIPMap; -use crate::utils::sampling::{PiecewiseConstant2D, WindowedPiecewiseConstant2D}; -use parking_lot::Mutex; -use shared::core::color::RGBToSpectrumTable; -use shared::core::image::DeviceImage; -use shared::core::light::Light; -use shared::core::material::Material; use shared::core::shape::Shape; +use shared::core::material::Material; use shared::core::spectrum::Spectrum; -use shared::core::texture::{GPUFloatTexture, GPUSpectrumTexture}; -use shared::spectra::{DenselySampledSpectrum, DeviceStandardColorSpaces, RGBColorSpace}; -use shared::textures::*; -use shared::utils::mesh::{DeviceBilinearPatchMesh, DeviceTriangleMesh}; -use shared::utils::sampling::{ - DevicePiecewiseConstant1D, DevicePiecewiseConstant2D, DeviceWindowedPiecewiseConstant2D, -}; -use shared::utils::Ptr; +use parking_lot::Mutex; +use shared::Ptr; use std::alloc::Layout; use std::collections::HashMap; use std::panic::Location; -use std::slice::from_raw_parts; use std::sync::Arc; -pub struct Arena { +struct Chunk { + ptr: *mut u8, + layout: Layout, +} + +struct GpuBump { allocator: A, - inner: Mutex, + current: *mut u8, + end: *mut u8, + chunks: Vec, } -struct ArenaInner { - blocks: Vec<(*mut u8, Layout)>, - current_block: *mut u8, - current_offset: usize, - current_capacity: usize, - current_align: usize, - texture_cache: HashMap, +const CHUNK_SIZE: usize = 256 * 1024; + +impl GpuBump { + fn new(allocator: A) -> Self { + Self { + allocator, + current: std::ptr::null_mut(), + end: std::ptr::null_mut(), + chunks: Vec::new(), + } + } + + fn alloc(&mut self, value: T) -> *mut T { + let layout = Layout::new::(); + let ptr = self.alloc_layout(layout) as *mut T; + unsafe { ptr.write(value) }; + ptr + } + + fn alloc_slice(&mut self, values: &[T]) -> (*mut T, usize) { + if values.is_empty() { + return (std::ptr::null_mut(), 0); + } + let layout = Layout::array::(values.len()).unwrap(); + let ptr = self.alloc_layout(layout) as *mut T; + unsafe { + std::ptr::copy_nonoverlapping(values.as_ptr(), ptr, values.len()); + } + (ptr, values.len()) + } + + fn alloc_layout(&mut self, layout: Layout) -> *mut u8 { + let size = layout.size(); + let align = layout.align(); + + if size == 0 { + return align as *mut u8; + } + + // Fast path: bump from current chunk + let start = self.current as usize; + let aligned = (start + align - 1) & !(align - 1); + let end = aligned + size; + + if end <= self.end as usize { + self.current = end as *mut u8; + return aligned as *mut u8; + } + + let chunk_size = if size > CHUNK_SIZE { + size.next_multiple_of(align.max(16)) + } else { + CHUNK_SIZE.max(size.next_multiple_of(align.max(16))) + }; + + let chunk_layout = Layout::from_size_align(chunk_size, align.max(16)).unwrap(); + let chunk = unsafe { self.allocator.alloc(chunk_layout) }; + + let caller = Location::caller(); + if chunk.is_null() { + panic!( + "GpuBump OOM {} {}: chunk_size={} align={} backend={}", + caller.file(), + caller.line(), + chunk_size, + chunk_layout.align(), + std::any::type_name::() + ); + } + + self.chunks.push(Chunk { + ptr: chunk, + layout: chunk_layout, + }); + + // If this object alone fills the chunk, mark it consumed and return. + if size > CHUNK_SIZE { + self.current = unsafe { chunk.add(chunk_size) }; + self.end = self.current; + return chunk; + } + + // Set up bump pointers inside the new chunk. + self.current = chunk; + self.end = unsafe { chunk.add(chunk_size) }; + + let aligned = (self.current as usize + align - 1) & !(align - 1); + self.current = (aligned + size) as *mut u8; + aligned as *mut u8 + } } -const DEFAULT_BLOCK_SIZE: usize = 256 * 1024; +impl Drop for GpuBump { + fn drop(&mut self) { + for chunk in self.chunks.drain(..) { + unsafe { self.allocator.dealloc(chunk.ptr, chunk.layout) }; + } + } +} + +unsafe impl Send for GpuBump {} +unsafe impl Sync for GpuBump {} + +pub struct Arena { + bump: Mutex>, + texture_cache: Mutex>, +} impl Arena { pub fn new(allocator: A) -> Self { Self { - allocator, - inner: Mutex::new(ArenaInner { - blocks: Vec::new(), - current_block: std::ptr::null_mut(), - current_offset: 0, - current_capacity: 0, - current_align: 1, - texture_cache: HashMap::new(), - }), + bump: Mutex::new(GpuBump::new(allocator)), + texture_cache: Mutex::new(HashMap::new()), } } pub fn alloc(&self, value: T) -> Ptr { - let layout = Layout::new::(); - let mut inner = self.inner.lock(); - - let aligned = (inner.current_offset + layout.align() - 1) & !(layout.align() - 1); - - // Checking if current block alignment is sufficient - if aligned + layout.size() > inner.current_capacity || inner.current_align < layout.align() - { - let block_size = DEFAULT_BLOCK_SIZE.max(layout.size() * 2); - let block_layout = Layout::from_size_align(block_size, layout.align().max(16)).unwrap(); - let block = unsafe { self.allocator.alloc(block_layout) }; - - // null check - if block.is_null() { - panic!( - "Arena alloc failed at {}:{} — size={} align={}", - caller.file(), - caller.line(), - block_layout.size(), - block_layout.align() - ); - } - - inner.blocks.push((block, block_layout)); - inner.current_block = block; - inner.current_offset = 0; - inner.current_capacity = block_size; - inner.current_align = block_layout.align(); // NEW - - let aligned = (0 + layout.align() - 1) & !(layout.align() - 1); - let ptr = unsafe { inner.current_block.add(aligned) as *mut T }; - unsafe { ptr.write(value) }; - inner.current_offset = aligned + layout.size(); - return Ptr::from_raw(ptr); - } - - let ptr = unsafe { inner.current_block.add(aligned) as *mut T }; - unsafe { ptr.write(value) }; - inner.current_offset = aligned + layout.size(); + let mut bump = self.bump.lock(); + let ptr = bump.alloc(value); Ptr::from_raw(ptr) } - pub fn alloc_slice(&self, values: &[T]) -> (Ptr, usize) { - if values.is_empty() { - return (Ptr::null(), 0); - } - let layout = Layout::array::(values.len()).unwrap(); - let mut inner = self.inner.lock(); - - let aligned = (inner.current_offset + layout.align() - 1) & !(layout.align() - 1); - - if aligned + layout.size() > inner.current_capacity || inner.current_align < layout.align() - { - let block_size = DEFAULT_BLOCK_SIZE.max(layout.size() * 2); - let block_layout = Layout::from_size_align(block_size, layout.align().max(16)).unwrap(); - let block = unsafe { self.allocator.alloc(block_layout) }; - - if block.is_null() { - panic!( - "Arena allocation failed: size={} align={}", - block_layout.size(), - block_layout.align() - ); - } - - inner.blocks.push((block, block_layout)); - inner.current_block = block; - inner.current_offset = 0; - inner.current_capacity = block_size; - inner.current_align = block_layout.align(); // NEW - - let aligned = 0; - let ptr = unsafe { inner.current_block.add(aligned) as *mut T }; - unsafe { std::ptr::copy_nonoverlapping(values.as_ptr(), ptr, values.len()) }; - inner.current_offset = aligned + layout.size(); - return (Ptr::from_raw(ptr), values.len()); - } - - let ptr = unsafe { inner.current_block.add(aligned) as *mut T }; - unsafe { std::ptr::copy_nonoverlapping(values.as_ptr(), ptr, values.len()) }; - inner.current_offset = aligned + layout.size(); - (Ptr::from_raw(ptr), values.len()) - } - pub fn alloc_opt(&self, value: Option) -> Ptr { match value { Some(v) => self.alloc(v), @@ -149,20 +150,20 @@ impl Arena { } } + pub fn alloc_slice(&self, values: &[T]) -> (Ptr, usize) { + let mut bump = self.bump.lock(); + let (ptr, len) = bump.alloc_slice(values); + (Ptr::from_raw(ptr), len) + } + pub fn get_texture_object(&self, mipmap: &Arc) -> u64 { let key = Arc::as_ptr(mipmap) as usize; - let mut inner = self.inner.lock(); - - if let Some(&tex_obj) = inner.texture_cache.get(&key) { + let mut cache = self.texture_cache.lock(); + if let Some(&tex_obj) = cache.get(&key) { return tex_obj; } - - // TODO: Backend-specific texture object creation. - // CUDA: cudaCreateTextureObject - // Vulkan: VkImageView + VkSampler -> descriptor index - let tex_obj = 0u64; - - inner.texture_cache.insert(key, tex_obj); + let tex_obj = 0u64; // TODO: backend-specific creation + cache.insert(key, tex_obj); tex_obj } } @@ -173,76 +174,279 @@ impl Default for Arena { } } -impl Drop for Arena { - fn drop(&mut self) { - let inner = self.inner.get_mut(); - for (ptr, layout) in inner.blocks.drain(..) { - unsafe { self.allocator.dealloc(ptr, layout) }; +pub trait DeviceRepr { + /// The `#[repr(C)] Copy` device-side struct. + type Target: Copy; + + /// Upload into the arena and return the device struct by value. + /// Use this when embedding the result inline in another device struct. + fn upload_value(&self, arena: &Arena) -> Self::Target; + + /// Upload into the arena and return a Ptr to the device struct. + /// This is the common entry point — allocates the Target in the arena. + fn upload(&self, arena: &Arena) -> Ptr { + let value = self.upload_value(arena); + arena.alloc(value) + } +} + +impl DeviceRepr for Option { + type Target = T::Target; + + fn upload_value(&self, arena: &Arena) -> Self::Target { + match self { + Some(val) => val.upload_value(arena), + None => panic!("Cannot upload_value on None — use upload() which returns Ptr::null()"), + } + } + + fn upload(&self, arena: &Arena) -> Ptr { + match self { + Some(val) => val.upload(arena), + None => Ptr::null(), } } } -unsafe impl Send for Arena {} -unsafe impl Sync for Arena {} +impl DeviceRepr for std::sync::Arc { + type Target = T::Target; -pub trait Upload { - type Target: Copy; + fn upload_value(&self, arena: &Arena) -> Self::Target { + (**self).upload_value(arena) + } - fn upload(&self, arena: &Arena) -> Ptr; + fn upload(&self, arena: &Arena) -> Ptr { + (**self).upload(arena) + } } -impl Upload for Shape { +impl DeviceRepr for Box { + type Target = T::Target; + + fn upload_value(&self, arena: &Arena) -> Self::Target { + (**self).upload_value(arena) + } + + fn upload(&self, arena: &Arena) -> Ptr { + (**self).upload(arena) + } +} + +impl DeviceRepr for Shape { type Target = Shape; - fn upload(&self, arena: &Arena) -> Ptr { - arena.alloc(self.clone()) + fn upload_value(&self, _arena: &Arena) -> Shape { + self.clone() } } -impl Upload for Light { +impl DeviceRepr for Light { type Target = Light; - fn upload(&self, arena: &Arena) -> Ptr { - arena.alloc(self.clone()) + fn upload_value(&self, _arena: &Arena) -> Light { + self.clone() } } -impl Upload for Image { - type Target = DeviceImage; - fn upload(&self, arena: &Arena) -> Ptr { - arena.alloc(*self.device()) - } -} - -impl Upload for Spectrum { +impl DeviceRepr for Spectrum { type Target = Spectrum; - fn upload(&self, arena: &Arena) -> Ptr { - arena.alloc(self.clone()) + fn upload_value(&self, _arena: &Arena) -> Spectrum { + self.clone() } } -impl Upload for Material { +impl DeviceRepr for Material { type Target = Material; - fn upload(&self, arena: &Arena) -> Ptr { - arena.alloc(self.clone()) + fn upload_value(&self, _arena: &Arena) -> Material { + self.clone() } } -impl Upload for DenselySampledSpectrumBuffer { +// ============================================================================= +// Image → DeviceImage +// ============================================================================= + +impl DeviceRepr for Image { + type Target = DeviceImage; + fn upload_value(&self, _arena: &Arena) -> DeviceImage { + *self.device() + } +} + +// ============================================================================= +// DenselySampledSpectrumBuffer → DenselySampledSpectrum +// ============================================================================= + +impl DeviceRepr for DenselySampledSpectrumBuffer { type Target = DenselySampledSpectrum; - fn upload(&self, arena: &Arena) -> Ptr { - arena.alloc(*&self.device()) + fn upload_value(&self, _arena: &Arena) -> DenselySampledSpectrum { + self.device() } } -impl Upload for SpectrumTexture { +// ============================================================================= +// RGBToSpectrumTable — re-uploads Ptr fields into arena +// ============================================================================= + +impl DeviceRepr for RGBToSpectrumTable { + type Target = RGBToSpectrumTable; + + fn upload_value(&self, arena: &Arena) -> RGBToSpectrumTable { + let n_nodes = self.n_nodes as usize; + + // Safety: these Ptrs point into static or previously-uploaded data; + // we're copying the contents into the arena for a new lifetime. + let z_slice = unsafe { from_raw_parts(self.z_nodes.as_raw(), n_nodes) }; + let (z_ptr, _) = arena.alloc_slice(z_slice); + + let n_coeffs = 3 * n_nodes.pow(3); + let coeffs_slice = unsafe { from_raw_parts(self.coeffs.as_raw(), n_coeffs) }; + let (c_ptr, _) = arena.alloc_slice(coeffs_slice); + + RGBToSpectrumTable { + z_nodes: z_ptr, + coeffs: c_ptr, + n_nodes: self.n_nodes, + } + } +} + +// ============================================================================= +// RGBColorSpace — nested upload of spectrum table +// ============================================================================= + +impl DeviceRepr for RGBColorSpace { + type Target = RGBColorSpace; + + fn upload_value(&self, arena: &Arena) -> RGBColorSpace { + let table_ptr = self.rgb_to_spectrum_table.upload(arena); + + RGBColorSpace { + r: self.r, + g: self.g, + b: self.b, + w: self.w, + illuminant: self.illuminant.clone(), + rgb_to_spectrum_table: table_ptr, + xyz_from_rgb: self.xyz_from_rgb, + rgb_from_xyz: self.rgb_from_xyz, + } + } +} + +// ============================================================================= +// DeviceStandardColorSpaces — composition of color space uploads +// ============================================================================= + +impl DeviceRepr for DeviceStandardColorSpaces { + type Target = DeviceStandardColorSpaces; + + fn upload_value(&self, arena: &Arena) -> DeviceStandardColorSpaces { + DeviceStandardColorSpaces { + srgb: self.srgb.upload(arena), + dci_p3: self.dci_p3.upload(arena), + rec2020: self.rec2020.upload(arena), + aces2065_1: self.aces2065_1.upload(arena), + } + } +} + +// ============================================================================= +// TriangleMesh → DeviceTriangleMesh +// ============================================================================= +// TriangleMesh should own all its fields directly — no .device field. +// The host struct holds Vec arrays + scalar metadata. derive(Device) handles it: +// +// #[derive(Device)] +// #[device(name = "DeviceTriangleMesh")] +// pub struct TriangleMesh { +// pub vertex_indices: Vec, +// pub p: Vec, +// pub n: Vec, +// pub s: Vec, +// pub uv: Vec, +// pub face_indices: Vec, +// pub n_triangles: u32, +// pub reverse_orientation: bool, +// pub transform_swaps_handedness: bool, +// } +// +// Until the mesh struct is refactored to remove .device, here's a manual +// impl that lists every field. This is the MIGRATION TARGET — once TriangleMesh +// drops its .device field and puts scalars at the top level, switch to derive(Device). + +impl DeviceRepr for TriangleMesh { + type Target = DeviceTriangleMesh; + + fn upload_value(&self, arena: &Arena) -> DeviceTriangleMesh { + let s = &self.storage; + + let (vertex_indices_ptr, _) = arena.alloc_slice(&s.vertex_indices); + let (p_ptr, _) = arena.alloc_slice(&s.p); + let (n_ptr, _) = arena.alloc_slice(&s.n); + let (s_ptr, _) = arena.alloc_slice(&s.s); + let (uv_ptr, _) = arena.alloc_slice(&s.uv); + let (face_indices_ptr, _) = arena.alloc_slice(&s.face_indices); + + DeviceTriangleMesh { + vertex_indices: vertex_indices_ptr, + p: p_ptr, + n: n_ptr, + s: s_ptr, + uv: uv_ptr, + face_indices: face_indices_ptr, + n_triangles: self.n_triangles, + reverse_orientation: self.reverse_orientation, + transform_swaps_handedness: self.transform_swaps_handedness, + } + } +} + +// ============================================================================= +// BilinearPatchMesh → DeviceBilinearPatchMesh +// ============================================================================= +// Same as TriangleMesh: once the host struct is refactored to own scalars +// directly (no .device field), this switches to derive(Device). + +impl DeviceRepr for BilinearPatchMesh { + type Target = DeviceBilinearPatchMesh; + + fn upload_value(&self, arena: &Arena) -> DeviceBilinearPatchMesh { + let s = &self.storage; + + let (vertex_indices_ptr, _) = arena.alloc_slice(&s.vertex_indices); + let (p_ptr, _) = arena.alloc_slice(&s.p); + let (n_ptr, _) = arena.alloc_slice(&s.n); + let (uv_ptr, _) = arena.alloc_slice(&s.uv); + + let image_dist_ptr = s.image_distribution.upload(arena); + + DeviceBilinearPatchMesh { + vertex_indices: vertex_indices_ptr, + p: p_ptr, + n: n_ptr, + uv: uv_ptr, + image_distribution: image_dist_ptr, + n_patches: self.n_patches, + reverse_orientation: self.reverse_orientation, + transform_swaps_handedness: self.transform_swaps_handedness, + } + } +} + +// ============================================================================= +// SpectrumTexture → GPUSpectrumTexture +// ============================================================================= + +impl DeviceRepr for SpectrumTexture { type Target = GPUSpectrumTexture; - fn upload(&self, arena: &Arena) -> Ptr { - let gpu_variant = match self { + + fn upload_value(&self, arena: &Arena) -> GPUSpectrumTexture { + match self { SpectrumTexture::Constant(tex) => GPUSpectrumTexture::Constant(tex.clone()), SpectrumTexture::Checkerboard(tex) => GPUSpectrumTexture::Checkerboard(tex.clone()), SpectrumTexture::Dots(tex) => GPUSpectrumTexture::Dots(tex.clone()), SpectrumTexture::Image(tex) => { let tex_obj = arena.get_texture_object(&tex.base.mipmap); - let gpu_img = GPUSpectrumImageTexture { + GPUSpectrumTexture::Image(GPUSpectrumImageTexture { mapping: tex.base.mapping, tex_obj, scale: tex.base.scale, @@ -255,55 +459,50 @@ impl Upload for SpectrumTexture { .clone() .unwrap_or_else(crate::spectra::default_colorspace), spectrum_type: tex.spectrum_type, - }; - GPUSpectrumTexture::Image(gpu_img) + }) } SpectrumTexture::Bilerp(tex) => GPUSpectrumTexture::Bilerp(tex.clone()), SpectrumTexture::Scaled(tex) => { let child_ptr = tex.tex.upload(arena); - - let gpu_scaled = GPUSpectrumScaledTexture { + let scale_ptr = tex.scale.upload(arena); + GPUSpectrumTexture::Scaled(GPUSpectrumScaledTexture { tex: child_ptr, - scale: tex.scale.upload(arena), - }; - GPUSpectrumTexture::Scaled(gpu_scaled) + scale: scale_ptr, + }) } SpectrumTexture::Marble(tex) => GPUSpectrumTexture::Marble(tex.clone()), SpectrumTexture::Mix(tex) => { let tex1_ptr = tex.tex1.upload(arena); let tex2_ptr = tex.tex2.upload(arena); let amount_ptr = tex.amount.upload(arena); - - let gpu_mix = GPUSpectrumMixTexture { + GPUSpectrumTexture::Mix(GPUSpectrumMixTexture { tex1: tex1_ptr, tex2: tex2_ptr, amount: amount_ptr, - }; - - GPUSpectrumTexture::Mix(gpu_mix) + }) } SpectrumTexture::DirectionMix(tex) => { let tex1_ptr = tex.tex1.upload(arena); let tex2_ptr = tex.tex2.upload(arena); - - let gpu_mix = GPUSpectrumDirectionMixTexture { + GPUSpectrumTexture::DirectionMix(GPUSpectrumDirectionMixTexture { tex1: tex1_ptr, tex2: tex2_ptr, dir: tex.dir, - }; - - GPUSpectrumTexture::DirectionMix(gpu_mix) + }) } - }; - - arena.alloc(gpu_variant) + } } } -impl Upload for FloatTexture { +// ============================================================================= +// FloatTexture → GPUFloatTexture +// ============================================================================= + +impl DeviceRepr for FloatTexture { type Target = GPUFloatTexture; - fn upload(&self, arena: &Arena) -> Ptr { - let gpu_variant = match self { + + fn upload_value(&self, arena: &Arena) -> GPUFloatTexture { + match self { FloatTexture::Constant(tex) => GPUFloatTexture::Constant(tex.clone()), FloatTexture::Checkerboard(tex) => GPUFloatTexture::Checkerboard(tex.clone()), FloatTexture::Dots(tex) => GPUFloatTexture::Dots(tex.clone()), @@ -312,249 +511,40 @@ impl Upload for FloatTexture { FloatTexture::Wrinkled(tex) => GPUFloatTexture::Wrinkled(tex.clone()), FloatTexture::Scaled(tex) => { let child_ptr = tex.tex.upload(arena); - - let gpu_scaled = GPUFloatScaledTexture { + let scale_ptr = tex.scale.upload(arena); + GPUFloatTexture::Scaled(GPUFloatScaledTexture { tex: child_ptr, - scale: tex.scale.upload(arena), - }; - GPUFloatTexture::Scaled(gpu_scaled) + scale: scale_ptr, + }) } - FloatTexture::Mix(tex) => { let tex1_ptr = tex.tex1.upload(arena); let tex2_ptr = tex.tex2.upload(arena); let amount_ptr = tex.amount.upload(arena); - - let gpu_mix = GPUFloatMixTexture { + GPUFloatTexture::Mix(GPUFloatMixTexture { tex1: tex1_ptr, tex2: tex2_ptr, amount: amount_ptr, - }; - GPUFloatTexture::Mix(gpu_mix) + }) } - FloatTexture::DirectionMix(tex) => { let tex1_ptr = tex.tex1.upload(arena); let tex2_ptr = tex.tex2.upload(arena); - let gpu_dmix = GPUFloatDirectionMixTexture { + GPUFloatTexture::DirectionMix(GPUFloatDirectionMixTexture { tex1: tex1_ptr, tex2: tex2_ptr, dir: tex.dir, - }; - GPUFloatTexture::DirectionMix(gpu_dmix) + }) } - FloatTexture::Image(tex) => { - let gpu_image_tex = GPUFloatImageTexture { + GPUFloatTexture::Image(GPUFloatImageTexture { mapping: tex.base.mapping, tex_obj: tex.base.mipmap.texture_object(), scale: tex.base.scale, invert: tex.base.invert, - }; - GPUFloatTexture::Image(gpu_image_tex) + }) } - FloatTexture::Bilerp(tex) => GPUFloatTexture::Bilerp(tex.clone()), - }; - - arena.alloc(gpu_variant) - } -} - -impl Upload for RGBToSpectrumTable { - type Target = RGBToSpectrumTable; - - fn upload(&self, arena: &Arena) -> Ptr { - let n_nodes = self.n_nodes as usize; - let z_slice = unsafe { from_raw_parts(self.z_nodes.as_raw(), n_nodes) }; - let (z_ptr, _) = arena.alloc_slice(z_slice); - - let n_coeffs = 3 * (n_nodes as usize).pow(3); - let coeffs_slice = unsafe { from_raw_parts(self.coeffs.as_raw(), n_coeffs) }; - let (c_ptr, _) = arena.alloc_slice(coeffs_slice); - - arena.alloc(RGBToSpectrumTable { - z_nodes: z_ptr, - coeffs: c_ptr, - n_nodes: self.n_nodes, - }) - } -} - -impl Upload for RGBColorSpace { - type Target = RGBColorSpace; - - fn upload(&self, arena: &Arena) -> Ptr { - let table_ptr = self.rgb_to_spectrum_table.upload(arena); - - let shared_space = RGBColorSpace { - r: self.r, - g: self.g, - b: self.b, - w: self.w, - illuminant: self.illuminant.clone(), - rgb_to_spectrum_table: table_ptr, - xyz_from_rgb: self.xyz_from_rgb, - rgb_from_xyz: self.rgb_from_xyz, - }; - - arena.alloc(shared_space) - } -} - -impl Upload for DeviceStandardColorSpaces { - type Target = DeviceStandardColorSpaces; - - fn upload(&self, arena: &Arena) -> Ptr { - let srgb_ptr = self.srgb.upload(arena); - let dci_ptr = self.dci_p3.upload(arena); - let rec_ptr = self.rec2020.upload(arena); - let aces_ptr = self.aces2065_1.upload(arena); - - let registry = DeviceStandardColorSpaces { - srgb: srgb_ptr, - dci_p3: dci_ptr, - rec2020: rec_ptr, - aces2065_1: aces_ptr, - }; - - arena.alloc(registry) - } -} - -impl Upload for PiecewiseConstant2D { - type Target = DevicePiecewiseConstant2D; - - fn upload(&self, arena: &Arena) -> Ptr { - let marginal_shared = self.marginal.to_shared(arena); - - let conditionals_shared: Vec = self - .conditionals - .iter() - .map(|c| c.to_shared(arena)) - .collect(); - - let (conditionals_ptr, _) = arena.alloc_slice(&conditionals_shared); - - let shared_2d = DevicePiecewiseConstant2D { - conditionals: conditionals_ptr, - marginal: marginal_shared, - ..self.device - }; - - arena.alloc(shared_2d) - } -} - -impl Upload for WindowedPiecewiseConstant2D { - type Target = DeviceWindowedPiecewiseConstant2D; - fn upload(&self, arena: &Arena) -> Ptr { - let specific = DeviceWindowedPiecewiseConstant2D { - sat: self.sat, - func: self.func, - }; - arena.alloc(specific) - } -} - -impl Upload for TriangleMesh { - type Target = DeviceTriangleMesh; - - fn upload(&self, arena: &Arena) -> Ptr { - let storage = &self.storage; - - // Upload all arrays to arena - let (vertex_indices_ptr, _) = arena.alloc_slice(&storage.vertex_indices); - let (p_ptr, _) = arena.alloc_slice(&storage.p); - - let (n_ptr, _) = if storage.n.is_empty() { - (Ptr::null(), 0) - } else { - arena.alloc_slice(&storage.n) - }; - - let (s_ptr, _) = if storage.s.is_empty() { - (Ptr::null(), 0) - } else { - arena.alloc_slice(&storage.s) - }; - - let (uv_ptr, _) = if storage.uv.is_empty() { - (Ptr::null(), 0) - } else { - arena.alloc_slice(&storage.uv) - }; - - let (face_indices_ptr, _) = if storage.face_indices.is_empty() { - (Ptr::null(), 0) - } else { - arena.alloc_slice(&storage.face_indices) - }; - - let device = DeviceTriangleMesh { - vertex_indices: vertex_indices_ptr, - p: p_ptr, - n: n_ptr, - s: s_ptr, - uv: uv_ptr, - face_indices: face_indices_ptr, - ..self.device - }; - - arena.alloc(device) - } -} - -impl Upload for BilinearPatchMesh { - type Target = DeviceBilinearPatchMesh; - - fn upload(&self, arena: &Arena) -> Ptr { - let storage = &self.storage; - - let (vertex_indices_ptr, _) = arena.alloc_slice(&storage.vertex_indices); - let (p_ptr, _) = arena.alloc_slice(&storage.p); - - let (n_ptr, _) = if storage.n.is_empty() { - (Ptr::null(), 0) - } else { - arena.alloc_slice(&storage.n) - }; - - let (uv_ptr, _) = if storage.uv.is_empty() { - (Ptr::null(), 0) - } else { - arena.alloc_slice(&storage.uv) - }; - - let image_dist_ptr = storage.image_distribution.upload(arena); - - let device = DeviceBilinearPatchMesh { - vertex_indices: vertex_indices_ptr, - p: p_ptr, - n: n_ptr, - uv: uv_ptr, - image_distribution: image_dist_ptr, - ..self.device - }; - - arena.alloc(device) - } -} - -impl Upload for Option { - type Target = T::Target; - fn upload(&self, arena: &Arena) -> Ptr { - match self { - Some(val) => val.upload(arena), - None => Ptr::null(), } } } - -impl Upload for Arc { - type Target = T::Target; - - fn upload(&self, arena: &Arena) -> Ptr { - (**self).upload(arena) - } -} diff --git a/src/utils/backend.rs b/src/utils/backend.rs index 89ddc0a..9c44412 100644 --- a/src/utils/backend.rs +++ b/src/utils/backend.rs @@ -1,18 +1,17 @@ use std::alloc::Layout; -pub trait GpuAllocator: Send + Sync { - /// Allocate `size` bytes with given alignment. - /// Returns a host-mapped pointer. +pub trait GpuAllocator: Send + Sync + Clone { unsafe fn alloc(&self, layout: Layout) -> *mut u8; unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout); } -/// CPU fallback — standard system allocator. +/// CPU fallback +#[derive(Clone)] pub struct SystemAllocator; impl Default for SystemAllocator { fn default() -> Self { - Self {} + Self } } @@ -21,101 +20,85 @@ impl GpuAllocator for SystemAllocator { if layout.size() == 0 { return layout.align() as *mut u8; } - unsafe { std::alloc::alloc(layout) } + std::alloc::alloc(layout) } unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { if layout.size() > 0 { - unsafe { - std::alloc::dealloc(ptr, layout); - } + std::alloc::dealloc(ptr, layout); } } } -/// CUDA unified memory backend using CudaAllocator +/// CUDA unified memory. Still using CudaAllocator, might move over to #[cfg(feature = "cuda")] pub mod cuda { use super::GpuAllocator; + use cust::memory::{cuda_free_unified, cuda_malloc_unified, UnifiedPointer}; use std::alloc::Layout; + #[derive(Clone)] pub struct CudaAllocator; impl Default for CudaAllocator { fn default() -> Self { - Self {} + Self } } impl GpuAllocator for CudaAllocator { unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - use cust::memory::cuda_malloc_unified; - use cust_raw::driver_sys::*; - - let size = layout.size().max(layout.align()); + let size = layout.size(); if size == 0 { return layout.align() as *mut u8; } - let mut ctx: CUcontext = std::ptr::null_mut(); - cuCtxGetCurrent(&mut ctx); - if ctx.is_null() { - let mut primary: CUcontext = std::ptr::null_mut(); - cuDevicePrimaryCtxRetain(&mut primary, 0); - cuCtxSetCurrent(primary); - } + let ptr = cuda_malloc_unified::(size) + .expect("cuda_malloc_unified failed — is a CUDA context current?"); - let mut unified_ptr = - unsafe { cuda_malloc_unified::(size).expect("cuda_malloc_unified failed") }; - let raw = unified_ptr.as_raw_mut(); - std::mem::forget(unified_ptr); + let raw = ptr.as_raw_mut(); + std::mem::forget(ptr); // Leak RAII wrapper; Arena owns the raw pointer raw } unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { - use cust::memory::{UnifiedPointer, cuda_free_unified}; - if layout.size() > 0 { - let _ = unsafe { cuda_free_unified(UnifiedPointer::wrap(ptr)) }; + if layout.size() > 0 && !ptr.is_null() { + let _ = cuda_free_unified(UnifiedPointer::wrap(ptr)); } } } } -/// Vulkan backend (gpu-allocator for now, there might be a better solution) #[cfg(feature = "vulkan")] pub mod vulkan { use super::GpuAllocator; use ash::vk; - use gpu_allocator::MemoryLocation; use gpu_allocator::vulkan::{ Allocation, AllocationCreateDesc, AllocationScheme, Allocator, AllocatorCreateDesc, }; + use gpu_allocator::MemoryLocation; use parking_lot::Mutex; use std::alloc::Layout; use std::collections::HashMap; - use std::sync::OnceLock; + use std::sync::Arc; - // So, having a static allocator seems like a terrible idea - // But I cant find a way to get a functioning generic Arena constructor - // That might not even be a necessity, since rust-gpu/rust-cuda might actually handle that - // differently - static VK_ALLOCATOR: OnceLock = OnceLock::new(); - - struct VulkanAllocatorInner { - state: Mutex, + #[derive(Clone)] + pub struct VulkanAllocator { + inner: Arc>, } - struct VulkanState { + struct VulkanInner { + device: ash::Device, allocator: Allocator, allocations: HashMap, } - pub fn init_vulkan( - instance: &ash::Instance, - device: &ash::Device, - physical_device: vk::PhysicalDevice, - ) { - VK_ALLOCATOR.get_or_init(|| { + impl VulkanAllocator { + pub fn new( + instance: &ash::Instance, + device: ash::Device, + physical_device: vk::PhysicalDevice, + ) -> Self { let allocator = Allocator::new(&AllocatorCreateDesc { instance: instance.clone(), device: device.clone(), @@ -126,52 +109,29 @@ pub mod vulkan { }) .expect("Failed to create Vulkan allocator"); - VulkanAllocatorInner { - state: Mutex::new(VulkanState { + Self { + inner: Arc::new(Mutex::new(VulkanInner { + device, allocator, allocations: HashMap::new(), - }), + })), } - }); - } - - fn inner() -> &'static VulkanAllocatorInner { - VK_ALLOCATOR - .get() - .expect("Vulkan not initialized — call init_vulkan() before arena creation") - } - - impl Default for VulkanAllocator { - fn default() -> Self { - let _ = inner(); - Self } } - pub struct VulkanAllocator; - - // impl VulkanAllocator { - // pub fn new(allocator: Allocator) -> Self { - // Self { - // allocator: Mutex::new(allocator), - // allocations: Mutex::new(HashMap::new()), - // } - // } - // } - impl GpuAllocator for VulkanAllocator { unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - let size = layout.size().max(layout.align()); + let size = layout.size(); if size == 0 { return layout.align() as *mut u8; } - let inner = inner(); - let mut state = inner.state.lock(); - let allocation = state + let mut inner = self.inner.lock(); + + let allocation = inner .allocator .allocate(&AllocationCreateDesc { - name: "arena", + name: "arena_chunk", requirements: vk::MemoryRequirements { size: size as u64, alignment: layout.align() as u64, @@ -188,18 +148,17 @@ pub mod vulkan { .expect("Vulkan allocation not host-mapped") .as_ptr() as *mut u8; - state.allocations.insert(ptr as usize, allocation); + inner.allocations.insert(ptr as usize, allocation); ptr } unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { - if layout.size() == 0 { + if layout.size() == 0 || ptr.is_null() { return; } - let inner = inner(); - let mut state = inner.state.lock(); - if let Some(allocation) = state.allocations.remove(&(ptr as usize)) { - state + let mut inner = self.inner.lock(); + if let Some(allocation) = inner.allocations.remove(&(ptr as usize)) { + inner .allocator .free(allocation) .expect("Vulkan free failed"); diff --git a/src/utils/containers.rs b/src/utils/containers.rs index 3ec5dc4..128b0f2 100644 --- a/src/utils/containers.rs +++ b/src/utils/containers.rs @@ -34,6 +34,7 @@ where } #[derive(Debug, Clone)] +#[derive(DeviceRepr)] pub struct Array2D { pub device: DeviceArray2D, pub values: Vec, diff --git a/src/utils/mod.rs b/src/utils/mod.rs index aad0b5c..18ba826 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -12,7 +12,7 @@ pub mod parser; pub mod sampling; pub mod strings; -pub use arena::Upload; +pub use arena::DeviceRepr; pub use error::FileLoc; pub use file::{read_float_file, resolve_filename}; pub use parameters::{ @@ -28,3 +28,516 @@ pub type Arena = arena::Arena; #[cfg(not(any(feature = "cuda", feature = "vulkan")))] pub type Arena = arena::Arena; + +/// # Enum variant attributes +/// +/// | Attribute | Effect | +/// |-----------|--------| +/// | *(none)* | Inner type has `DeviceRepr`; auto-call `upload_value` | +/// | `#[device(clone)]` | Same type on both sides, just clone | +/// | `#[device(custom = "method")]` | You provide `fn method(inner: &T, arena) -> DeviceT` | +/// | `#[device(variant_type = "T")]` | Override the device-side variant's inner type | +/// +/// # Container attribute +/// +/// `#[device(name = "DeviceFoo")]` — override the generated type name (default: `Device{Name}`). +#[proc_macro_derive(Device, attributes(device))] +pub fn derive_device(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match derive_impl(input) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} + +fn derive_impl(input: DeriveInput) -> syn::Result { + match &input.data { + Data::Struct(_) => derive_struct(input), + Data::Enum(_) => derive_enum(input), + Data::Union(_) => Err(syn::Error::new_spanned( + &input.ident, + "Device derive does not support unions", + )), + } +} + +// Struct derivation + +fn derive_struct(input: DeriveInput) -> syn::Result { + let host_name = &input.ident; + let vis = &input.vis; + let device_name = get_device_name(&input.attrs, host_name)?; + + let fields = match &input.data { + Data::Struct(s) => match &s.fields { + Fields::Named(named) => &named.named, + _ => { + return Err(syn::Error::new_spanned( + host_name, + "Device derive only supports structs with named fields", + )) + } + }, + _ => unreachable!(), + }; + + let mut device_fields = Vec::new(); + let mut upload_stmts = Vec::new(); + let mut device_field_inits = Vec::new(); + let mut spread_expr: Option = None; + + for field in fields { + let field_name = field.ident.as_ref().unwrap(); + let attrs = parse_field_attrs(&field.attrs)?; + + if attrs.skip { + continue; + } + + if let Some(ref expr_str) = attrs.spread { + spread_expr = Some(syn::parse_str(expr_str).map_err(|e| { + syn::Error::new_spanned(field, format!("invalid device(spread): {}", e)) + })?); + continue; + } + + if let Some(expr_str) = &attrs.expr { + let expr: Expr = syn::parse_str(expr_str).map_err(|e| { + syn::Error::new_spanned(field, format!("invalid device(expr): {}", e)) + })?; + let ty = &field.ty; + device_fields.push(quote! { pub #field_name: #ty }); + upload_stmts.push(quote! { let #field_name = #expr; }); + device_field_inits.push(quote! { #field_name }); + continue; + } + + match classify_type(&field.ty) { + FieldClass::VecCopy(inner_ty) => { + let len_name = format_ident!("{}_len", field_name); + device_fields.push(quote! { pub #field_name: Ptr<#inner_ty> }); + device_fields.push(quote! { pub #len_name: usize }); + upload_stmts.push(quote! { + let (#field_name, #len_name) = arena.alloc_slice(&self.#field_name); + }); + device_field_inits.push(quote! { #field_name }); + device_field_inits.push(quote! { #len_name }); + } + FieldClass::VecUploadable(inner_ty) => { + let len_name = format_ident!("{}_len", field_name); + device_fields.push(quote! { + pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target> + }); + device_fields.push(quote! { pub #len_name: usize }); + upload_stmts.push(quote! { + let __up: Vec<<#inner_ty as DeviceRepr>::Target> = self.#field_name + .iter() + .map(|item| DeviceRepr::upload_value(item, arena)) + .collect(); + let (#field_name, #len_name) = arena.alloc_slice(&__up); + }); + device_field_inits.push(quote! { #field_name }); + device_field_inits.push(quote! { #len_name }); + } + FieldClass::Option(inner_ty) => { + device_fields.push(quote! { + pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target> + }); + upload_stmts.push(quote! { + let #field_name = match &self.#field_name { + Some(val) => DeviceRepr::upload(val, arena), + None => Ptr::null(), + }; + }); + device_field_inits.push(quote! { #field_name }); + } + FieldClass::Arc(inner_ty) => { + if attrs.flatten { + device_fields.push(quote! { + pub #field_name: <#inner_ty as DeviceRepr>::Target + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload_value(&*self.#field_name, arena); + }); + } else { + device_fields.push(quote! { + pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target> + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload(&*self.#field_name, arena); + }); + } + device_field_inits.push(quote! { #field_name }); + } + FieldClass::Plain => { + let ty = &field.ty; + if attrs.copy_upload { + device_fields.push(quote! { pub #field_name: #ty }); + upload_stmts.push(quote! { + let #field_name = self.#field_name.clone(); + }); + } else if attrs.flatten { + device_fields.push(quote! { + pub #field_name: <#ty as DeviceRepr>::Target + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload_value(&self.#field_name, arena); + }); + } else if attrs.upload { + device_fields.push(quote! { + pub #field_name: Ptr<<#ty as DeviceRepr>::Target> + }); + upload_stmts.push(quote! { + let #field_name = DeviceRepr::upload(&self.#field_name, arena); + }); + } else { + device_fields.push(quote! { pub #field_name: #ty }); + upload_stmts.push(quote! { + let #field_name = self.#field_name; + }); + } + device_field_inits.push(quote! { #field_name }); + } + } + } + + let constructor = if let Some(spread) = spread_expr { + quote! { + #device_name { + #(#device_field_inits,)* + ..#spread + } + } + } else { + quote! { + #device_name { + #(#device_field_inits,)* + } + } + }; + + Ok(quote! { + #[repr(C)] + #[derive(Debug, Copy, Clone)] + #vis struct #device_name { + #(#device_fields,)* + } + + unsafe impl Send for #device_name {} + unsafe impl Sync for #device_name {} + + impl DeviceRepr for #host_name { + type Target = #device_name; + + fn upload_value(&self, arena: &Arena) -> Self::Target { + #(#upload_stmts)* + #constructor + } + } + }) +} + +// Enum derivation +fn derive_enum(input: DeriveInput) -> syn::Result { + let host_name = &input.ident; + let vis = &input.vis; + let device_name = get_device_name(&input.attrs, host_name)?; + + let variants = match &input.data { + Data::Enum(e) => &e.variants, + _ => unreachable!(), + }; + + let mut device_variants = Vec::new(); + let mut match_arms = Vec::new(); + + for variant in variants { + let var_name = &variant.ident; + let var_attrs = parse_variant_attrs(&variant.attrs)?; + let inner_ty = get_variant_inner_type(variant)?; + + // Determine the device-side inner type for this variant + let device_inner: Type = if let Some(ref ty_str) = var_attrs.variant_type { + syn::parse_str(ty_str).map_err(|e| { + syn::Error::new_spanned(variant, format!("invalid variant_type: {}", e)) + })? + } else if var_attrs.clone_variant { + // clone: same type on both sides + inner_ty.clone() + } else { + // auto-upload: use DeviceRepr::Target + syn::parse_str(&format!("<{} as DeviceRepr>::Target", quote!(#inner_ty))).map_err( + |e| { + syn::Error::new_spanned(variant, format!("cannot construct Target type: {}", e)) + }, + )? + }; + + device_variants.push(quote! { #var_name(#device_inner) }); + + if var_attrs.clone_variant { + match_arms.push(quote! { + #host_name::#var_name(inner) => #device_name::#var_name(inner.clone()) + }); + } else if let Some(ref method) = var_attrs.custom { + let method_ident = format_ident!("{}", method); + match_arms.push(quote! { + #host_name::#var_name(inner) => { + #device_name::#var_name(Self::#method_ident(inner, arena)) + } + }); + } else { + // Default: inner implements DeviceRepr + match_arms.push(quote! { + #host_name::#var_name(inner) => { + #device_name::#var_name(DeviceRepr::upload_value(inner, arena)) + } + }); + } + } + + Ok(quote! { + #[repr(C)] + #[derive(Debug, Copy, Clone)] + #vis enum #device_name { + #(#device_variants,)* + } + + unsafe impl Send for #device_name {} + unsafe impl Sync for #device_name {} + + impl DeviceRepr for #host_name { + type Target = #device_name; + + fn upload_value(&self, arena: &Arena) -> Self::Target { + match self { + #(#match_arms,)* + } + } + } + }) +} + +fn get_variant_inner_type(variant: &Variant) -> syn::Result { + match &variant.fields { + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + Ok(fields.unnamed.first().unwrap().ty.clone()) + } + Fields::Unit => Err(syn::Error::new_spanned( + variant, + "Device derive: enum variants must have exactly one field, e.g. Variant(Type)", + )), + _ => Err(syn::Error::new_spanned( + variant, + "Device derive: only single-field tuple variants supported, e.g. Variant(Type)", + )), + } +} + +// Attribute parsing for variants +struct VariantAttrs { + clone_variant: bool, + custom: Option, + variant_type: Option, +} + +fn parse_variant_attrs(attrs: &[Attribute]) -> syn::Result { + let mut result = VariantAttrs { + clone_variant: false, + custom: None, + variant_type: None, + }; + + for attr in attrs { + if !attr.path().is_ident("device") { + continue; + } + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("clone") { + result.clone_variant = true; + Ok(()) + } else if meta.path.is_ident("custom") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.custom = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else if meta.path.is_ident("variant_type") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.variant_type = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else { + Err(meta.error("unknown device variant attribute")) + } + })?; + } + + Ok(result) +} + +// Attribute parsing for fields +struct FieldAttrs { + skip: bool, + expr: Option, + copy_upload: bool, + flatten: bool, + upload: bool, + spread: Option, +} + +fn parse_field_attrs(attrs: &[Attribute]) -> syn::Result { + let mut result = FieldAttrs { + skip: false, + expr: None, + copy_upload: false, + flatten: false, + upload: false, + spread: None, + }; + + for attr in attrs { + if !attr.path().is_ident("device") { + continue; + } + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("skip") { + result.skip = true; + Ok(()) + } else if meta.path.is_ident("expr") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.expr = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else if meta.path.is_ident("copy_upload") { + result.copy_upload = true; + Ok(()) + } else if meta.path.is_ident("flatten") { + result.flatten = true; + Ok(()) + } else if meta.path.is_ident("upload") { + result.upload = true; + Ok(()) + } else if meta.path.is_ident("spread") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + result.spread = Some(s.value()); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else { + Err(meta.error("unknown device attribute")) + } + })?; + } + + Ok(result) +} + +// Container-level name attribute +fn get_device_name(attrs: &[Attribute], host_name: &Ident) -> syn::Result { + for attr in attrs { + if !attr.path().is_ident("device") { + continue; + } + let mut name = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + let value = meta.value()?; + let lit: Lit = value.parse()?; + if let Lit::Str(s) = lit { + name = Some(format_ident!("{}", s.value())); + Ok(()) + } else { + Err(meta.error("expected string literal")) + } + } else { + Ok(()) + } + })?; + if let Some(n) = name { + return Ok(n); + } + } + Ok(format_ident!("Device{}", host_name)) +} + +// Type classification +enum FieldClass { + VecCopy(Type), + VecUploadable(Type), + Option(Type), + Arc(Type), + Plain, +} + +fn classify_type(ty: &Type) -> FieldClass { + if let Some(inner) = extract_generic_arg(ty, "Vec") { + if is_copy_primitive(&inner) { + FieldClass::VecCopy(inner) + } else { + FieldClass::VecUploadable(inner) + } + } else if let Some(inner) = extract_generic_arg(ty, "Option") { + FieldClass::Option(inner) + } else if let Some(inner) = extract_generic_arg(ty, "Arc") { + FieldClass::Arc(inner) + } else { + FieldClass::Plain + } +} + +fn extract_generic_arg(ty: &Type, wrapper: &str) -> Option { + if let Type::Path(type_path) = ty { + let seg = type_path.path.segments.last()?; + if seg.ident != wrapper { + return None; + } + if let PathArguments::AngleBracketed(args) = &seg.arguments { + if let Some(GenericArgument::Type(inner)) = args.args.first() { + return Some(inner.clone()); + } + } + } + None +} + +fn is_copy_primitive(ty: &Type) -> bool { + if let Type::Path(type_path) = ty { + if let Some(seg) = type_path.path.segments.last() { + let name = seg.ident.to_string(); + return matches!( + name.as_str(), + "f32" + | "f64" + | "u8" + | "u16" + | "u32" + | "u64" + | "i8" + | "i16" + | "i32" + | "i64" + | "usize" + | "isize" + | "bool" + | "Float" + ); + } + } + false +} + diff --git a/src/utils/sampling.rs b/src/utils/sampling.rs index 591d287..37c87db 100644 --- a/src/utils/sampling.rs +++ b/src/utils/sampling.rs @@ -13,43 +13,25 @@ use std::sync::Arc; #[derive(Debug, Clone)] pub struct PiecewiseConstant1D { - func: Box<[Float]>, - cdf: Box<[Float]>, - pub device: DevicePiecewiseConstant1D, + func: Vec, + cdf: Vec, + pub min: Float, + pub max: Float, } impl PiecewiseConstant1D { - // Constructors pub fn new(f: &[Float]) -> Self { Self::new_with_bounds(f.to_vec(), 0.0, 1.0) } - pub fn to_shared(&self, arena: &Arena) -> DevicePiecewiseConstant1D { - let (func_ptr, _) = arena.alloc_slice(&self.func); - let (cdf_ptr, _) = arena.alloc_slice(&self.cdf); - - DevicePiecewiseConstant1D { - func: func_ptr, - cdf: cdf_ptr, - func_integral: self.func_integral, - n: self.func.len() as u32, - min: self.min, - max: self.max, - } - } - pub fn from_func(f: F, min: Float, max: Float, n: usize) -> Self where F: Fn(Float) -> Float, { let delta = (max - min) / n as Float; let values: Vec = (0..n) - .map(|i| { - let x = min + (i as Float + 0.5) * delta; - f(x) - }) + .map(|i| f(min + (i as Float + 0.5) * delta)) .collect(); - Self::new_with_bounds(values, min, max) } @@ -64,74 +46,92 @@ impl PiecewiseConstant1D { } let func_integral = cdf[n]; - if func_integral > 0.0 { for c in &mut cdf { *c /= func_integral; } } - // Convert to boxed slices (no more reallocation possible) - let func: Box<[Float]> = f.into_boxed_slice(); - let cdf: Box<[Float]> = cdf.into_boxed_slice(); - - let device = DevicePiecewiseConstant1D { - func: func.as_ptr().into(), - cdf: cdf.as_ptr().into(), - min, - max, - n: n as u32, - func_integral, - }; - - Self { func, cdf, device } + Self { func: f, cdf, min, max } } // Accessors - pub fn min(&self) -> Float { - self.device.min - } - pub fn max(&self) -> Float { - self.device.max - } - pub fn n(&self) -> usize { - self.device.n as usize - } + pub fn n(&self) -> usize { self.func.len() } + pub fn func(&self) -> &[Float] { &self.func } + pub fn cdf(&self) -> &[Float] { &self.cdf } + pub fn integral(&self) -> Float { - self.device.func_integral + // func_integral is the un-normalized sum. After normalization cdf[n] == 1.0, + // so we reconstruct from the last CDF entry before normalization. + // But since we normalized in-place, we need to store it. Let's compute it. + let n = self.func.len(); + let delta = (self.max - self.min) / n as Float; + self.func.iter().sum::() * delta } - pub fn func(&self) -> &[Float] { - &self.func + /// Host-side sampling (for scene construction, not rendering). + /// During rendering, use the device struct via arena-uploaded Ptrs. + pub fn sample_host(&self, u: Float) -> (Float, Float, usize) { + let n = self.func.len(); + let offset = self.find_interval_host(u); + let cdf_offset = self.cdf[offset]; + let cdf_next = self.cdf[offset + 1]; + let du = if cdf_next - cdf_offset > 0.0 { + (u - cdf_offset) / (cdf_next - cdf_offset) + } else { + 0.0 + }; + let delta = (self.max - self.min) / n as Float; + let x = self.min + (offset as Float + du) * delta; + let func_integral = self.integral(); + let pdf = if func_integral > 0.0 { + self.func[offset] / func_integral + } else { + 0.0 + }; + (x, pdf, offset) } - pub fn cdf(&self) -> &[Float] { - &self.cdf + + fn find_interval_host(&self, u: Float) -> usize { + let n = self.func.len(); + let mut size = n; + let mut first = 0usize; + while size > 0 { + let half = size >> 1; + let middle = first + half; + if self.cdf[middle] <= u { + first = middle + 1; + size -= half + 1; + } else { + size = half; + } + } + first.saturating_sub(1).min(n - 1) } } -impl std::ops::Deref for PiecewiseConstant1D { - type Target = DevicePiecewiseConstant1D; - - fn deref(&self) -> &Self::Target { - &self.device - } +#[derive(DeviceRepr)] +#[device(name = "DevicePiecewiseConstant1D")] +pub struct PiecewiseConstant1D { + pub func: Vec, + pub cdf: Vec, + pub min: Float, + pub max: Float, + pub n: u32, + #[device(expr = "self.integral()")] + pub func_integral: Float, } -#[derive(Debug, Clone)] +#[derive(Clone, Debug, DeviceRepr)] +#[device(name = "DevicePiecewiseConstant2D")] pub struct PiecewiseConstant2D { pub conditionals: Vec, + #[device(flatten)] pub marginal: PiecewiseConstant1D, - pub conditional_devices: Box<[DevicePiecewiseConstant1D]>, - pub device: DevicePiecewiseConstant2D, + pub n_u: u32, + pub n_v: u32, } -impl std::ops::Deref for PiecewiseConstant2D { - type Target = DevicePiecewiseConstant2D; - - fn deref(&self) -> &Self::Target { - &self.device - } -} impl PiecewiseConstant2D { pub fn new(data: &Array2D) -> Self { @@ -141,8 +141,8 @@ impl PiecewiseConstant2D { pub fn new_with_bounds(data: &Array2D, domain: Bounds2f) -> Self { Self::from_slice( data.as_slice(), - data.x_size() as usize, - data.y_size() as usize, + data.x_size(), + data.y_size(), domain, ) } @@ -154,36 +154,23 @@ impl PiecewiseConstant2D { let mut marginal_func = Vec::with_capacity(n_v); for v in 0..n_v { - let row_start = v * n_u; - let row: Vec = data[row_start..row_start + n_u].to_vec(); - let conditional = - PiecewiseConstant1D::new_with_bounds(row, domain.p_min.x(), domain.p_max.x()); + let row = data[v * n_u..(v + 1) * n_u].to_vec(); + let conditional = PiecewiseConstant1D::new_with_bounds( + row, + domain.p_min.x(), + domain.p_max.x(), + ); marginal_func.push(conditional.integral()); conditionals.push(conditional); } - let marginal = - PiecewiseConstant1D::new_with_bounds(marginal_func, domain.p_min.y(), domain.p_max.y()); + let marginal = PiecewiseConstant1D::new_with_bounds( + marginal_func, + domain.p_min.y(), + domain.p_max.y(), + ); - let conditional_devices: Box<[DevicePiecewiseConstant1D]> = conditionals - .iter() - .map(|c| c.device) - .collect::>() - .into_boxed_slice(); - - let device = DevicePiecewiseConstant2D { - conditionals: conditional_devices.as_ptr().into(), - marginal: marginal.device, - n_u: n_u as u32, - n_v: n_v as u32, - }; - - Self { - conditionals, - marginal, - conditional_devices, - device, - } + Self { conditionals, marginal, n_u, n_v } } pub fn from_image(image: &Image) -> Self { @@ -192,12 +179,9 @@ impl PiecewiseConstant2D { let n_v = res.y() as usize; let mut data = Vec::with_capacity(n_u * n_v); - for v in 0..n_v { for u in 0..n_u { - let p = Point2i::new(u as i32, v as i32); - let luminance = image.get_channels(p).average(); - data.push(luminance); + data.push(image.get_channels(Point2i::new(u as i32, v as i32)).average()); } } @@ -217,10 +201,14 @@ struct PiecewiseLinear2DStorage { } pub struct PiecewiseLinear2DHost { - pub view: PiecewiseLinear2D, - _storage: Arc>, + size: Vector2i, + inv_patch_size: Vector2f, + param_size: [u32; N], + param_strides: [u32; N], + storage: Arc>, } + impl PiecewiseLinear2DHost { pub fn new( data: &[Float], @@ -354,27 +342,49 @@ impl PiecewiseLinear2DHost { } } +impl DeviceRepr for PiecewiseLinear2DHost { + type Target = PiecewiseLinear2D; + + fn upload_value(&self, arena: &Arena) -> PiecewiseLinear2D { + let s = &self.storage; + + let (data_ptr, _) = arena.alloc_slice(&s.data); + let (marginal_ptr, _) = arena.alloc_slice(&s.marginal_cdf); + let (conditional_ptr, _) = arena.alloc_slice(&s.conditional_cdf); + + let param_ptrs: [Ptr; N] = std::array::from_fn(|i| { + let (ptr, _) = arena.alloc_slice(&s.param_values[i]); + ptr + }); + + PiecewiseLinear2D { + size: self.size, + inv_patch_size: self.inv_patch_size, + param_size: self.param_size, + param_strides: self.param_strides, + param_values: param_ptrs, + data: data_ptr, + marginal_cdf: marginal_ptr, + conditional_cdf: conditional_ptr, + } + } +} + #[derive(Debug, Clone)] pub struct AliasTableHost { - pub view: AliasTable, - _storage: Vec, + bins: Vec, } impl AliasTableHost { pub fn new(weights: &[Float]) -> Self { let n = weights.len(); if n == 0 { - return Self { - view: AliasTable { - bins: Ptr::null(), - size: 0, - }, - _storage: Vec::new(), - }; + return Self { bins: Vec::new() }; } let sum: f64 = weights.iter().map(|&w| w as f64).sum(); assert!(sum > 0.0, "Sum of weights must be positive"); + let mut bins = Vec::with_capacity(n); for &w in weights { bins.push(Bin { @@ -384,10 +394,7 @@ impl AliasTableHost { }); } - struct Outcome { - p_hat: f64, - index: usize, - } + struct Outcome { p_hat: f64, index: usize } let mut under = Vec::with_capacity(n); let mut over = Vec::with_capacity(n); @@ -409,17 +416,10 @@ impl AliasTableHost { bins[un.index].alias = ov.index as u32; let p_excess = un.p_hat + ov.p_hat - 1.0; - if p_excess < 1.0 { - under.push(Outcome { - p_hat: p_excess, - index: ov.index, - }); + under.push(Outcome { p_hat: p_excess, index: ov.index }); } else { - over.push(Outcome { - p_hat: p_excess, - index: ov.index, - }); + over.push(Outcome { p_hat: p_excess, index: ov.index }); } } @@ -427,51 +427,40 @@ impl AliasTableHost { bins[ov.index].q = 1.0; bins[ov.index].alias = ov.index as u32; } - while let Some(un) = under.pop() { bins[un.index].q = 1.0; bins[un.index].alias = un.index as u32; } - let view = AliasTable { - bins: bins.as_ptr().into(), - size: bins.len() as u32, - }; - - Self { - view, - _storage: bins, - } + Self { bins } } - pub fn to_device(&self, arena: &Arena) -> AliasTable { - if self._storage.is_empty() { - return AliasTable { - bins: Ptr::null(), - size: 0, - }; + pub fn size(&self) -> usize { self.bins.len() } + pub fn is_empty(&self) -> bool { self.bins.is_empty() } +} + +impl DeviceRepr for AliasTableHost { + type Target = AliasTable; + + fn upload_value(&self, arena: &Arena) -> AliasTable { + if self.bins.is_empty() { + return AliasTable { bins: Ptr::null(), size: 0 }; } - let (bins_ptr, _) = arena.alloc_slice(&self._storage); + let (bins_ptr, _) = arena.alloc_slice(&self.bins); AliasTable { bins: bins_ptr, - size: self._storage.len() as u32, + size: self.bins.len() as u32, } } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, DeviceRepr)] +#[device(name = "DeviceSummedAreaTable")] pub struct SummedAreaTable { - pub device: DeviceSummedAreaTable, + #[device(flatten)] sum: Array2D, } -impl std::ops::Deref for SummedAreaTable { - type Target = DeviceSummedAreaTable; - fn deref(&self) -> &Self::Target { - &self.device - } -} - impl SummedAreaTable { pub fn new(values: &Array2D) -> Self { let width = values.x_size() as i32; @@ -483,46 +472,194 @@ impl SummedAreaTable { for x in 1..width { sum[(x, 0)] = values[(x, 0)] as f64 + sum[(x - 1, 0)]; } - for y in 1..height { sum[(0, y)] = values[(0, y)] as f64 + sum[(0, y - 1)]; } - for y in 1..height { for x in 1..width { - let term = values[(x, y)] as f64; - let left = sum[(x - 1, y)]; - let up = sum[(x, y - 1)]; - let diag = sum[(x - 1, y - 1)]; - - sum[(x, y)] = term + left + up - diag; + sum[(x, y)] = values[(x, y)] as f64 + + sum[(x - 1, y)] + + sum[(x, y - 1)] + - sum[(x - 1, y - 1)]; } } - let device = DeviceSummedAreaTable { sum: *sum }; - Self { device, sum } + Self { sum } } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, DeviceRepr)] +#[device(name = "DeviceWindowedPiecewiseConstant2D")] pub struct WindowedPiecewiseConstant2D { - pub device: DeviceWindowedPiecewiseConstant2D, - sat: DeviceSummedAreaTable, + #[device(flatten)] + sat: SummedAreaTable, + #[device(flatten)] func: Array2D, } -impl std::ops::Deref for WindowedPiecewiseConstant2D { - type Target = DeviceWindowedPiecewiseConstant2D; - fn deref(&self) -> &Self::Target { - &self.device - } -} - impl WindowedPiecewiseConstant2D { pub fn new(func: Array2D) -> Self { - let sat = *SummedAreaTable::new(&func); - let device = DeviceWindowedPiecewiseConstant2D { sat, func: *func }; - - Self { sat, func, device } + let sat = SummedAreaTable::new(&func); + Self { sat, func } + } +} + +struct PiecewiseLinear2DStorage { + data: Vec, + marginal_cdf: Vec, + conditional_cdf: Vec, + param_values: [Vec; N], +} + +pub struct PiecewiseLinear2DHost { + size: Vector2i, + inv_patch_size: Vector2f, + param_size: [u32; N], + param_strides: [u32; N], + storage: Arc>, +} + +impl PiecewiseLinear2DHost { + pub fn new( + data: &[Float], + x_size: i32, + y_size: i32, + param_res: [usize; N], + param_values: [&[Float]; N], + normalize: bool, + build_cdf: bool, + ) -> Self { + if build_cdf && !normalize { + panic!("PiecewiseLinear2D: build_cdf implies normalize=true"); + } + + let size = Vector2i::new(x_size, y_size); + let inv_patch_size = Vector2f::new( + 1.0 / (x_size - 1) as Float, + 1.0 / (y_size - 1) as Float, + ); + + let mut param_size = [0u32; N]; + let mut param_strides = [0u32; N]; + let owned_param_values: [Vec; N] = gpu_array_from_fn(|i| param_values[i].to_vec()); + + let mut slices: u32 = 1; + for i in (0..N).rev() { + assert!(param_res[i] >= 1, "Parameter resolution must be >= 1"); + param_size[i] = param_res[i] as u32; + param_strides[i] = if param_res[i] > 1 { slices } else { 0 }; + slices *= param_size[i]; + } + + let n_values = (x_size * y_size) as usize; + let mut new_data = vec![0.0; slices as usize * n_values]; + let mut marginal_cdf = if build_cdf { + vec![0.0; slices as usize * y_size as usize] + } else { + Vec::new() + }; + let mut conditional_cdf = if build_cdf { + vec![0.0; slices as usize * n_values] + } else { + Vec::new() + }; + + let mut data_offset = 0; + for slice in 0..slices as usize { + let slice_offset = slice * n_values; + let current_data = &data[data_offset..data_offset + n_values]; + let mut sum = 0.0_f64; + + if normalize { + for y in 0..(y_size - 1) { + for x in 0..(x_size - 1) { + let i = (y * x_size + x) as usize; + let v00 = current_data[i] as f64; + let v10 = current_data[i + 1] as f64; + let v01 = current_data[i + x_size as usize] as f64; + let v11 = current_data[i + 1 + x_size as usize] as f64; + sum += 0.25 * (v00 + v10 + v01 + v11); + } + } + } + + let normalization = if normalize && sum > 0.0 { + 1.0 / sum as Float + } else { + 1.0 + }; + for k in 0..n_values { + new_data[slice_offset + k] = current_data[k] * normalization; + } + + if build_cdf { + let marginal_slice_offset = slice * y_size as usize; + for y in 0..y_size as usize { + let mut cdf_sum = 0.0; + let i_base = y * x_size as usize; + conditional_cdf[slice_offset + i_base] = 0.0; + for x in 0..(x_size - 1) as usize { + let i = i_base + x; + cdf_sum += 0.5 + * (new_data[slice_offset + i] + new_data[slice_offset + i + 1]); + conditional_cdf[slice_offset + i + 1] = cdf_sum; + } + } + marginal_cdf[marginal_slice_offset] = 0.0; + let mut marginal_sum = 0.0; + for y in 0..(y_size - 1) as usize { + let cdf1 = + conditional_cdf[slice_offset + (y + 1) * x_size as usize - 1]; + let cdf2 = + conditional_cdf[slice_offset + (y + 2) * x_size as usize - 1]; + marginal_sum += 0.5 * (cdf1 + cdf2); + marginal_cdf[marginal_slice_offset + y + 1] = marginal_sum; + } + } + data_offset += n_values; + } + + let storage = Arc::new(PiecewiseLinear2DStorage { + data: new_data, + marginal_cdf, + conditional_cdf, + param_values: owned_param_values, + }); + + Self { + size, + inv_patch_size, + param_size, + param_strides, + storage, + } + } +} + +impl DeviceRepr for PiecewiseLinear2DHost { + type Target = PiecewiseLinear2D; + + fn upload_value(&self, arena: &Arena) -> PiecewiseLinear2D { + let s = &self.storage; + + let (data_ptr, _) = arena.alloc_slice(&s.data); + let (marginal_ptr, _) = arena.alloc_slice(&s.marginal_cdf); + let (conditional_ptr, _) = arena.alloc_slice(&s.conditional_cdf); + + let param_ptrs: [Ptr; N] = std::array::from_fn(|i| { + let (ptr, _) = arena.alloc_slice(&s.param_values[i]); + ptr + }); + + PiecewiseLinear2D { + size: self.size, + inv_patch_size: self.inv_patch_size, + param_size: self.param_size, + param_strides: self.param_strides, + param_values: param_ptrs, + data: data_ptr, + marginal_cdf: marginal_ptr, + conditional_cdf: conditional_ptr, + } } }