Hitting myself, and going for a trait for uploading objs to GPU. Keep forgetting to actually upload them
This commit is contained in:
parent
645556da22
commit
dad7300a14
20 changed files with 1917 additions and 782 deletions
|
|
@ -13,13 +13,13 @@ pub struct FilterSample {
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
#[derive(Clone, Debug, Copy)]
|
#[derive(Clone, Debug, Copy)]
|
||||||
pub struct FilterSampler {
|
pub struct DeviceFilterSampler {
|
||||||
pub domain: Bounds2f,
|
pub domain: Bounds2f,
|
||||||
pub distrib: DevicePiecewiseConstant2D,
|
pub distrib: DevicePiecewiseConstant2D,
|
||||||
pub f: DeviceArray2D<Float>,
|
pub f: DeviceArray2D<Float>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FilterSampler {
|
impl DeviceFilterSampler {
|
||||||
pub fn sample(&self, u: Point2f) -> FilterSample {
|
pub fn sample(&self, u: Point2f) -> FilterSample {
|
||||||
let (p, pdf, pi) = self.distrib.sample(u);
|
let (p, pdf, pi) = self.distrib.sample(u);
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ pub trait FilterTrait {
|
||||||
fn radius(&self) -> Vector2f;
|
fn radius(&self) -> Vector2f;
|
||||||
fn evaluate(&self, p: Point2f) -> Float;
|
fn evaluate(&self, p: Point2f) -> Float;
|
||||||
fn integral(&self) -> Float;
|
fn integral(&self) -> Float;
|
||||||
fn sample(&self, u: Point2f) -> FilterSample;
|
fn sample(&self, u: Point2f) -> DeviceFilterSample;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
|
|
|
||||||
|
|
@ -205,6 +205,7 @@ impl PrimitiveTrait for KdTreeAggregate {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
#[derive(Clone, Debug, Copy)]
|
#[derive(Clone, Debug, Copy)]
|
||||||
#[enum_dispatch(PrimitiveTrait)]
|
#[enum_dispatch(PrimitiveTrait)]
|
||||||
pub enum Primitive {
|
pub enum Primitive {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use crate::core::filter::{FilterSample, FilterTrait};
|
use crate::core::filter::{DeviceFilterSample, FilterTrait};
|
||||||
use crate::core::geometry::{Point2f, Vector2f};
|
use crate::core::geometry::{Point2f, Vector2f};
|
||||||
use crate::utils::math::lerp;
|
use crate::utils::math::lerp;
|
||||||
|
|
||||||
|
|
@ -31,7 +31,7 @@ impl FilterTrait for BoxFilter {
|
||||||
(2.0 * self.radius.x()) * (2.0 * self.radius.y())
|
(2.0 * self.radius.x()) * (2.0 * self.radius.y())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&self, u: Point2f) -> FilterSample {
|
fn sample(&self, u: Point2f) -> DeviceFilterSample {
|
||||||
let p = Point2f::new(
|
let p = Point2f::new(
|
||||||
lerp(u[0], -self.radius.x(), self.radius.x()),
|
lerp(u[0], -self.radius.x(), self.radius.x()),
|
||||||
lerp(u[1], -self.radius.y(), self.radius.y()),
|
lerp(u[1], -self.radius.y(), self.radius.y()),
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use crate::core::filter::{FilterSample, FilterSampler, FilterTrait};
|
use crate::core::filter::{DeviceFilterSample, FilterSampler, FilterTrait};
|
||||||
use crate::core::geometry::{Point2f, Vector2f};
|
use crate::core::geometry::{Point2f, Vector2f};
|
||||||
use crate::utils::math::{gaussian, gaussian_integral};
|
use crate::utils::math::{gaussian, gaussian_integral};
|
||||||
|
|
||||||
|
|
@ -30,7 +30,7 @@ impl FilterTrait for GaussianFilter {
|
||||||
- 2.0 * self.radius.y() * self.exp_y)
|
- 2.0 * self.radius.y() * self.exp_y)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&self, u: Point2f) -> FilterSample {
|
fn sample(&self, u: Point2f) -> DeviceFilterSample {
|
||||||
self.sampler.sample(u)
|
self.sampler.sample(u)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use crate::core::filter::{FilterSample, FilterSampler, FilterTrait};
|
use crate::core::filter::{DeviceFilterSample, FilterSampler, FilterTrait};
|
||||||
use crate::core::geometry::{Point2f, Vector2f};
|
use crate::core::geometry::{Point2f, Vector2f};
|
||||||
use crate::utils::math::{lerp, windowed_sinc};
|
use crate::utils::math::{lerp, windowed_sinc};
|
||||||
|
|
||||||
|
|
@ -26,7 +26,7 @@ impl FilterTrait for LanczosSincFilter {
|
||||||
self.integral
|
self.integral
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&self, u: Point2f) -> FilterSample {
|
fn sample(&self, u: Point2f) -> DeviceFilterSample {
|
||||||
self.sampler.sample(u)
|
self.sampler.sample(u)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use crate::core::filter::{FilterSample, FilterSampler, FilterTrait};
|
use crate::core::filter::{DeviceFilterSample, FilterSampler, FilterTrait};
|
||||||
use crate::core::geometry::{Point2f, Vector2f};
|
use crate::core::geometry::{Point2f, Vector2f};
|
||||||
use num_traits::Float as NumFloat;
|
use num_traits::Float as NumFloat;
|
||||||
|
|
||||||
|
|
@ -9,7 +9,7 @@ pub struct MitchellFilter {
|
||||||
pub radius: Vector2f,
|
pub radius: Vector2f,
|
||||||
pub b: Float,
|
pub b: Float,
|
||||||
pub c: Float,
|
pub c: Float,
|
||||||
pub sampler: FilterSampler,
|
pub sampler: DeviceFilterSampler,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MitchellFilter {
|
impl MitchellFilter {
|
||||||
|
|
@ -50,7 +50,7 @@ impl FilterTrait for MitchellFilter {
|
||||||
self.radius.x() * self.radius.y() / 4.0
|
self.radius.x() * self.radius.y() / 4.0
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&self, u: Point2f) -> FilterSample {
|
fn sample(&self, u: Point2f) -> DeviceFilterSample {
|
||||||
self.sampler.sample(u)
|
self.sampler.sample(u)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use crate::core::filter::{FilterSample, FilterTrait};
|
use crate::core::filter::{DeviceFilterSample, FilterTrait};
|
||||||
use crate::core::geometry::{Point2f, Vector2f};
|
use crate::core::geometry::{Point2f, Vector2f};
|
||||||
use crate::utils::math::sample_tent;
|
use crate::utils::math::sample_tent;
|
||||||
use num_traits::Float as NumFloat;
|
use num_traits::Float as NumFloat;
|
||||||
|
|
@ -29,11 +29,11 @@ impl FilterTrait for TriangleFilter {
|
||||||
self.radius.x().powi(2) * self.radius.y().powi(2)
|
self.radius.x().powi(2) * self.radius.y().powi(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&self, u: Point2f) -> FilterSample {
|
fn sample(&self, u: Point2f) -> DeviceFilterSample {
|
||||||
let p = Point2f::new(
|
let p = Point2f::new(
|
||||||
sample_tent(u[0], self.radius.x()),
|
sample_tent(u[0], self.radius.x()),
|
||||||
sample_tent(u[1], self.radius.y()),
|
sample_tent(u[1], self.radius.y()),
|
||||||
);
|
);
|
||||||
FilterSample { p, weight: 1.0 }
|
DeviceFilterSample { p, weight: 1.0 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,15 @@ pub use options::PBRTOptions;
|
||||||
pub use ptr::Ptr;
|
pub use ptr::Ptr;
|
||||||
pub use transform::{AnimatedTransform, Transform, TransformGeneric};
|
pub use transform::{AnimatedTransform, Transform, TransformGeneric};
|
||||||
|
|
||||||
|
use proc_macro::TokenStream;
|
||||||
|
use proc_macro2::TokenStream as TokenStream2;
|
||||||
|
use quote::{format_ident, quote};
|
||||||
|
use syn::{
|
||||||
|
parse_macro_input, Attribute, Data, DeriveInput, Expr, Fields, GenericArgument, Ident, Lit,
|
||||||
|
PathArguments, Type, Variant,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use core::sync::atomic::{AtomicU32, Ordering};
|
use core::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
|
||||||
|
|
@ -128,3 +137,515 @@ pub fn gpu_array_from_fn<T, const N: usize>(mut f: impl FnMut(usize) -> T) -> [T
|
||||||
arr.assume_init()
|
arr.assume_init()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Enum variant attributes
|
||||||
|
///
|
||||||
|
/// | Attribute | Effect |
|
||||||
|
/// |-----------|--------|
|
||||||
|
/// | *(none)* | Inner type has `DeviceRepr`; auto-call `upload_value` |
|
||||||
|
/// | `#[device(clone)]` | Same type on both sides, just clone |
|
||||||
|
/// | `#[device(custom = "method")]` | You provide `fn method(inner: &T, arena) -> DeviceT` |
|
||||||
|
/// | `#[device(variant_type = "T")]` | Override the device-side variant's inner type |
|
||||||
|
///
|
||||||
|
/// # Container attribute
|
||||||
|
///
|
||||||
|
/// `#[device(name = "DeviceFoo")]` — override the generated type name (default: `Device{Name}`).
|
||||||
|
#[proc_macro_derive(Device, attributes(device))]
|
||||||
|
pub fn derive_device(input: TokenStream) -> TokenStream {
|
||||||
|
let input = parse_macro_input!(input as DeriveInput);
|
||||||
|
match derive_impl(input) {
|
||||||
|
Ok(tokens) => tokens.into(),
|
||||||
|
Err(e) => e.to_compile_error().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
|
||||||
|
match &input.data {
|
||||||
|
Data::Struct(_) => derive_struct(input),
|
||||||
|
Data::Enum(_) => derive_enum(input),
|
||||||
|
Data::Union(_) => Err(syn::Error::new_spanned(
|
||||||
|
&input.ident,
|
||||||
|
"Device derive does not support unions",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Struct derivation
|
||||||
|
|
||||||
|
fn derive_struct(input: DeriveInput) -> syn::Result<TokenStream2> {
|
||||||
|
let host_name = &input.ident;
|
||||||
|
let vis = &input.vis;
|
||||||
|
let device_name = get_device_name(&input.attrs, host_name)?;
|
||||||
|
|
||||||
|
let fields = match &input.data {
|
||||||
|
Data::Struct(s) => match &s.fields {
|
||||||
|
Fields::Named(named) => &named.named,
|
||||||
|
_ => {
|
||||||
|
return Err(syn::Error::new_spanned(
|
||||||
|
host_name,
|
||||||
|
"Device derive only supports structs with named fields",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut device_fields = Vec::new();
|
||||||
|
let mut upload_stmts = Vec::new();
|
||||||
|
let mut device_field_inits = Vec::new();
|
||||||
|
let mut spread_expr: Option<Expr> = None;
|
||||||
|
|
||||||
|
for field in fields {
|
||||||
|
let field_name = field.ident.as_ref().unwrap();
|
||||||
|
let attrs = parse_field_attrs(&field.attrs)?;
|
||||||
|
|
||||||
|
if attrs.skip {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref expr_str) = attrs.spread {
|
||||||
|
spread_expr = Some(syn::parse_str(expr_str).map_err(|e| {
|
||||||
|
syn::Error::new_spanned(field, format!("invalid device(spread): {}", e))
|
||||||
|
})?);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(expr_str) = &attrs.expr {
|
||||||
|
let expr: Expr = syn::parse_str(expr_str).map_err(|e| {
|
||||||
|
syn::Error::new_spanned(field, format!("invalid device(expr): {}", e))
|
||||||
|
})?;
|
||||||
|
let ty = &field.ty;
|
||||||
|
device_fields.push(quote! { pub #field_name: #ty });
|
||||||
|
upload_stmts.push(quote! { let #field_name = #expr; });
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match classify_type(&field.ty) {
|
||||||
|
FieldClass::VecCopy(inner_ty) => {
|
||||||
|
let len_name = format_ident!("{}_len", field_name);
|
||||||
|
device_fields.push(quote! { pub #field_name: Ptr<#inner_ty> });
|
||||||
|
device_fields.push(quote! { pub #len_name: usize });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let (#field_name, #len_name) = arena.alloc_slice(&self.#field_name);
|
||||||
|
});
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
device_field_inits.push(quote! { #len_name });
|
||||||
|
}
|
||||||
|
FieldClass::VecUploadable(inner_ty) => {
|
||||||
|
let len_name = format_ident!("{}_len", field_name);
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
device_fields.push(quote! { pub #len_name: usize });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let __up: Vec<<#inner_ty as DeviceRepr>::Target> = self.#field_name
|
||||||
|
.iter()
|
||||||
|
.map(|item| DeviceRepr::upload_value(item, arena))
|
||||||
|
.collect();
|
||||||
|
let (#field_name, #len_name) = arena.alloc_slice(&__up);
|
||||||
|
});
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
device_field_inits.push(quote! { #len_name });
|
||||||
|
}
|
||||||
|
FieldClass::Option(inner_ty) => {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = match &self.#field_name {
|
||||||
|
Some(val) => DeviceRepr::upload(val, arena),
|
||||||
|
None => Ptr::null(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
}
|
||||||
|
FieldClass::Arc(inner_ty) => {
|
||||||
|
if attrs.flatten {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: <#inner_ty as DeviceRepr>::Target
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload_value(&*self.#field_name, arena);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload(&*self.#field_name, arena);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
}
|
||||||
|
FieldClass::Plain => {
|
||||||
|
let ty = &field.ty;
|
||||||
|
if attrs.copy_upload {
|
||||||
|
device_fields.push(quote! { pub #field_name: #ty });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = self.#field_name.clone();
|
||||||
|
});
|
||||||
|
} else if attrs.flatten {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: <#ty as DeviceRepr>::Target
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload_value(&self.#field_name, arena);
|
||||||
|
});
|
||||||
|
} else if attrs.upload {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload(&self.#field_name, arena);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
device_fields.push(quote! { pub #field_name: #ty });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = self.#field_name;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let constructor = if let Some(spread) = spread_expr {
|
||||||
|
quote! {
|
||||||
|
#device_name {
|
||||||
|
#(#device_field_inits,)*
|
||||||
|
..#spread
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
#device_name {
|
||||||
|
#(#device_field_inits,)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(quote! {
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
#vis struct #device_name {
|
||||||
|
#(#device_fields,)*
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for #device_name {}
|
||||||
|
unsafe impl Sync for #device_name {}
|
||||||
|
|
||||||
|
impl DeviceRepr for #host_name {
|
||||||
|
type Target = #device_name;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target {
|
||||||
|
#(#upload_stmts)*
|
||||||
|
#constructor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enum derivation
|
||||||
|
fn derive_enum(input: DeriveInput) -> syn::Result<TokenStream2> {
|
||||||
|
let host_name = &input.ident;
|
||||||
|
let vis = &input.vis;
|
||||||
|
let device_name = get_device_name(&input.attrs, host_name)?;
|
||||||
|
|
||||||
|
let variants = match &input.data {
|
||||||
|
Data::Enum(e) => &e.variants,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut device_variants = Vec::new();
|
||||||
|
let mut match_arms = Vec::new();
|
||||||
|
|
||||||
|
for variant in variants {
|
||||||
|
let var_name = &variant.ident;
|
||||||
|
let var_attrs = parse_variant_attrs(&variant.attrs)?;
|
||||||
|
let inner_ty = get_variant_inner_type(variant)?;
|
||||||
|
|
||||||
|
// Determine the device-side inner type for this variant
|
||||||
|
let device_inner: Type = if let Some(ref ty_str) = var_attrs.variant_type {
|
||||||
|
syn::parse_str(ty_str).map_err(|e| {
|
||||||
|
syn::Error::new_spanned(variant, format!("invalid variant_type: {}", e))
|
||||||
|
})?
|
||||||
|
} else if var_attrs.clone_variant {
|
||||||
|
// clone: same type on both sides
|
||||||
|
inner_ty.clone()
|
||||||
|
} else {
|
||||||
|
// auto-upload: use DeviceRepr::Target
|
||||||
|
syn::parse_str(&format!("<{} as DeviceRepr>::Target", quote!(#inner_ty))).map_err(
|
||||||
|
|e| {
|
||||||
|
syn::Error::new_spanned(variant, format!("cannot construct Target type: {}", e))
|
||||||
|
},
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
device_variants.push(quote! { #var_name(#device_inner) });
|
||||||
|
|
||||||
|
if var_attrs.clone_variant {
|
||||||
|
match_arms.push(quote! {
|
||||||
|
#host_name::#var_name(inner) => #device_name::#var_name(inner.clone())
|
||||||
|
});
|
||||||
|
} else if let Some(ref method) = var_attrs.custom {
|
||||||
|
let method_ident = format_ident!("{}", method);
|
||||||
|
match_arms.push(quote! {
|
||||||
|
#host_name::#var_name(inner) => {
|
||||||
|
#device_name::#var_name(Self::#method_ident(inner, arena))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Default: inner implements DeviceRepr
|
||||||
|
match_arms.push(quote! {
|
||||||
|
#host_name::#var_name(inner) => {
|
||||||
|
#device_name::#var_name(DeviceRepr::upload_value(inner, arena))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(quote! {
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
#vis enum #device_name {
|
||||||
|
#(#device_variants,)*
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for #device_name {}
|
||||||
|
unsafe impl Sync for #device_name {}
|
||||||
|
|
||||||
|
impl DeviceRepr for #host_name {
|
||||||
|
type Target = #device_name;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target {
|
||||||
|
match self {
|
||||||
|
#(#match_arms,)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_variant_inner_type(variant: &Variant) -> syn::Result<Type> {
|
||||||
|
match &variant.fields {
|
||||||
|
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
||||||
|
Ok(fields.unnamed.first().unwrap().ty.clone())
|
||||||
|
}
|
||||||
|
Fields::Unit => Err(syn::Error::new_spanned(
|
||||||
|
variant,
|
||||||
|
"Device derive: enum variants must have exactly one field, e.g. Variant(Type)",
|
||||||
|
)),
|
||||||
|
_ => Err(syn::Error::new_spanned(
|
||||||
|
variant,
|
||||||
|
"Device derive: only single-field tuple variants supported, e.g. Variant(Type)",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attribute parsing for variants
|
||||||
|
struct VariantAttrs {
|
||||||
|
clone_variant: bool,
|
||||||
|
custom: Option<String>,
|
||||||
|
variant_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_variant_attrs(attrs: &[Attribute]) -> syn::Result<VariantAttrs> {
|
||||||
|
let mut result = VariantAttrs {
|
||||||
|
clone_variant: false,
|
||||||
|
custom: None,
|
||||||
|
variant_type: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
for attr in attrs {
|
||||||
|
if !attr.path().is_ident("device") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
attr.parse_nested_meta(|meta| {
|
||||||
|
if meta.path.is_ident("clone") {
|
||||||
|
result.clone_variant = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("custom") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.custom = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else if meta.path.is_ident("variant_type") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.variant_type = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(meta.error("unknown device variant attribute"))
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attribute parsing for fields
|
||||||
|
struct FieldAttrs {
|
||||||
|
skip: bool,
|
||||||
|
expr: Option<String>,
|
||||||
|
copy_upload: bool,
|
||||||
|
flatten: bool,
|
||||||
|
upload: bool,
|
||||||
|
spread: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_field_attrs(attrs: &[Attribute]) -> syn::Result<FieldAttrs> {
|
||||||
|
let mut result = FieldAttrs {
|
||||||
|
skip: false,
|
||||||
|
expr: None,
|
||||||
|
copy_upload: false,
|
||||||
|
flatten: false,
|
||||||
|
upload: false,
|
||||||
|
spread: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
for attr in attrs {
|
||||||
|
if !attr.path().is_ident("device") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
attr.parse_nested_meta(|meta| {
|
||||||
|
if meta.path.is_ident("skip") {
|
||||||
|
result.skip = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("expr") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.expr = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else if meta.path.is_ident("copy_upload") {
|
||||||
|
result.copy_upload = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("flatten") {
|
||||||
|
result.flatten = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("upload") {
|
||||||
|
result.upload = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("spread") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.spread = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(meta.error("unknown device attribute"))
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Container-level name attribute
|
||||||
|
fn get_device_name(attrs: &[Attribute], host_name: &Ident) -> syn::Result<Ident> {
|
||||||
|
for attr in attrs {
|
||||||
|
if !attr.path().is_ident("device") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let mut name = None;
|
||||||
|
attr.parse_nested_meta(|meta| {
|
||||||
|
if meta.path.is_ident("name") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
name = Some(format_ident!("{}", s.value()));
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
if let Some(n) = name {
|
||||||
|
return Ok(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(format_ident!("Device{}", host_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type classification
|
||||||
|
enum FieldClass {
|
||||||
|
VecCopy(Type),
|
||||||
|
VecUploadable(Type),
|
||||||
|
Option(Type),
|
||||||
|
Arc(Type),
|
||||||
|
Plain,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn classify_type(ty: &Type) -> FieldClass {
|
||||||
|
if let Some(inner) = extract_generic_arg(ty, "Vec") {
|
||||||
|
if is_copy_primitive(&inner) {
|
||||||
|
FieldClass::VecCopy(inner)
|
||||||
|
} else {
|
||||||
|
FieldClass::VecUploadable(inner)
|
||||||
|
}
|
||||||
|
} else if let Some(inner) = extract_generic_arg(ty, "Option") {
|
||||||
|
FieldClass::Option(inner)
|
||||||
|
} else if let Some(inner) = extract_generic_arg(ty, "Arc") {
|
||||||
|
FieldClass::Arc(inner)
|
||||||
|
} else {
|
||||||
|
FieldClass::Plain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_generic_arg(ty: &Type, wrapper: &str) -> Option<Type> {
|
||||||
|
if let Type::Path(type_path) = ty {
|
||||||
|
let seg = type_path.path.segments.last()?;
|
||||||
|
if seg.ident != wrapper {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if let PathArguments::AngleBracketed(args) = &seg.arguments {
|
||||||
|
if let Some(GenericArgument::Type(inner)) = args.args.first() {
|
||||||
|
return Some(inner.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_copy_primitive(ty: &Type) -> bool {
|
||||||
|
if let Type::Path(type_path) = ty {
|
||||||
|
if let Some(seg) = type_path.path.segments.last() {
|
||||||
|
let name = seg.ident.to_string();
|
||||||
|
return matches!(
|
||||||
|
name.as_str(),
|
||||||
|
"f32"
|
||||||
|
| "f64"
|
||||||
|
| "u8"
|
||||||
|
| "u16"
|
||||||
|
| "u32"
|
||||||
|
| "u64"
|
||||||
|
| "i8"
|
||||||
|
| "i16"
|
||||||
|
| "i32"
|
||||||
|
| "i64"
|
||||||
|
| "usize"
|
||||||
|
| "isize"
|
||||||
|
| "bool"
|
||||||
|
| "Float"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -695,8 +695,8 @@ impl VarianceEstimator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, Default)]
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
|
#[derive(Debug, Copy, Clone, Default)]
|
||||||
pub struct PLSample {
|
pub struct PLSample {
|
||||||
pub p: Point2f,
|
pub p: Point2f,
|
||||||
pub pdf: Float,
|
pub pdf: Float,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use shared::core::aggregates::DeviceBVHAggregate;
|
use shared::core::aggregates::{{DeviceBVHAggregate, LinearBVHNode};
|
||||||
use shared::core::geometry::{Bounds3f, Point3f, Ray, Vector3f};
|
use shared::core::geometry::{Bounds3f, Point3f, Ray, Vector3f};
|
||||||
use shared::core::primitive::{Primitive, PrimitiveTrait};
|
use shared::core::primitive::{Primitive, PrimitiveTrait};
|
||||||
use shared::core::shape::ShapeIntersection;
|
use shared::core::shape::ShapeIntersection;
|
||||||
|
|
@ -26,14 +26,6 @@ struct BVHSplitBucket {
|
||||||
pub bounds: Bounds3f,
|
pub bounds: Bounds3f,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct LinearBVHNode {
|
|
||||||
pub bounds: Bounds3f,
|
|
||||||
pub primitives_offset: usize,
|
|
||||||
pub n_primitives: u16,
|
|
||||||
pub axis: u8,
|
|
||||||
pub pad: u8,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
struct MortonPrimitive {
|
struct MortonPrimitive {
|
||||||
|
|
@ -48,7 +40,7 @@ struct LBVHTreelet {
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct BVHPrimitiveInfo {
|
pub struct BVHPrimitiveInfo {
|
||||||
primitive_number: usize, // Index into the original primitives vector
|
primitive_number: usize,
|
||||||
bounds: Bounds3f,
|
bounds: Bounds3f,
|
||||||
centroid: Point3f,
|
centroid: Point3f,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
use crate::filters::*;
|
use crate::filters::*;
|
||||||
use crate::utils::containers::Array2D;
|
use crate::utils::containers::Array2D;
|
||||||
use crate::utils::sampling::PiecewiseConstant2D;
|
use crate::utils::sampling::PiecewiseConstant2D;
|
||||||
|
use crate::utils::DeviceRepr;
|
||||||
use crate::utils::{FileLoc, ParameterDictionary};
|
use crate::utils::{FileLoc, ParameterDictionary};
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use shared::Float;
|
use shared::core::filter::{DeviceFilterSampler, Filter};
|
||||||
use shared::core::filter::{Filter, FilterSampler};
|
|
||||||
use shared::core::geometry::{Bounds2f, Point2f, Vector2f};
|
use shared::core::geometry::{Bounds2f, Point2f, Vector2f};
|
||||||
use shared::filters::*;
|
use shared::filters::*;
|
||||||
|
use shared::Float;
|
||||||
|
|
||||||
pub trait FilterFactory {
|
pub trait FilterFactory {
|
||||||
fn create(name: &str, params: &ParameterDictionary, loc: &FileLoc) -> Result<Filter>;
|
fn create(name: &str, params: &ParameterDictionary, loc: &FileLoc) -> Result<Filter>;
|
||||||
|
|
@ -54,14 +55,16 @@ impl FilterFactory for Filter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CreateFilterSampler {
|
#[repr(C)]
|
||||||
fn new<F>(radius: Vector2f, func: F) -> Self
|
#[derive(Clone, Debug, Copy)]
|
||||||
where
|
pub struct FilterSampler {
|
||||||
F: Fn(Point2f) -> Float;
|
pub domain: Bounds2f,
|
||||||
|
pub distrib: PiecewiseConstant2D,
|
||||||
|
pub f: Array2D<Float>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CreateFilterSampler for FilterSampler {
|
impl FilterSampler {
|
||||||
fn new<F>(radius: Vector2f, func: F) -> Self
|
pub fn new<F>(radius: Vector2f, func: F) -> Self
|
||||||
where
|
where
|
||||||
F: Fn(Point2f) -> Float,
|
F: Fn(Point2f) -> Float,
|
||||||
{
|
{
|
||||||
|
|
@ -72,7 +75,6 @@ impl CreateFilterSampler for FilterSampler {
|
||||||
|
|
||||||
let nx = (32.0 * radius.x()) as i32;
|
let nx = (32.0 * radius.x()) as i32;
|
||||||
let ny = (32.0 * radius.y()) as i32;
|
let ny = (32.0 * radius.y()) as i32;
|
||||||
|
|
||||||
let mut f = Array2D::new_dims(nx, ny);
|
let mut f = Array2D::new_dims(nx, ny);
|
||||||
for y in 0..f.y_size() {
|
for y in 0..f.y_size() {
|
||||||
for x in 0..f.x_size() {
|
for x in 0..f.x_size() {
|
||||||
|
|
@ -83,11 +85,21 @@ impl CreateFilterSampler for FilterSampler {
|
||||||
f[(x as i32, y as i32)] = func(p);
|
f[(x as i32, y as i32)] = func(p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let distrib = PiecewiseConstant2D::new_with_bounds(&f, domain);
|
let distrib = PiecewiseConstant2D::new_with_bounds(&f, domain);
|
||||||
Self {
|
|
||||||
domain,
|
Self { domain, distrib, f }
|
||||||
f: *f.device(),
|
}
|
||||||
distrib: distrib.device,
|
}
|
||||||
|
|
||||||
|
impl DeviceRepr for FilterSampler {
|
||||||
|
type Target = DeviceFilterSampler;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> DeviceFilterSampler {
|
||||||
|
DeviceFilterSampler {
|
||||||
|
domain: self.domain,
|
||||||
|
distrib: self.distrib.upload_value(arena),
|
||||||
|
f: self.f.upload_value(arena),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,27 +5,26 @@ use shared::core::geometry::{Point2f, Vector2f};
|
||||||
use shared::filters::GaussianFilter;
|
use shared::filters::GaussianFilter;
|
||||||
use shared::utils::math::gaussian;
|
use shared::utils::math::gaussian;
|
||||||
|
|
||||||
pub trait GaussianFilterCreator {
|
#[derive(Clone, Debug, Device)]
|
||||||
fn new(radius: Vector2f, sigma: Float) -> Self;
|
#[device(name = "GaussianFilter")]
|
||||||
|
pub struct GaussianFilterHost {
|
||||||
|
pub radius: Vector2f,
|
||||||
|
pub sigma: Float,
|
||||||
|
pub exp_x: Float,
|
||||||
|
pub exp_y: Float,
|
||||||
|
#[device(flatten)]
|
||||||
|
pub sampler: FilterSampler,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GaussianFilterCreator for GaussianFilter {
|
impl GaussianFilterHost {
|
||||||
fn new(radius: Vector2f, sigma: Float) -> Self {
|
pub fn new(radius: Vector2f, sigma: Float) -> Self {
|
||||||
let exp_x = gaussian(radius.x(), 0., sigma);
|
let exp_x = gaussian(radius.x(), 0.0, sigma);
|
||||||
let exp_y = gaussian(radius.y(), 0., sigma);
|
let exp_y = gaussian(radius.y(), 0.0, sigma);
|
||||||
|
|
||||||
let sampler = FilterSampler::new(radius, move |p: Point2f| {
|
let sampler = FilterSampler::new(radius, move |p: Point2f| {
|
||||||
let gx = (gaussian(p.x(), 0., sigma) - exp_x).max(0.0);
|
let gx = (gaussian(p.x(), 0.0, sigma) - exp_x).max(0.0);
|
||||||
let gy = (gaussian(p.y(), 0., sigma) - exp_y).max(0.0);
|
let gy = (gaussian(p.y(), 0.0, sigma) - exp_y).max(0.0);
|
||||||
gx * gy
|
gx * gy
|
||||||
});
|
});
|
||||||
|
Self { radius, sigma, exp_x, exp_y, sampler }
|
||||||
Self {
|
|
||||||
radius,
|
|
||||||
sigma,
|
|
||||||
exp_x: gaussian(radius.x(), 0., sigma),
|
|
||||||
exp_y: gaussian(radius.y(), 0., sigma),
|
|
||||||
sampler,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,43 +6,35 @@ use shared::core::geometry::{Point2f, Vector2f};
|
||||||
use shared::filters::LanczosSincFilter;
|
use shared::filters::LanczosSincFilter;
|
||||||
use shared::utils::math::{lerp, windowed_sinc};
|
use shared::utils::math::{lerp, windowed_sinc};
|
||||||
|
|
||||||
pub trait LanczosFilterCreator {
|
#[derive(Clone, Debug, Device)]
|
||||||
fn new(radius: Vector2f, tau: Float) -> Self;
|
#[device(name = "LanczosSincFilter")]
|
||||||
|
pub struct LanczosSincFilterHost {
|
||||||
|
pub radius: Vector2f,
|
||||||
|
pub tau: Float,
|
||||||
|
#[device(flatten)]
|
||||||
|
pub sampler: FilterSampler,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanczosFilterCreator for LanczosSincFilter {
|
impl LanczosSincFilterHost {
|
||||||
fn new(radius: Vector2f, tau: Float) -> Self {
|
pub fn new(radius: Vector2f, tau: Float) -> Self {
|
||||||
let evaluate = |p: Point2f| -> Float {
|
let sampler = FilterSampler::new(radius, move |p: Point2f| {
|
||||||
windowed_sinc(p.x(), radius.x(), tau) * windowed_sinc(p.y(), radius.y(), tau)
|
windowed_sinc(p.x(), radius.x(), tau) * windowed_sinc(p.y(), radius.y(), tau)
|
||||||
};
|
});
|
||||||
|
Self { radius, tau, sampler }
|
||||||
let sampler = FilterSampler::new(radius, evaluate);
|
}
|
||||||
let sqrt_samples = 64;
|
}
|
||||||
let n_samples = sqrt_samples * sqrt_samples;
|
|
||||||
let area = (2.0 * radius.x()) * (2.0 * radius.y());
|
fn windowed_sinc(x: Float, radius: Float, tau: Float) -> Float {
|
||||||
let mut sum = 0.0;
|
use std::f32::consts::PI;
|
||||||
let mut rng = rand::rng();
|
let x = x.abs();
|
||||||
|
if x > radius {
|
||||||
for y in 0..sqrt_samples {
|
return 0.0;
|
||||||
for x in 0..sqrt_samples {
|
}
|
||||||
let u = Point2f::new(
|
if x < 1e-5 {
|
||||||
(x as Float + rng.random::<Float>()) / sqrt_samples as Float,
|
1.0
|
||||||
(y as Float + rng.random::<Float>()) / sqrt_samples as Float,
|
} else {
|
||||||
);
|
let xpi = x * PI;
|
||||||
let p = Point2f::new(
|
let xpit = xpi * tau;
|
||||||
lerp(u.x(), -radius.x(), radius.x()),
|
(xpi.sin() / xpi) * (xpit.sin() / xpit)
|
||||||
lerp(u.y(), -radius.y(), radius.y()),
|
|
||||||
);
|
|
||||||
sum += evaluate(p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let integral = sum / n_samples as Float * area;
|
|
||||||
|
|
||||||
Self {
|
|
||||||
radius,
|
|
||||||
tau,
|
|
||||||
sampler,
|
|
||||||
integral,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,23 +4,39 @@ use shared::core::filter::FilterSampler;
|
||||||
use shared::core::geometry::{Point2f, Vector2f};
|
use shared::core::geometry::{Point2f, Vector2f};
|
||||||
use shared::filters::MitchellFilter;
|
use shared::filters::MitchellFilter;
|
||||||
|
|
||||||
pub trait MitchellFilterCreator {
|
#[derive(Clone, Debug, Device)]
|
||||||
fn new(radius: Vector2f, b: Float, c: Float) -> Self;
|
#[device(name = "MitchellFilter")]
|
||||||
|
pub struct MitchellFilterHost {
|
||||||
|
pub radius: Vector2f,
|
||||||
|
pub b: Float,
|
||||||
|
pub c: Float,
|
||||||
|
#[device(flatten)]
|
||||||
|
pub sampler: FilterSampler,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MitchellFilterCreator for MitchellFilter {
|
impl MitchellFilterHost {
|
||||||
fn new(radius: Vector2f, b: Float, c: Float) -> Self {
|
pub fn new(radius: Vector2f, b: Float, c: Float) -> Self {
|
||||||
let sampler = FilterSampler::new(radius, move |p: Point2f| {
|
let sampler = FilterSampler::new(radius, move |p: Point2f| {
|
||||||
let nx = 2.0 * p.x() / radius.x();
|
mitchell_1d(p.x() / radius.x(), b, c) * mitchell_1d(p.y() / radius.y(), b, c)
|
||||||
let ny = 2.0 * p.y() / radius.y();
|
|
||||||
Self::mitchell_1d_eval(b, c, nx) * Self::mitchell_1d_eval(b, c, ny)
|
|
||||||
});
|
});
|
||||||
|
Self { radius, b, c, sampler }
|
||||||
Self {
|
}
|
||||||
radius,
|
}
|
||||||
b,
|
|
||||||
c,
|
fn mitchell_1d(x: Float, b: Float, c: Float) -> Float {
|
||||||
sampler,
|
let x = (2.0 * x).abs();
|
||||||
}
|
if x <= 1.0 {
|
||||||
|
((12.0 - 9.0 * b - 6.0 * c) * x * x * x
|
||||||
|
+ (-18.0 + 12.0 * b + 6.0 * c) * x * x
|
||||||
|
+ (6.0 - 2.0 * b))
|
||||||
|
/ 6.0
|
||||||
|
} else if x <= 2.0 {
|
||||||
|
((-b - 6.0 * c) * x * x * x
|
||||||
|
+ (6.0 * b + 30.0 * c) * x * x
|
||||||
|
+ (-12.0 * b - 48.0 * c) * x
|
||||||
|
+ (8.0 * b + 24.0 * c))
|
||||||
|
/ 6.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,16 +21,18 @@ pub struct TriQuadMesh {
|
||||||
pub quad_indices: Vec<i32>,
|
pub quad_indices: Vec<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(DeviceRepr)]
|
||||||
pub(crate) struct TriangleMeshStorage {
|
#[device(name = "DeviceTriangleMesh")]
|
||||||
pub p: Vec<Point3f>,
|
pub struct TriangleMeshStorage {
|
||||||
pub n: Vec<Normal3f>,
|
pub vertex_indices: Vec<i32>, // → Ptr<i32> + len (always present)
|
||||||
pub s: Vec<Vector3f>,
|
pub p: Vec<Point3f>, // → Ptr<Point3f> + len (always present)
|
||||||
pub uv: Vec<Point2f>,
|
pub n: Vec<Normal3f>, // → Ptr<Normal3f> + len (empty → null Ptr, len 0)
|
||||||
pub vertex_indices: Vec<i32>,
|
pub s: Vec<Vector3f>, // → Ptr<Vector3f> + len
|
||||||
pub face_indices: Vec<i32>,
|
pub uv: Vec<Point2f>, // → Ptr<Point2f> + len
|
||||||
|
pub face_indices: Vec<i32>, // → Ptr<i32> + len
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct BilinearMeshStorage {
|
pub(crate) struct BilinearMeshStorage {
|
||||||
pub vertex_indices: Vec<i32>,
|
pub vertex_indices: Vec<i32>,
|
||||||
|
|
|
||||||
|
|
@ -1,147 +1,148 @@
|
||||||
use crate::core::image::Image;
|
|
||||||
use crate::core::texture::{FloatTexture, SpectrumTexture};
|
|
||||||
use crate::shapes::{BilinearPatchMesh, TriangleMesh};
|
|
||||||
use crate::spectra::DenselySampledSpectrumBuffer;
|
|
||||||
use crate::utils::backend::GpuAllocator;
|
use crate::utils::backend::GpuAllocator;
|
||||||
use crate::utils::mipmap::MIPMap;
|
|
||||||
use crate::utils::sampling::{PiecewiseConstant2D, WindowedPiecewiseConstant2D};
|
|
||||||
use parking_lot::Mutex;
|
|
||||||
use shared::core::color::RGBToSpectrumTable;
|
|
||||||
use shared::core::image::DeviceImage;
|
|
||||||
use shared::core::light::Light;
|
|
||||||
use shared::core::material::Material;
|
|
||||||
use shared::core::shape::Shape;
|
use shared::core::shape::Shape;
|
||||||
|
use shared::core::material::Material;
|
||||||
use shared::core::spectrum::Spectrum;
|
use shared::core::spectrum::Spectrum;
|
||||||
use shared::core::texture::{GPUFloatTexture, GPUSpectrumTexture};
|
use parking_lot::Mutex;
|
||||||
use shared::spectra::{DenselySampledSpectrum, DeviceStandardColorSpaces, RGBColorSpace};
|
use shared::Ptr;
|
||||||
use shared::textures::*;
|
|
||||||
use shared::utils::mesh::{DeviceBilinearPatchMesh, DeviceTriangleMesh};
|
|
||||||
use shared::utils::sampling::{
|
|
||||||
DevicePiecewiseConstant1D, DevicePiecewiseConstant2D, DeviceWindowedPiecewiseConstant2D,
|
|
||||||
};
|
|
||||||
use shared::utils::Ptr;
|
|
||||||
use std::alloc::Layout;
|
use std::alloc::Layout;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::panic::Location;
|
use std::panic::Location;
|
||||||
use std::slice::from_raw_parts;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub struct Arena<A: GpuAllocator> {
|
struct Chunk {
|
||||||
|
ptr: *mut u8,
|
||||||
|
layout: Layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GpuBump<A: GpuAllocator> {
|
||||||
allocator: A,
|
allocator: A,
|
||||||
inner: Mutex<ArenaInner>,
|
current: *mut u8,
|
||||||
|
end: *mut u8,
|
||||||
|
chunks: Vec<Chunk>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ArenaInner {
|
const CHUNK_SIZE: usize = 256 * 1024;
|
||||||
blocks: Vec<(*mut u8, Layout)>,
|
|
||||||
current_block: *mut u8,
|
impl<A: GpuAllocator> GpuBump<A> {
|
||||||
current_offset: usize,
|
fn new(allocator: A) -> Self {
|
||||||
current_capacity: usize,
|
Self {
|
||||||
current_align: usize,
|
allocator,
|
||||||
texture_cache: HashMap<usize, u64>,
|
current: std::ptr::null_mut(),
|
||||||
|
end: std::ptr::null_mut(),
|
||||||
|
chunks: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc<T>(&mut self, value: T) -> *mut T {
|
||||||
|
let layout = Layout::new::<T>();
|
||||||
|
let ptr = self.alloc_layout(layout) as *mut T;
|
||||||
|
unsafe { ptr.write(value) };
|
||||||
|
ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc_slice<T: Copy>(&mut self, values: &[T]) -> (*mut T, usize) {
|
||||||
|
if values.is_empty() {
|
||||||
|
return (std::ptr::null_mut(), 0);
|
||||||
|
}
|
||||||
|
let layout = Layout::array::<T>(values.len()).unwrap();
|
||||||
|
let ptr = self.alloc_layout(layout) as *mut T;
|
||||||
|
unsafe {
|
||||||
|
std::ptr::copy_nonoverlapping(values.as_ptr(), ptr, values.len());
|
||||||
|
}
|
||||||
|
(ptr, values.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc_layout(&mut self, layout: Layout) -> *mut u8 {
|
||||||
|
let size = layout.size();
|
||||||
|
let align = layout.align();
|
||||||
|
|
||||||
|
if size == 0 {
|
||||||
|
return align as *mut u8;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fast path: bump from current chunk
|
||||||
|
let start = self.current as usize;
|
||||||
|
let aligned = (start + align - 1) & !(align - 1);
|
||||||
|
let end = aligned + size;
|
||||||
|
|
||||||
|
if end <= self.end as usize {
|
||||||
|
self.current = end as *mut u8;
|
||||||
|
return aligned as *mut u8;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk_size = if size > CHUNK_SIZE {
|
||||||
|
size.next_multiple_of(align.max(16))
|
||||||
|
} else {
|
||||||
|
CHUNK_SIZE.max(size.next_multiple_of(align.max(16)))
|
||||||
|
};
|
||||||
|
|
||||||
|
let chunk_layout = Layout::from_size_align(chunk_size, align.max(16)).unwrap();
|
||||||
|
let chunk = unsafe { self.allocator.alloc(chunk_layout) };
|
||||||
|
|
||||||
|
let caller = Location::caller();
|
||||||
|
if chunk.is_null() {
|
||||||
|
panic!(
|
||||||
|
"GpuBump OOM {} {}: chunk_size={} align={} backend={}",
|
||||||
|
caller.file(),
|
||||||
|
caller.line(),
|
||||||
|
chunk_size,
|
||||||
|
chunk_layout.align(),
|
||||||
|
std::any::type_name::<A>()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.chunks.push(Chunk {
|
||||||
|
ptr: chunk,
|
||||||
|
layout: chunk_layout,
|
||||||
|
});
|
||||||
|
|
||||||
|
// If this object alone fills the chunk, mark it consumed and return.
|
||||||
|
if size > CHUNK_SIZE {
|
||||||
|
self.current = unsafe { chunk.add(chunk_size) };
|
||||||
|
self.end = self.current;
|
||||||
|
return chunk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up bump pointers inside the new chunk.
|
||||||
|
self.current = chunk;
|
||||||
|
self.end = unsafe { chunk.add(chunk_size) };
|
||||||
|
|
||||||
|
let aligned = (self.current as usize + align - 1) & !(align - 1);
|
||||||
|
self.current = (aligned + size) as *mut u8;
|
||||||
|
aligned as *mut u8
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_BLOCK_SIZE: usize = 256 * 1024;
|
impl<A: GpuAllocator> Drop for GpuBump<A> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
for chunk in self.chunks.drain(..) {
|
||||||
|
unsafe { self.allocator.dealloc(chunk.ptr, chunk.layout) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl<A: GpuAllocator> Send for GpuBump<A> {}
|
||||||
|
unsafe impl<A: GpuAllocator> Sync for GpuBump<A> {}
|
||||||
|
|
||||||
|
pub struct Arena<A: GpuAllocator> {
|
||||||
|
bump: Mutex<GpuBump<A>>,
|
||||||
|
texture_cache: Mutex<HashMap<usize, u64>>,
|
||||||
|
}
|
||||||
|
|
||||||
impl<A: GpuAllocator> Arena<A> {
|
impl<A: GpuAllocator> Arena<A> {
|
||||||
pub fn new(allocator: A) -> Self {
|
pub fn new(allocator: A) -> Self {
|
||||||
Self {
|
Self {
|
||||||
allocator,
|
bump: Mutex::new(GpuBump::new(allocator)),
|
||||||
inner: Mutex::new(ArenaInner {
|
texture_cache: Mutex::new(HashMap::new()),
|
||||||
blocks: Vec::new(),
|
|
||||||
current_block: std::ptr::null_mut(),
|
|
||||||
current_offset: 0,
|
|
||||||
current_capacity: 0,
|
|
||||||
current_align: 1,
|
|
||||||
texture_cache: HashMap::new(),
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn alloc<T>(&self, value: T) -> Ptr<T> {
|
pub fn alloc<T>(&self, value: T) -> Ptr<T> {
|
||||||
let layout = Layout::new::<T>();
|
let mut bump = self.bump.lock();
|
||||||
let mut inner = self.inner.lock();
|
let ptr = bump.alloc(value);
|
||||||
|
|
||||||
let aligned = (inner.current_offset + layout.align() - 1) & !(layout.align() - 1);
|
|
||||||
|
|
||||||
// Checking if current block alignment is sufficient
|
|
||||||
if aligned + layout.size() > inner.current_capacity || inner.current_align < layout.align()
|
|
||||||
{
|
|
||||||
let block_size = DEFAULT_BLOCK_SIZE.max(layout.size() * 2);
|
|
||||||
let block_layout = Layout::from_size_align(block_size, layout.align().max(16)).unwrap();
|
|
||||||
let block = unsafe { self.allocator.alloc(block_layout) };
|
|
||||||
|
|
||||||
// null check
|
|
||||||
if block.is_null() {
|
|
||||||
panic!(
|
|
||||||
"Arena alloc failed at {}:{} — size={} align={}",
|
|
||||||
caller.file(),
|
|
||||||
caller.line(),
|
|
||||||
block_layout.size(),
|
|
||||||
block_layout.align()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
inner.blocks.push((block, block_layout));
|
|
||||||
inner.current_block = block;
|
|
||||||
inner.current_offset = 0;
|
|
||||||
inner.current_capacity = block_size;
|
|
||||||
inner.current_align = block_layout.align(); // NEW
|
|
||||||
|
|
||||||
let aligned = (0 + layout.align() - 1) & !(layout.align() - 1);
|
|
||||||
let ptr = unsafe { inner.current_block.add(aligned) as *mut T };
|
|
||||||
unsafe { ptr.write(value) };
|
|
||||||
inner.current_offset = aligned + layout.size();
|
|
||||||
return Ptr::from_raw(ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
let ptr = unsafe { inner.current_block.add(aligned) as *mut T };
|
|
||||||
unsafe { ptr.write(value) };
|
|
||||||
inner.current_offset = aligned + layout.size();
|
|
||||||
Ptr::from_raw(ptr)
|
Ptr::from_raw(ptr)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn alloc_slice<T: Copy>(&self, values: &[T]) -> (Ptr<T>, usize) {
|
|
||||||
if values.is_empty() {
|
|
||||||
return (Ptr::null(), 0);
|
|
||||||
}
|
|
||||||
let layout = Layout::array::<T>(values.len()).unwrap();
|
|
||||||
let mut inner = self.inner.lock();
|
|
||||||
|
|
||||||
let aligned = (inner.current_offset + layout.align() - 1) & !(layout.align() - 1);
|
|
||||||
|
|
||||||
if aligned + layout.size() > inner.current_capacity || inner.current_align < layout.align()
|
|
||||||
{
|
|
||||||
let block_size = DEFAULT_BLOCK_SIZE.max(layout.size() * 2);
|
|
||||||
let block_layout = Layout::from_size_align(block_size, layout.align().max(16)).unwrap();
|
|
||||||
let block = unsafe { self.allocator.alloc(block_layout) };
|
|
||||||
|
|
||||||
if block.is_null() {
|
|
||||||
panic!(
|
|
||||||
"Arena allocation failed: size={} align={}",
|
|
||||||
block_layout.size(),
|
|
||||||
block_layout.align()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
inner.blocks.push((block, block_layout));
|
|
||||||
inner.current_block = block;
|
|
||||||
inner.current_offset = 0;
|
|
||||||
inner.current_capacity = block_size;
|
|
||||||
inner.current_align = block_layout.align(); // NEW
|
|
||||||
|
|
||||||
let aligned = 0;
|
|
||||||
let ptr = unsafe { inner.current_block.add(aligned) as *mut T };
|
|
||||||
unsafe { std::ptr::copy_nonoverlapping(values.as_ptr(), ptr, values.len()) };
|
|
||||||
inner.current_offset = aligned + layout.size();
|
|
||||||
return (Ptr::from_raw(ptr), values.len());
|
|
||||||
}
|
|
||||||
|
|
||||||
let ptr = unsafe { inner.current_block.add(aligned) as *mut T };
|
|
||||||
unsafe { std::ptr::copy_nonoverlapping(values.as_ptr(), ptr, values.len()) };
|
|
||||||
inner.current_offset = aligned + layout.size();
|
|
||||||
(Ptr::from_raw(ptr), values.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn alloc_opt<T>(&self, value: Option<T>) -> Ptr<T> {
|
pub fn alloc_opt<T>(&self, value: Option<T>) -> Ptr<T> {
|
||||||
match value {
|
match value {
|
||||||
Some(v) => self.alloc(v),
|
Some(v) => self.alloc(v),
|
||||||
|
|
@ -149,20 +150,20 @@ impl<A: GpuAllocator> Arena<A> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn alloc_slice<T: Copy>(&self, values: &[T]) -> (Ptr<T>, usize) {
|
||||||
|
let mut bump = self.bump.lock();
|
||||||
|
let (ptr, len) = bump.alloc_slice(values);
|
||||||
|
(Ptr::from_raw(ptr), len)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_texture_object(&self, mipmap: &Arc<MIPMap>) -> u64 {
|
pub fn get_texture_object(&self, mipmap: &Arc<MIPMap>) -> u64 {
|
||||||
let key = Arc::as_ptr(mipmap) as usize;
|
let key = Arc::as_ptr(mipmap) as usize;
|
||||||
let mut inner = self.inner.lock();
|
let mut cache = self.texture_cache.lock();
|
||||||
|
if let Some(&tex_obj) = cache.get(&key) {
|
||||||
if let Some(&tex_obj) = inner.texture_cache.get(&key) {
|
|
||||||
return tex_obj;
|
return tex_obj;
|
||||||
}
|
}
|
||||||
|
let tex_obj = 0u64; // TODO: backend-specific creation
|
||||||
// TODO: Backend-specific texture object creation.
|
cache.insert(key, tex_obj);
|
||||||
// CUDA: cudaCreateTextureObject
|
|
||||||
// Vulkan: VkImageView + VkSampler -> descriptor index
|
|
||||||
let tex_obj = 0u64;
|
|
||||||
|
|
||||||
inner.texture_cache.insert(key, tex_obj);
|
|
||||||
tex_obj
|
tex_obj
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -173,76 +174,279 @@ impl<A: GpuAllocator + Default> Default for Arena<A> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A: GpuAllocator> Drop for Arena<A> {
|
pub trait DeviceRepr {
|
||||||
fn drop(&mut self) {
|
/// The `#[repr(C)] Copy` device-side struct.
|
||||||
let inner = self.inner.get_mut();
|
type Target: Copy;
|
||||||
for (ptr, layout) in inner.blocks.drain(..) {
|
|
||||||
unsafe { self.allocator.dealloc(ptr, layout) };
|
/// Upload into the arena and return the device struct by value.
|
||||||
|
/// Use this when embedding the result inline in another device struct.
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target;
|
||||||
|
|
||||||
|
/// Upload into the arena and return a Ptr to the device struct.
|
||||||
|
/// This is the common entry point — allocates the Target in the arena.
|
||||||
|
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
||||||
|
let value = self.upload_value(arena);
|
||||||
|
arena.alloc(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: DeviceRepr> DeviceRepr for Option<T> {
|
||||||
|
type Target = T::Target;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target {
|
||||||
|
match self {
|
||||||
|
Some(val) => val.upload_value(arena),
|
||||||
|
None => panic!("Cannot upload_value on None — use upload() which returns Ptr::null()"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
||||||
|
match self {
|
||||||
|
Some(val) => val.upload(arena),
|
||||||
|
None => Ptr::null(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<A: GpuAllocator> Send for Arena<A> {}
|
impl<T: DeviceRepr> DeviceRepr for std::sync::Arc<T> {
|
||||||
unsafe impl<A: GpuAllocator> Sync for Arena<A> {}
|
type Target = T::Target;
|
||||||
|
|
||||||
pub trait Upload {
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target {
|
||||||
type Target: Copy;
|
(**self).upload_value(arena)
|
||||||
|
}
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target>;
|
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
||||||
|
(**self).upload(arena)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upload for Shape {
|
impl<T: DeviceRepr> DeviceRepr for Box<T> {
|
||||||
|
type Target = T::Target;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target {
|
||||||
|
(**self).upload_value(arena)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
||||||
|
(**self).upload(arena)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeviceRepr for Shape {
|
||||||
type Target = Shape;
|
type Target = Shape;
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
fn upload_value<A: GpuAllocator>(&self, _arena: &Arena<A>) -> Shape {
|
||||||
arena.alloc(self.clone())
|
self.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upload for Light {
|
impl DeviceRepr for Light {
|
||||||
type Target = Light;
|
type Target = Light;
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
fn upload_value<A: GpuAllocator>(&self, _arena: &Arena<A>) -> Light {
|
||||||
arena.alloc(self.clone())
|
self.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upload for Image {
|
impl DeviceRepr for Spectrum {
|
||||||
type Target = DeviceImage;
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
arena.alloc(*self.device())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for Spectrum {
|
|
||||||
type Target = Spectrum;
|
type Target = Spectrum;
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
fn upload_value<A: GpuAllocator>(&self, _arena: &Arena<A>) -> Spectrum {
|
||||||
arena.alloc(self.clone())
|
self.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upload for Material {
|
impl DeviceRepr for Material {
|
||||||
type Target = Material;
|
type Target = Material;
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
fn upload_value<A: GpuAllocator>(&self, _arena: &Arena<A>) -> Material {
|
||||||
arena.alloc(self.clone())
|
self.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upload for DenselySampledSpectrumBuffer {
|
// =============================================================================
|
||||||
|
// Image → DeviceImage
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
impl DeviceRepr for Image {
|
||||||
|
type Target = DeviceImage;
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, _arena: &Arena<A>) -> DeviceImage {
|
||||||
|
*self.device()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// DenselySampledSpectrumBuffer → DenselySampledSpectrum
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
impl DeviceRepr for DenselySampledSpectrumBuffer {
|
||||||
type Target = DenselySampledSpectrum;
|
type Target = DenselySampledSpectrum;
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
fn upload_value<A: GpuAllocator>(&self, _arena: &Arena<A>) -> DenselySampledSpectrum {
|
||||||
arena.alloc(*&self.device())
|
self.device()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upload for SpectrumTexture {
|
// =============================================================================
|
||||||
|
// RGBToSpectrumTable — re-uploads Ptr fields into arena
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
impl DeviceRepr for RGBToSpectrumTable {
|
||||||
|
type Target = RGBToSpectrumTable;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> RGBToSpectrumTable {
|
||||||
|
let n_nodes = self.n_nodes as usize;
|
||||||
|
|
||||||
|
// Safety: these Ptrs point into static or previously-uploaded data;
|
||||||
|
// we're copying the contents into the arena for a new lifetime.
|
||||||
|
let z_slice = unsafe { from_raw_parts(self.z_nodes.as_raw(), n_nodes) };
|
||||||
|
let (z_ptr, _) = arena.alloc_slice(z_slice);
|
||||||
|
|
||||||
|
let n_coeffs = 3 * n_nodes.pow(3);
|
||||||
|
let coeffs_slice = unsafe { from_raw_parts(self.coeffs.as_raw(), n_coeffs) };
|
||||||
|
let (c_ptr, _) = arena.alloc_slice(coeffs_slice);
|
||||||
|
|
||||||
|
RGBToSpectrumTable {
|
||||||
|
z_nodes: z_ptr,
|
||||||
|
coeffs: c_ptr,
|
||||||
|
n_nodes: self.n_nodes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// RGBColorSpace — nested upload of spectrum table
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
impl DeviceRepr for RGBColorSpace {
|
||||||
|
type Target = RGBColorSpace;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> RGBColorSpace {
|
||||||
|
let table_ptr = self.rgb_to_spectrum_table.upload(arena);
|
||||||
|
|
||||||
|
RGBColorSpace {
|
||||||
|
r: self.r,
|
||||||
|
g: self.g,
|
||||||
|
b: self.b,
|
||||||
|
w: self.w,
|
||||||
|
illuminant: self.illuminant.clone(),
|
||||||
|
rgb_to_spectrum_table: table_ptr,
|
||||||
|
xyz_from_rgb: self.xyz_from_rgb,
|
||||||
|
rgb_from_xyz: self.rgb_from_xyz,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// DeviceStandardColorSpaces — composition of color space uploads
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
impl DeviceRepr for DeviceStandardColorSpaces {
|
||||||
|
type Target = DeviceStandardColorSpaces;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> DeviceStandardColorSpaces {
|
||||||
|
DeviceStandardColorSpaces {
|
||||||
|
srgb: self.srgb.upload(arena),
|
||||||
|
dci_p3: self.dci_p3.upload(arena),
|
||||||
|
rec2020: self.rec2020.upload(arena),
|
||||||
|
aces2065_1: self.aces2065_1.upload(arena),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// TriangleMesh → DeviceTriangleMesh
|
||||||
|
// =============================================================================
|
||||||
|
// TriangleMesh should own all its fields directly — no .device field.
|
||||||
|
// The host struct holds Vec arrays + scalar metadata. derive(Device) handles it:
|
||||||
|
//
|
||||||
|
// #[derive(Device)]
|
||||||
|
// #[device(name = "DeviceTriangleMesh")]
|
||||||
|
// pub struct TriangleMesh {
|
||||||
|
// pub vertex_indices: Vec<i32>,
|
||||||
|
// pub p: Vec<Point3f>,
|
||||||
|
// pub n: Vec<Normal3f>,
|
||||||
|
// pub s: Vec<Vector3f>,
|
||||||
|
// pub uv: Vec<Point2f>,
|
||||||
|
// pub face_indices: Vec<i32>,
|
||||||
|
// pub n_triangles: u32,
|
||||||
|
// pub reverse_orientation: bool,
|
||||||
|
// pub transform_swaps_handedness: bool,
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Until the mesh struct is refactored to remove .device, here's a manual
|
||||||
|
// impl that lists every field. This is the MIGRATION TARGET — once TriangleMesh
|
||||||
|
// drops its .device field and puts scalars at the top level, switch to derive(Device).
|
||||||
|
|
||||||
|
impl DeviceRepr for TriangleMesh {
|
||||||
|
type Target = DeviceTriangleMesh;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> DeviceTriangleMesh {
|
||||||
|
let s = &self.storage;
|
||||||
|
|
||||||
|
let (vertex_indices_ptr, _) = arena.alloc_slice(&s.vertex_indices);
|
||||||
|
let (p_ptr, _) = arena.alloc_slice(&s.p);
|
||||||
|
let (n_ptr, _) = arena.alloc_slice(&s.n);
|
||||||
|
let (s_ptr, _) = arena.alloc_slice(&s.s);
|
||||||
|
let (uv_ptr, _) = arena.alloc_slice(&s.uv);
|
||||||
|
let (face_indices_ptr, _) = arena.alloc_slice(&s.face_indices);
|
||||||
|
|
||||||
|
DeviceTriangleMesh {
|
||||||
|
vertex_indices: vertex_indices_ptr,
|
||||||
|
p: p_ptr,
|
||||||
|
n: n_ptr,
|
||||||
|
s: s_ptr,
|
||||||
|
uv: uv_ptr,
|
||||||
|
face_indices: face_indices_ptr,
|
||||||
|
n_triangles: self.n_triangles,
|
||||||
|
reverse_orientation: self.reverse_orientation,
|
||||||
|
transform_swaps_handedness: self.transform_swaps_handedness,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// BilinearPatchMesh → DeviceBilinearPatchMesh
|
||||||
|
// =============================================================================
|
||||||
|
// Same as TriangleMesh: once the host struct is refactored to own scalars
|
||||||
|
// directly (no .device field), this switches to derive(Device).
|
||||||
|
|
||||||
|
impl DeviceRepr for BilinearPatchMesh {
|
||||||
|
type Target = DeviceBilinearPatchMesh;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> DeviceBilinearPatchMesh {
|
||||||
|
let s = &self.storage;
|
||||||
|
|
||||||
|
let (vertex_indices_ptr, _) = arena.alloc_slice(&s.vertex_indices);
|
||||||
|
let (p_ptr, _) = arena.alloc_slice(&s.p);
|
||||||
|
let (n_ptr, _) = arena.alloc_slice(&s.n);
|
||||||
|
let (uv_ptr, _) = arena.alloc_slice(&s.uv);
|
||||||
|
|
||||||
|
let image_dist_ptr = s.image_distribution.upload(arena);
|
||||||
|
|
||||||
|
DeviceBilinearPatchMesh {
|
||||||
|
vertex_indices: vertex_indices_ptr,
|
||||||
|
p: p_ptr,
|
||||||
|
n: n_ptr,
|
||||||
|
uv: uv_ptr,
|
||||||
|
image_distribution: image_dist_ptr,
|
||||||
|
n_patches: self.n_patches,
|
||||||
|
reverse_orientation: self.reverse_orientation,
|
||||||
|
transform_swaps_handedness: self.transform_swaps_handedness,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// SpectrumTexture → GPUSpectrumTexture
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
impl DeviceRepr for SpectrumTexture {
|
||||||
type Target = GPUSpectrumTexture;
|
type Target = GPUSpectrumTexture;
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let gpu_variant = match self {
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> GPUSpectrumTexture {
|
||||||
|
match self {
|
||||||
SpectrumTexture::Constant(tex) => GPUSpectrumTexture::Constant(tex.clone()),
|
SpectrumTexture::Constant(tex) => GPUSpectrumTexture::Constant(tex.clone()),
|
||||||
SpectrumTexture::Checkerboard(tex) => GPUSpectrumTexture::Checkerboard(tex.clone()),
|
SpectrumTexture::Checkerboard(tex) => GPUSpectrumTexture::Checkerboard(tex.clone()),
|
||||||
SpectrumTexture::Dots(tex) => GPUSpectrumTexture::Dots(tex.clone()),
|
SpectrumTexture::Dots(tex) => GPUSpectrumTexture::Dots(tex.clone()),
|
||||||
SpectrumTexture::Image(tex) => {
|
SpectrumTexture::Image(tex) => {
|
||||||
let tex_obj = arena.get_texture_object(&tex.base.mipmap);
|
let tex_obj = arena.get_texture_object(&tex.base.mipmap);
|
||||||
let gpu_img = GPUSpectrumImageTexture {
|
GPUSpectrumTexture::Image(GPUSpectrumImageTexture {
|
||||||
mapping: tex.base.mapping,
|
mapping: tex.base.mapping,
|
||||||
tex_obj,
|
tex_obj,
|
||||||
scale: tex.base.scale,
|
scale: tex.base.scale,
|
||||||
|
|
@ -255,55 +459,50 @@ impl Upload for SpectrumTexture {
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(crate::spectra::default_colorspace),
|
.unwrap_or_else(crate::spectra::default_colorspace),
|
||||||
spectrum_type: tex.spectrum_type,
|
spectrum_type: tex.spectrum_type,
|
||||||
};
|
})
|
||||||
GPUSpectrumTexture::Image(gpu_img)
|
|
||||||
}
|
}
|
||||||
SpectrumTexture::Bilerp(tex) => GPUSpectrumTexture::Bilerp(tex.clone()),
|
SpectrumTexture::Bilerp(tex) => GPUSpectrumTexture::Bilerp(tex.clone()),
|
||||||
SpectrumTexture::Scaled(tex) => {
|
SpectrumTexture::Scaled(tex) => {
|
||||||
let child_ptr = tex.tex.upload(arena);
|
let child_ptr = tex.tex.upload(arena);
|
||||||
|
let scale_ptr = tex.scale.upload(arena);
|
||||||
let gpu_scaled = GPUSpectrumScaledTexture {
|
GPUSpectrumTexture::Scaled(GPUSpectrumScaledTexture {
|
||||||
tex: child_ptr,
|
tex: child_ptr,
|
||||||
scale: tex.scale.upload(arena),
|
scale: scale_ptr,
|
||||||
};
|
})
|
||||||
GPUSpectrumTexture::Scaled(gpu_scaled)
|
|
||||||
}
|
}
|
||||||
SpectrumTexture::Marble(tex) => GPUSpectrumTexture::Marble(tex.clone()),
|
SpectrumTexture::Marble(tex) => GPUSpectrumTexture::Marble(tex.clone()),
|
||||||
SpectrumTexture::Mix(tex) => {
|
SpectrumTexture::Mix(tex) => {
|
||||||
let tex1_ptr = tex.tex1.upload(arena);
|
let tex1_ptr = tex.tex1.upload(arena);
|
||||||
let tex2_ptr = tex.tex2.upload(arena);
|
let tex2_ptr = tex.tex2.upload(arena);
|
||||||
let amount_ptr = tex.amount.upload(arena);
|
let amount_ptr = tex.amount.upload(arena);
|
||||||
|
GPUSpectrumTexture::Mix(GPUSpectrumMixTexture {
|
||||||
let gpu_mix = GPUSpectrumMixTexture {
|
|
||||||
tex1: tex1_ptr,
|
tex1: tex1_ptr,
|
||||||
tex2: tex2_ptr,
|
tex2: tex2_ptr,
|
||||||
amount: amount_ptr,
|
amount: amount_ptr,
|
||||||
};
|
})
|
||||||
|
|
||||||
GPUSpectrumTexture::Mix(gpu_mix)
|
|
||||||
}
|
}
|
||||||
SpectrumTexture::DirectionMix(tex) => {
|
SpectrumTexture::DirectionMix(tex) => {
|
||||||
let tex1_ptr = tex.tex1.upload(arena);
|
let tex1_ptr = tex.tex1.upload(arena);
|
||||||
let tex2_ptr = tex.tex2.upload(arena);
|
let tex2_ptr = tex.tex2.upload(arena);
|
||||||
|
GPUSpectrumTexture::DirectionMix(GPUSpectrumDirectionMixTexture {
|
||||||
let gpu_mix = GPUSpectrumDirectionMixTexture {
|
|
||||||
tex1: tex1_ptr,
|
tex1: tex1_ptr,
|
||||||
tex2: tex2_ptr,
|
tex2: tex2_ptr,
|
||||||
dir: tex.dir,
|
dir: tex.dir,
|
||||||
};
|
})
|
||||||
|
|
||||||
GPUSpectrumTexture::DirectionMix(gpu_mix)
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
arena.alloc(gpu_variant)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upload for FloatTexture {
|
// =============================================================================
|
||||||
|
// FloatTexture → GPUFloatTexture
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
impl DeviceRepr for FloatTexture {
|
||||||
type Target = GPUFloatTexture;
|
type Target = GPUFloatTexture;
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let gpu_variant = match self {
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> GPUFloatTexture {
|
||||||
|
match self {
|
||||||
FloatTexture::Constant(tex) => GPUFloatTexture::Constant(tex.clone()),
|
FloatTexture::Constant(tex) => GPUFloatTexture::Constant(tex.clone()),
|
||||||
FloatTexture::Checkerboard(tex) => GPUFloatTexture::Checkerboard(tex.clone()),
|
FloatTexture::Checkerboard(tex) => GPUFloatTexture::Checkerboard(tex.clone()),
|
||||||
FloatTexture::Dots(tex) => GPUFloatTexture::Dots(tex.clone()),
|
FloatTexture::Dots(tex) => GPUFloatTexture::Dots(tex.clone()),
|
||||||
|
|
@ -312,249 +511,40 @@ impl Upload for FloatTexture {
|
||||||
FloatTexture::Wrinkled(tex) => GPUFloatTexture::Wrinkled(tex.clone()),
|
FloatTexture::Wrinkled(tex) => GPUFloatTexture::Wrinkled(tex.clone()),
|
||||||
FloatTexture::Scaled(tex) => {
|
FloatTexture::Scaled(tex) => {
|
||||||
let child_ptr = tex.tex.upload(arena);
|
let child_ptr = tex.tex.upload(arena);
|
||||||
|
let scale_ptr = tex.scale.upload(arena);
|
||||||
let gpu_scaled = GPUFloatScaledTexture {
|
GPUFloatTexture::Scaled(GPUFloatScaledTexture {
|
||||||
tex: child_ptr,
|
tex: child_ptr,
|
||||||
scale: tex.scale.upload(arena),
|
scale: scale_ptr,
|
||||||
};
|
})
|
||||||
GPUFloatTexture::Scaled(gpu_scaled)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatTexture::Mix(tex) => {
|
FloatTexture::Mix(tex) => {
|
||||||
let tex1_ptr = tex.tex1.upload(arena);
|
let tex1_ptr = tex.tex1.upload(arena);
|
||||||
let tex2_ptr = tex.tex2.upload(arena);
|
let tex2_ptr = tex.tex2.upload(arena);
|
||||||
let amount_ptr = tex.amount.upload(arena);
|
let amount_ptr = tex.amount.upload(arena);
|
||||||
|
GPUFloatTexture::Mix(GPUFloatMixTexture {
|
||||||
let gpu_mix = GPUFloatMixTexture {
|
|
||||||
tex1: tex1_ptr,
|
tex1: tex1_ptr,
|
||||||
tex2: tex2_ptr,
|
tex2: tex2_ptr,
|
||||||
amount: amount_ptr,
|
amount: amount_ptr,
|
||||||
};
|
})
|
||||||
GPUFloatTexture::Mix(gpu_mix)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatTexture::DirectionMix(tex) => {
|
FloatTexture::DirectionMix(tex) => {
|
||||||
let tex1_ptr = tex.tex1.upload(arena);
|
let tex1_ptr = tex.tex1.upload(arena);
|
||||||
let tex2_ptr = tex.tex2.upload(arena);
|
let tex2_ptr = tex.tex2.upload(arena);
|
||||||
let gpu_dmix = GPUFloatDirectionMixTexture {
|
GPUFloatTexture::DirectionMix(GPUFloatDirectionMixTexture {
|
||||||
tex1: tex1_ptr,
|
tex1: tex1_ptr,
|
||||||
tex2: tex2_ptr,
|
tex2: tex2_ptr,
|
||||||
dir: tex.dir,
|
dir: tex.dir,
|
||||||
};
|
})
|
||||||
GPUFloatTexture::DirectionMix(gpu_dmix)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatTexture::Image(tex) => {
|
FloatTexture::Image(tex) => {
|
||||||
let gpu_image_tex = GPUFloatImageTexture {
|
GPUFloatTexture::Image(GPUFloatImageTexture {
|
||||||
mapping: tex.base.mapping,
|
mapping: tex.base.mapping,
|
||||||
tex_obj: tex.base.mipmap.texture_object(),
|
tex_obj: tex.base.mipmap.texture_object(),
|
||||||
scale: tex.base.scale,
|
scale: tex.base.scale,
|
||||||
invert: tex.base.invert,
|
invert: tex.base.invert,
|
||||||
};
|
})
|
||||||
GPUFloatTexture::Image(gpu_image_tex)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatTexture::Bilerp(tex) => GPUFloatTexture::Bilerp(tex.clone()),
|
FloatTexture::Bilerp(tex) => GPUFloatTexture::Bilerp(tex.clone()),
|
||||||
};
|
|
||||||
|
|
||||||
arena.alloc(gpu_variant)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for RGBToSpectrumTable {
|
|
||||||
type Target = RGBToSpectrumTable;
|
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let n_nodes = self.n_nodes as usize;
|
|
||||||
let z_slice = unsafe { from_raw_parts(self.z_nodes.as_raw(), n_nodes) };
|
|
||||||
let (z_ptr, _) = arena.alloc_slice(z_slice);
|
|
||||||
|
|
||||||
let n_coeffs = 3 * (n_nodes as usize).pow(3);
|
|
||||||
let coeffs_slice = unsafe { from_raw_parts(self.coeffs.as_raw(), n_coeffs) };
|
|
||||||
let (c_ptr, _) = arena.alloc_slice(coeffs_slice);
|
|
||||||
|
|
||||||
arena.alloc(RGBToSpectrumTable {
|
|
||||||
z_nodes: z_ptr,
|
|
||||||
coeffs: c_ptr,
|
|
||||||
n_nodes: self.n_nodes,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for RGBColorSpace {
|
|
||||||
type Target = RGBColorSpace;
|
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let table_ptr = self.rgb_to_spectrum_table.upload(arena);
|
|
||||||
|
|
||||||
let shared_space = RGBColorSpace {
|
|
||||||
r: self.r,
|
|
||||||
g: self.g,
|
|
||||||
b: self.b,
|
|
||||||
w: self.w,
|
|
||||||
illuminant: self.illuminant.clone(),
|
|
||||||
rgb_to_spectrum_table: table_ptr,
|
|
||||||
xyz_from_rgb: self.xyz_from_rgb,
|
|
||||||
rgb_from_xyz: self.rgb_from_xyz,
|
|
||||||
};
|
|
||||||
|
|
||||||
arena.alloc(shared_space)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for DeviceStandardColorSpaces {
|
|
||||||
type Target = DeviceStandardColorSpaces;
|
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let srgb_ptr = self.srgb.upload(arena);
|
|
||||||
let dci_ptr = self.dci_p3.upload(arena);
|
|
||||||
let rec_ptr = self.rec2020.upload(arena);
|
|
||||||
let aces_ptr = self.aces2065_1.upload(arena);
|
|
||||||
|
|
||||||
let registry = DeviceStandardColorSpaces {
|
|
||||||
srgb: srgb_ptr,
|
|
||||||
dci_p3: dci_ptr,
|
|
||||||
rec2020: rec_ptr,
|
|
||||||
aces2065_1: aces_ptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
arena.alloc(registry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for PiecewiseConstant2D {
|
|
||||||
type Target = DevicePiecewiseConstant2D;
|
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let marginal_shared = self.marginal.to_shared(arena);
|
|
||||||
|
|
||||||
let conditionals_shared: Vec<DevicePiecewiseConstant1D> = self
|
|
||||||
.conditionals
|
|
||||||
.iter()
|
|
||||||
.map(|c| c.to_shared(arena))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let (conditionals_ptr, _) = arena.alloc_slice(&conditionals_shared);
|
|
||||||
|
|
||||||
let shared_2d = DevicePiecewiseConstant2D {
|
|
||||||
conditionals: conditionals_ptr,
|
|
||||||
marginal: marginal_shared,
|
|
||||||
..self.device
|
|
||||||
};
|
|
||||||
|
|
||||||
arena.alloc(shared_2d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for WindowedPiecewiseConstant2D {
|
|
||||||
type Target = DeviceWindowedPiecewiseConstant2D;
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let specific = DeviceWindowedPiecewiseConstant2D {
|
|
||||||
sat: self.sat,
|
|
||||||
func: self.func,
|
|
||||||
};
|
|
||||||
arena.alloc(specific)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for TriangleMesh {
|
|
||||||
type Target = DeviceTriangleMesh;
|
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let storage = &self.storage;
|
|
||||||
|
|
||||||
// Upload all arrays to arena
|
|
||||||
let (vertex_indices_ptr, _) = arena.alloc_slice(&storage.vertex_indices);
|
|
||||||
let (p_ptr, _) = arena.alloc_slice(&storage.p);
|
|
||||||
|
|
||||||
let (n_ptr, _) = if storage.n.is_empty() {
|
|
||||||
(Ptr::null(), 0)
|
|
||||||
} else {
|
|
||||||
arena.alloc_slice(&storage.n)
|
|
||||||
};
|
|
||||||
|
|
||||||
let (s_ptr, _) = if storage.s.is_empty() {
|
|
||||||
(Ptr::null(), 0)
|
|
||||||
} else {
|
|
||||||
arena.alloc_slice(&storage.s)
|
|
||||||
};
|
|
||||||
|
|
||||||
let (uv_ptr, _) = if storage.uv.is_empty() {
|
|
||||||
(Ptr::null(), 0)
|
|
||||||
} else {
|
|
||||||
arena.alloc_slice(&storage.uv)
|
|
||||||
};
|
|
||||||
|
|
||||||
let (face_indices_ptr, _) = if storage.face_indices.is_empty() {
|
|
||||||
(Ptr::null(), 0)
|
|
||||||
} else {
|
|
||||||
arena.alloc_slice(&storage.face_indices)
|
|
||||||
};
|
|
||||||
|
|
||||||
let device = DeviceTriangleMesh {
|
|
||||||
vertex_indices: vertex_indices_ptr,
|
|
||||||
p: p_ptr,
|
|
||||||
n: n_ptr,
|
|
||||||
s: s_ptr,
|
|
||||||
uv: uv_ptr,
|
|
||||||
face_indices: face_indices_ptr,
|
|
||||||
..self.device
|
|
||||||
};
|
|
||||||
|
|
||||||
arena.alloc(device)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Upload for BilinearPatchMesh {
|
|
||||||
type Target = DeviceBilinearPatchMesh;
|
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
let storage = &self.storage;
|
|
||||||
|
|
||||||
let (vertex_indices_ptr, _) = arena.alloc_slice(&storage.vertex_indices);
|
|
||||||
let (p_ptr, _) = arena.alloc_slice(&storage.p);
|
|
||||||
|
|
||||||
let (n_ptr, _) = if storage.n.is_empty() {
|
|
||||||
(Ptr::null(), 0)
|
|
||||||
} else {
|
|
||||||
arena.alloc_slice(&storage.n)
|
|
||||||
};
|
|
||||||
|
|
||||||
let (uv_ptr, _) = if storage.uv.is_empty() {
|
|
||||||
(Ptr::null(), 0)
|
|
||||||
} else {
|
|
||||||
arena.alloc_slice(&storage.uv)
|
|
||||||
};
|
|
||||||
|
|
||||||
let image_dist_ptr = storage.image_distribution.upload(arena);
|
|
||||||
|
|
||||||
let device = DeviceBilinearPatchMesh {
|
|
||||||
vertex_indices: vertex_indices_ptr,
|
|
||||||
p: p_ptr,
|
|
||||||
n: n_ptr,
|
|
||||||
uv: uv_ptr,
|
|
||||||
image_distribution: image_dist_ptr,
|
|
||||||
..self.device
|
|
||||||
};
|
|
||||||
|
|
||||||
arena.alloc(device)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Upload> Upload for Option<T> {
|
|
||||||
type Target = T::Target;
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
match self {
|
|
||||||
Some(val) => val.upload(arena),
|
|
||||||
None => Ptr::null(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Upload> Upload for Arc<T> {
|
|
||||||
type Target = T::Target;
|
|
||||||
|
|
||||||
fn upload<A: GpuAllocator>(&self, arena: &Arena<A>) -> Ptr<Self::Target> {
|
|
||||||
(**self).upload(arena)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,17 @@
|
||||||
use std::alloc::Layout;
|
use std::alloc::Layout;
|
||||||
|
|
||||||
pub trait GpuAllocator: Send + Sync {
|
pub trait GpuAllocator: Send + Sync + Clone {
|
||||||
/// Allocate `size` bytes with given alignment.
|
|
||||||
/// Returns a host-mapped pointer.
|
|
||||||
unsafe fn alloc(&self, layout: Layout) -> *mut u8;
|
unsafe fn alloc(&self, layout: Layout) -> *mut u8;
|
||||||
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout);
|
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// CPU fallback — standard system allocator.
|
/// CPU fallback
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct SystemAllocator;
|
pub struct SystemAllocator;
|
||||||
|
|
||||||
impl Default for SystemAllocator {
|
impl Default for SystemAllocator {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {}
|
Self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -21,101 +20,85 @@ impl GpuAllocator for SystemAllocator {
|
||||||
if layout.size() == 0 {
|
if layout.size() == 0 {
|
||||||
return layout.align() as *mut u8;
|
return layout.align() as *mut u8;
|
||||||
}
|
}
|
||||||
unsafe { std::alloc::alloc(layout) }
|
std::alloc::alloc(layout)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
|
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
|
||||||
if layout.size() > 0 {
|
if layout.size() > 0 {
|
||||||
unsafe {
|
std::alloc::dealloc(ptr, layout);
|
||||||
std::alloc::dealloc(ptr, layout);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// CUDA unified memory backend using CudaAllocator
|
/// CUDA unified memory. Still using CudaAllocator, might move over to
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub mod cuda {
|
pub mod cuda {
|
||||||
use super::GpuAllocator;
|
use super::GpuAllocator;
|
||||||
|
use cust::memory::{cuda_free_unified, cuda_malloc_unified, UnifiedPointer};
|
||||||
use std::alloc::Layout;
|
use std::alloc::Layout;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct CudaAllocator;
|
pub struct CudaAllocator;
|
||||||
|
|
||||||
impl Default for CudaAllocator {
|
impl Default for CudaAllocator {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {}
|
Self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GpuAllocator for CudaAllocator {
|
impl GpuAllocator for CudaAllocator {
|
||||||
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
|
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
|
||||||
use cust::memory::cuda_malloc_unified;
|
let size = layout.size();
|
||||||
use cust_raw::driver_sys::*;
|
|
||||||
|
|
||||||
let size = layout.size().max(layout.align());
|
|
||||||
if size == 0 {
|
if size == 0 {
|
||||||
return layout.align() as *mut u8;
|
return layout.align() as *mut u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut ctx: CUcontext = std::ptr::null_mut();
|
let ptr = cuda_malloc_unified::<u8>(size)
|
||||||
cuCtxGetCurrent(&mut ctx);
|
.expect("cuda_malloc_unified failed — is a CUDA context current?");
|
||||||
if ctx.is_null() {
|
|
||||||
let mut primary: CUcontext = std::ptr::null_mut();
|
|
||||||
cuDevicePrimaryCtxRetain(&mut primary, 0);
|
|
||||||
cuCtxSetCurrent(primary);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut unified_ptr =
|
let raw = ptr.as_raw_mut();
|
||||||
unsafe { cuda_malloc_unified::<u8>(size).expect("cuda_malloc_unified failed") };
|
std::mem::forget(ptr); // Leak RAII wrapper; Arena owns the raw pointer
|
||||||
let raw = unified_ptr.as_raw_mut();
|
|
||||||
std::mem::forget(unified_ptr);
|
|
||||||
raw
|
raw
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
|
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
|
||||||
use cust::memory::{UnifiedPointer, cuda_free_unified};
|
if layout.size() > 0 && !ptr.is_null() {
|
||||||
if layout.size() > 0 {
|
let _ = cuda_free_unified(UnifiedPointer::wrap(ptr));
|
||||||
let _ = unsafe { cuda_free_unified(UnifiedPointer::wrap(ptr)) };
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Vulkan backend (gpu-allocator for now, there might be a better solution)
|
|
||||||
#[cfg(feature = "vulkan")]
|
#[cfg(feature = "vulkan")]
|
||||||
pub mod vulkan {
|
pub mod vulkan {
|
||||||
use super::GpuAllocator;
|
use super::GpuAllocator;
|
||||||
use ash::vk;
|
use ash::vk;
|
||||||
use gpu_allocator::MemoryLocation;
|
|
||||||
use gpu_allocator::vulkan::{
|
use gpu_allocator::vulkan::{
|
||||||
Allocation, AllocationCreateDesc, AllocationScheme, Allocator, AllocatorCreateDesc,
|
Allocation, AllocationCreateDesc, AllocationScheme, Allocator, AllocatorCreateDesc,
|
||||||
};
|
};
|
||||||
|
use gpu_allocator::MemoryLocation;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::alloc::Layout;
|
use std::alloc::Layout;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::OnceLock;
|
use std::sync::Arc;
|
||||||
|
|
||||||
// So, having a static allocator seems like a terrible idea
|
#[derive(Clone)]
|
||||||
// But I cant find a way to get a functioning generic Arena constructor
|
pub struct VulkanAllocator {
|
||||||
// That might not even be a necessity, since rust-gpu/rust-cuda might actually handle that
|
inner: Arc<Mutex<VulkanInner>>,
|
||||||
// differently
|
|
||||||
static VK_ALLOCATOR: OnceLock<VulkanAllocatorInner> = OnceLock::new();
|
|
||||||
|
|
||||||
struct VulkanAllocatorInner {
|
|
||||||
state: Mutex<VulkanState>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct VulkanState {
|
struct VulkanInner {
|
||||||
|
device: ash::Device,
|
||||||
allocator: Allocator,
|
allocator: Allocator,
|
||||||
allocations: HashMap<usize, Allocation>,
|
allocations: HashMap<usize, Allocation>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn init_vulkan(
|
impl VulkanAllocator {
|
||||||
instance: &ash::Instance,
|
pub fn new(
|
||||||
device: &ash::Device,
|
instance: &ash::Instance,
|
||||||
physical_device: vk::PhysicalDevice,
|
device: ash::Device,
|
||||||
) {
|
physical_device: vk::PhysicalDevice,
|
||||||
VK_ALLOCATOR.get_or_init(|| {
|
) -> Self {
|
||||||
let allocator = Allocator::new(&AllocatorCreateDesc {
|
let allocator = Allocator::new(&AllocatorCreateDesc {
|
||||||
instance: instance.clone(),
|
instance: instance.clone(),
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
|
|
@ -126,52 +109,29 @@ pub mod vulkan {
|
||||||
})
|
})
|
||||||
.expect("Failed to create Vulkan allocator");
|
.expect("Failed to create Vulkan allocator");
|
||||||
|
|
||||||
VulkanAllocatorInner {
|
Self {
|
||||||
state: Mutex::new(VulkanState {
|
inner: Arc::new(Mutex::new(VulkanInner {
|
||||||
|
device,
|
||||||
allocator,
|
allocator,
|
||||||
allocations: HashMap::new(),
|
allocations: HashMap::new(),
|
||||||
}),
|
})),
|
||||||
}
|
}
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn inner() -> &'static VulkanAllocatorInner {
|
|
||||||
VK_ALLOCATOR
|
|
||||||
.get()
|
|
||||||
.expect("Vulkan not initialized — call init_vulkan() before arena creation")
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for VulkanAllocator {
|
|
||||||
fn default() -> Self {
|
|
||||||
let _ = inner();
|
|
||||||
Self
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct VulkanAllocator;
|
|
||||||
|
|
||||||
// impl VulkanAllocator {
|
|
||||||
// pub fn new(allocator: Allocator) -> Self {
|
|
||||||
// Self {
|
|
||||||
// allocator: Mutex::new(allocator),
|
|
||||||
// allocations: Mutex::new(HashMap::new()),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
impl GpuAllocator for VulkanAllocator {
|
impl GpuAllocator for VulkanAllocator {
|
||||||
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
|
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
|
||||||
let size = layout.size().max(layout.align());
|
let size = layout.size();
|
||||||
if size == 0 {
|
if size == 0 {
|
||||||
return layout.align() as *mut u8;
|
return layout.align() as *mut u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
let inner = inner();
|
let mut inner = self.inner.lock();
|
||||||
let mut state = inner.state.lock();
|
|
||||||
let allocation = state
|
let allocation = inner
|
||||||
.allocator
|
.allocator
|
||||||
.allocate(&AllocationCreateDesc {
|
.allocate(&AllocationCreateDesc {
|
||||||
name: "arena",
|
name: "arena_chunk",
|
||||||
requirements: vk::MemoryRequirements {
|
requirements: vk::MemoryRequirements {
|
||||||
size: size as u64,
|
size: size as u64,
|
||||||
alignment: layout.align() as u64,
|
alignment: layout.align() as u64,
|
||||||
|
|
@ -188,18 +148,17 @@ pub mod vulkan {
|
||||||
.expect("Vulkan allocation not host-mapped")
|
.expect("Vulkan allocation not host-mapped")
|
||||||
.as_ptr() as *mut u8;
|
.as_ptr() as *mut u8;
|
||||||
|
|
||||||
state.allocations.insert(ptr as usize, allocation);
|
inner.allocations.insert(ptr as usize, allocation);
|
||||||
ptr
|
ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
|
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
|
||||||
if layout.size() == 0 {
|
if layout.size() == 0 || ptr.is_null() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let inner = inner();
|
let mut inner = self.inner.lock();
|
||||||
let mut state = inner.state.lock();
|
if let Some(allocation) = inner.allocations.remove(&(ptr as usize)) {
|
||||||
if let Some(allocation) = state.allocations.remove(&(ptr as usize)) {
|
inner
|
||||||
state
|
|
||||||
.allocator
|
.allocator
|
||||||
.free(allocation)
|
.free(allocation)
|
||||||
.expect("Vulkan free failed");
|
.expect("Vulkan free failed");
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
#[derive(DeviceRepr)]
|
||||||
pub struct Array2D<T> {
|
pub struct Array2D<T> {
|
||||||
pub device: DeviceArray2D<T>,
|
pub device: DeviceArray2D<T>,
|
||||||
pub values: Vec<T>,
|
pub values: Vec<T>,
|
||||||
|
|
|
||||||
515
src/utils/mod.rs
515
src/utils/mod.rs
|
|
@ -12,7 +12,7 @@ pub mod parser;
|
||||||
pub mod sampling;
|
pub mod sampling;
|
||||||
pub mod strings;
|
pub mod strings;
|
||||||
|
|
||||||
pub use arena::Upload;
|
pub use arena::DeviceRepr;
|
||||||
pub use error::FileLoc;
|
pub use error::FileLoc;
|
||||||
pub use file::{read_float_file, resolve_filename};
|
pub use file::{read_float_file, resolve_filename};
|
||||||
pub use parameters::{
|
pub use parameters::{
|
||||||
|
|
@ -28,3 +28,516 @@ pub type Arena = arena::Arena<backend::cuda::CudaAllocator>;
|
||||||
|
|
||||||
#[cfg(not(any(feature = "cuda", feature = "vulkan")))]
|
#[cfg(not(any(feature = "cuda", feature = "vulkan")))]
|
||||||
pub type Arena = arena::Arena<backend::SystemAllocator>;
|
pub type Arena = arena::Arena<backend::SystemAllocator>;
|
||||||
|
|
||||||
|
/// # Enum variant attributes
|
||||||
|
///
|
||||||
|
/// | Attribute | Effect |
|
||||||
|
/// |-----------|--------|
|
||||||
|
/// | *(none)* | Inner type has `DeviceRepr`; auto-call `upload_value` |
|
||||||
|
/// | `#[device(clone)]` | Same type on both sides, just clone |
|
||||||
|
/// | `#[device(custom = "method")]` | You provide `fn method(inner: &T, arena) -> DeviceT` |
|
||||||
|
/// | `#[device(variant_type = "T")]` | Override the device-side variant's inner type |
|
||||||
|
///
|
||||||
|
/// # Container attribute
|
||||||
|
///
|
||||||
|
/// `#[device(name = "DeviceFoo")]` — override the generated type name (default: `Device{Name}`).
|
||||||
|
#[proc_macro_derive(Device, attributes(device))]
|
||||||
|
pub fn derive_device(input: TokenStream) -> TokenStream {
|
||||||
|
let input = parse_macro_input!(input as DeriveInput);
|
||||||
|
match derive_impl(input) {
|
||||||
|
Ok(tokens) => tokens.into(),
|
||||||
|
Err(e) => e.to_compile_error().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
|
||||||
|
match &input.data {
|
||||||
|
Data::Struct(_) => derive_struct(input),
|
||||||
|
Data::Enum(_) => derive_enum(input),
|
||||||
|
Data::Union(_) => Err(syn::Error::new_spanned(
|
||||||
|
&input.ident,
|
||||||
|
"Device derive does not support unions",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Struct derivation
|
||||||
|
|
||||||
|
fn derive_struct(input: DeriveInput) -> syn::Result<TokenStream2> {
|
||||||
|
let host_name = &input.ident;
|
||||||
|
let vis = &input.vis;
|
||||||
|
let device_name = get_device_name(&input.attrs, host_name)?;
|
||||||
|
|
||||||
|
let fields = match &input.data {
|
||||||
|
Data::Struct(s) => match &s.fields {
|
||||||
|
Fields::Named(named) => &named.named,
|
||||||
|
_ => {
|
||||||
|
return Err(syn::Error::new_spanned(
|
||||||
|
host_name,
|
||||||
|
"Device derive only supports structs with named fields",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut device_fields = Vec::new();
|
||||||
|
let mut upload_stmts = Vec::new();
|
||||||
|
let mut device_field_inits = Vec::new();
|
||||||
|
let mut spread_expr: Option<Expr> = None;
|
||||||
|
|
||||||
|
for field in fields {
|
||||||
|
let field_name = field.ident.as_ref().unwrap();
|
||||||
|
let attrs = parse_field_attrs(&field.attrs)?;
|
||||||
|
|
||||||
|
if attrs.skip {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref expr_str) = attrs.spread {
|
||||||
|
spread_expr = Some(syn::parse_str(expr_str).map_err(|e| {
|
||||||
|
syn::Error::new_spanned(field, format!("invalid device(spread): {}", e))
|
||||||
|
})?);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(expr_str) = &attrs.expr {
|
||||||
|
let expr: Expr = syn::parse_str(expr_str).map_err(|e| {
|
||||||
|
syn::Error::new_spanned(field, format!("invalid device(expr): {}", e))
|
||||||
|
})?;
|
||||||
|
let ty = &field.ty;
|
||||||
|
device_fields.push(quote! { pub #field_name: #ty });
|
||||||
|
upload_stmts.push(quote! { let #field_name = #expr; });
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match classify_type(&field.ty) {
|
||||||
|
FieldClass::VecCopy(inner_ty) => {
|
||||||
|
let len_name = format_ident!("{}_len", field_name);
|
||||||
|
device_fields.push(quote! { pub #field_name: Ptr<#inner_ty> });
|
||||||
|
device_fields.push(quote! { pub #len_name: usize });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let (#field_name, #len_name) = arena.alloc_slice(&self.#field_name);
|
||||||
|
});
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
device_field_inits.push(quote! { #len_name });
|
||||||
|
}
|
||||||
|
FieldClass::VecUploadable(inner_ty) => {
|
||||||
|
let len_name = format_ident!("{}_len", field_name);
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
device_fields.push(quote! { pub #len_name: usize });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let __up: Vec<<#inner_ty as DeviceRepr>::Target> = self.#field_name
|
||||||
|
.iter()
|
||||||
|
.map(|item| DeviceRepr::upload_value(item, arena))
|
||||||
|
.collect();
|
||||||
|
let (#field_name, #len_name) = arena.alloc_slice(&__up);
|
||||||
|
});
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
device_field_inits.push(quote! { #len_name });
|
||||||
|
}
|
||||||
|
FieldClass::Option(inner_ty) => {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = match &self.#field_name {
|
||||||
|
Some(val) => DeviceRepr::upload(val, arena),
|
||||||
|
None => Ptr::null(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
}
|
||||||
|
FieldClass::Arc(inner_ty) => {
|
||||||
|
if attrs.flatten {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: <#inner_ty as DeviceRepr>::Target
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload_value(&*self.#field_name, arena);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#inner_ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload(&*self.#field_name, arena);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
}
|
||||||
|
FieldClass::Plain => {
|
||||||
|
let ty = &field.ty;
|
||||||
|
if attrs.copy_upload {
|
||||||
|
device_fields.push(quote! { pub #field_name: #ty });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = self.#field_name.clone();
|
||||||
|
});
|
||||||
|
} else if attrs.flatten {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: <#ty as DeviceRepr>::Target
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload_value(&self.#field_name, arena);
|
||||||
|
});
|
||||||
|
} else if attrs.upload {
|
||||||
|
device_fields.push(quote! {
|
||||||
|
pub #field_name: Ptr<<#ty as DeviceRepr>::Target>
|
||||||
|
});
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = DeviceRepr::upload(&self.#field_name, arena);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
device_fields.push(quote! { pub #field_name: #ty });
|
||||||
|
upload_stmts.push(quote! {
|
||||||
|
let #field_name = self.#field_name;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
device_field_inits.push(quote! { #field_name });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let constructor = if let Some(spread) = spread_expr {
|
||||||
|
quote! {
|
||||||
|
#device_name {
|
||||||
|
#(#device_field_inits,)*
|
||||||
|
..#spread
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
#device_name {
|
||||||
|
#(#device_field_inits,)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(quote! {
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
#vis struct #device_name {
|
||||||
|
#(#device_fields,)*
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for #device_name {}
|
||||||
|
unsafe impl Sync for #device_name {}
|
||||||
|
|
||||||
|
impl DeviceRepr for #host_name {
|
||||||
|
type Target = #device_name;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target {
|
||||||
|
#(#upload_stmts)*
|
||||||
|
#constructor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enum derivation
|
||||||
|
fn derive_enum(input: DeriveInput) -> syn::Result<TokenStream2> {
|
||||||
|
let host_name = &input.ident;
|
||||||
|
let vis = &input.vis;
|
||||||
|
let device_name = get_device_name(&input.attrs, host_name)?;
|
||||||
|
|
||||||
|
let variants = match &input.data {
|
||||||
|
Data::Enum(e) => &e.variants,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut device_variants = Vec::new();
|
||||||
|
let mut match_arms = Vec::new();
|
||||||
|
|
||||||
|
for variant in variants {
|
||||||
|
let var_name = &variant.ident;
|
||||||
|
let var_attrs = parse_variant_attrs(&variant.attrs)?;
|
||||||
|
let inner_ty = get_variant_inner_type(variant)?;
|
||||||
|
|
||||||
|
// Determine the device-side inner type for this variant
|
||||||
|
let device_inner: Type = if let Some(ref ty_str) = var_attrs.variant_type {
|
||||||
|
syn::parse_str(ty_str).map_err(|e| {
|
||||||
|
syn::Error::new_spanned(variant, format!("invalid variant_type: {}", e))
|
||||||
|
})?
|
||||||
|
} else if var_attrs.clone_variant {
|
||||||
|
// clone: same type on both sides
|
||||||
|
inner_ty.clone()
|
||||||
|
} else {
|
||||||
|
// auto-upload: use DeviceRepr::Target
|
||||||
|
syn::parse_str(&format!("<{} as DeviceRepr>::Target", quote!(#inner_ty))).map_err(
|
||||||
|
|e| {
|
||||||
|
syn::Error::new_spanned(variant, format!("cannot construct Target type: {}", e))
|
||||||
|
},
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
device_variants.push(quote! { #var_name(#device_inner) });
|
||||||
|
|
||||||
|
if var_attrs.clone_variant {
|
||||||
|
match_arms.push(quote! {
|
||||||
|
#host_name::#var_name(inner) => #device_name::#var_name(inner.clone())
|
||||||
|
});
|
||||||
|
} else if let Some(ref method) = var_attrs.custom {
|
||||||
|
let method_ident = format_ident!("{}", method);
|
||||||
|
match_arms.push(quote! {
|
||||||
|
#host_name::#var_name(inner) => {
|
||||||
|
#device_name::#var_name(Self::#method_ident(inner, arena))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Default: inner implements DeviceRepr
|
||||||
|
match_arms.push(quote! {
|
||||||
|
#host_name::#var_name(inner) => {
|
||||||
|
#device_name::#var_name(DeviceRepr::upload_value(inner, arena))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(quote! {
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
#vis enum #device_name {
|
||||||
|
#(#device_variants,)*
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for #device_name {}
|
||||||
|
unsafe impl Sync for #device_name {}
|
||||||
|
|
||||||
|
impl DeviceRepr for #host_name {
|
||||||
|
type Target = #device_name;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> Self::Target {
|
||||||
|
match self {
|
||||||
|
#(#match_arms,)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_variant_inner_type(variant: &Variant) -> syn::Result<Type> {
|
||||||
|
match &variant.fields {
|
||||||
|
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
||||||
|
Ok(fields.unnamed.first().unwrap().ty.clone())
|
||||||
|
}
|
||||||
|
Fields::Unit => Err(syn::Error::new_spanned(
|
||||||
|
variant,
|
||||||
|
"Device derive: enum variants must have exactly one field, e.g. Variant(Type)",
|
||||||
|
)),
|
||||||
|
_ => Err(syn::Error::new_spanned(
|
||||||
|
variant,
|
||||||
|
"Device derive: only single-field tuple variants supported, e.g. Variant(Type)",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attribute parsing for variants
|
||||||
|
struct VariantAttrs {
|
||||||
|
clone_variant: bool,
|
||||||
|
custom: Option<String>,
|
||||||
|
variant_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_variant_attrs(attrs: &[Attribute]) -> syn::Result<VariantAttrs> {
|
||||||
|
let mut result = VariantAttrs {
|
||||||
|
clone_variant: false,
|
||||||
|
custom: None,
|
||||||
|
variant_type: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
for attr in attrs {
|
||||||
|
if !attr.path().is_ident("device") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
attr.parse_nested_meta(|meta| {
|
||||||
|
if meta.path.is_ident("clone") {
|
||||||
|
result.clone_variant = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("custom") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.custom = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else if meta.path.is_ident("variant_type") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.variant_type = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(meta.error("unknown device variant attribute"))
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attribute parsing for fields
|
||||||
|
struct FieldAttrs {
|
||||||
|
skip: bool,
|
||||||
|
expr: Option<String>,
|
||||||
|
copy_upload: bool,
|
||||||
|
flatten: bool,
|
||||||
|
upload: bool,
|
||||||
|
spread: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_field_attrs(attrs: &[Attribute]) -> syn::Result<FieldAttrs> {
|
||||||
|
let mut result = FieldAttrs {
|
||||||
|
skip: false,
|
||||||
|
expr: None,
|
||||||
|
copy_upload: false,
|
||||||
|
flatten: false,
|
||||||
|
upload: false,
|
||||||
|
spread: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
for attr in attrs {
|
||||||
|
if !attr.path().is_ident("device") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
attr.parse_nested_meta(|meta| {
|
||||||
|
if meta.path.is_ident("skip") {
|
||||||
|
result.skip = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("expr") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.expr = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else if meta.path.is_ident("copy_upload") {
|
||||||
|
result.copy_upload = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("flatten") {
|
||||||
|
result.flatten = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("upload") {
|
||||||
|
result.upload = true;
|
||||||
|
Ok(())
|
||||||
|
} else if meta.path.is_ident("spread") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
result.spread = Some(s.value());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(meta.error("unknown device attribute"))
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Container-level name attribute
|
||||||
|
fn get_device_name(attrs: &[Attribute], host_name: &Ident) -> syn::Result<Ident> {
|
||||||
|
for attr in attrs {
|
||||||
|
if !attr.path().is_ident("device") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let mut name = None;
|
||||||
|
attr.parse_nested_meta(|meta| {
|
||||||
|
if meta.path.is_ident("name") {
|
||||||
|
let value = meta.value()?;
|
||||||
|
let lit: Lit = value.parse()?;
|
||||||
|
if let Lit::Str(s) = lit {
|
||||||
|
name = Some(format_ident!("{}", s.value()));
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(meta.error("expected string literal"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
if let Some(n) = name {
|
||||||
|
return Ok(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(format_ident!("Device{}", host_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type classification
|
||||||
|
enum FieldClass {
|
||||||
|
VecCopy(Type),
|
||||||
|
VecUploadable(Type),
|
||||||
|
Option(Type),
|
||||||
|
Arc(Type),
|
||||||
|
Plain,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn classify_type(ty: &Type) -> FieldClass {
|
||||||
|
if let Some(inner) = extract_generic_arg(ty, "Vec") {
|
||||||
|
if is_copy_primitive(&inner) {
|
||||||
|
FieldClass::VecCopy(inner)
|
||||||
|
} else {
|
||||||
|
FieldClass::VecUploadable(inner)
|
||||||
|
}
|
||||||
|
} else if let Some(inner) = extract_generic_arg(ty, "Option") {
|
||||||
|
FieldClass::Option(inner)
|
||||||
|
} else if let Some(inner) = extract_generic_arg(ty, "Arc") {
|
||||||
|
FieldClass::Arc(inner)
|
||||||
|
} else {
|
||||||
|
FieldClass::Plain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_generic_arg(ty: &Type, wrapper: &str) -> Option<Type> {
|
||||||
|
if let Type::Path(type_path) = ty {
|
||||||
|
let seg = type_path.path.segments.last()?;
|
||||||
|
if seg.ident != wrapper {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if let PathArguments::AngleBracketed(args) = &seg.arguments {
|
||||||
|
if let Some(GenericArgument::Type(inner)) = args.args.first() {
|
||||||
|
return Some(inner.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_copy_primitive(ty: &Type) -> bool {
|
||||||
|
if let Type::Path(type_path) = ty {
|
||||||
|
if let Some(seg) = type_path.path.segments.last() {
|
||||||
|
let name = seg.ident.to_string();
|
||||||
|
return matches!(
|
||||||
|
name.as_str(),
|
||||||
|
"f32"
|
||||||
|
| "f64"
|
||||||
|
| "u8"
|
||||||
|
| "u16"
|
||||||
|
| "u32"
|
||||||
|
| "u64"
|
||||||
|
| "i8"
|
||||||
|
| "i16"
|
||||||
|
| "i32"
|
||||||
|
| "i64"
|
||||||
|
| "usize"
|
||||||
|
| "isize"
|
||||||
|
| "bool"
|
||||||
|
| "Float"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,43 +13,25 @@ use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct PiecewiseConstant1D {
|
pub struct PiecewiseConstant1D {
|
||||||
func: Box<[Float]>,
|
func: Vec<Float>,
|
||||||
cdf: Box<[Float]>,
|
cdf: Vec<Float>,
|
||||||
pub device: DevicePiecewiseConstant1D,
|
pub min: Float,
|
||||||
|
pub max: Float,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PiecewiseConstant1D {
|
impl PiecewiseConstant1D {
|
||||||
// Constructors
|
|
||||||
pub fn new(f: &[Float]) -> Self {
|
pub fn new(f: &[Float]) -> Self {
|
||||||
Self::new_with_bounds(f.to_vec(), 0.0, 1.0)
|
Self::new_with_bounds(f.to_vec(), 0.0, 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_shared<A: GpuAllocator>(&self, arena: &Arena<A>) -> DevicePiecewiseConstant1D {
|
|
||||||
let (func_ptr, _) = arena.alloc_slice(&self.func);
|
|
||||||
let (cdf_ptr, _) = arena.alloc_slice(&self.cdf);
|
|
||||||
|
|
||||||
DevicePiecewiseConstant1D {
|
|
||||||
func: func_ptr,
|
|
||||||
cdf: cdf_ptr,
|
|
||||||
func_integral: self.func_integral,
|
|
||||||
n: self.func.len() as u32,
|
|
||||||
min: self.min,
|
|
||||||
max: self.max,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_func<F>(f: F, min: Float, max: Float, n: usize) -> Self
|
pub fn from_func<F>(f: F, min: Float, max: Float, n: usize) -> Self
|
||||||
where
|
where
|
||||||
F: Fn(Float) -> Float,
|
F: Fn(Float) -> Float,
|
||||||
{
|
{
|
||||||
let delta = (max - min) / n as Float;
|
let delta = (max - min) / n as Float;
|
||||||
let values: Vec<Float> = (0..n)
|
let values: Vec<Float> = (0..n)
|
||||||
.map(|i| {
|
.map(|i| f(min + (i as Float + 0.5) * delta))
|
||||||
let x = min + (i as Float + 0.5) * delta;
|
|
||||||
f(x)
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Self::new_with_bounds(values, min, max)
|
Self::new_with_bounds(values, min, max)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -64,74 +46,92 @@ impl PiecewiseConstant1D {
|
||||||
}
|
}
|
||||||
|
|
||||||
let func_integral = cdf[n];
|
let func_integral = cdf[n];
|
||||||
|
|
||||||
if func_integral > 0.0 {
|
if func_integral > 0.0 {
|
||||||
for c in &mut cdf {
|
for c in &mut cdf {
|
||||||
*c /= func_integral;
|
*c /= func_integral;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to boxed slices (no more reallocation possible)
|
Self { func: f, cdf, min, max }
|
||||||
let func: Box<[Float]> = f.into_boxed_slice();
|
|
||||||
let cdf: Box<[Float]> = cdf.into_boxed_slice();
|
|
||||||
|
|
||||||
let device = DevicePiecewiseConstant1D {
|
|
||||||
func: func.as_ptr().into(),
|
|
||||||
cdf: cdf.as_ptr().into(),
|
|
||||||
min,
|
|
||||||
max,
|
|
||||||
n: n as u32,
|
|
||||||
func_integral,
|
|
||||||
};
|
|
||||||
|
|
||||||
Self { func, cdf, device }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accessors
|
// Accessors
|
||||||
pub fn min(&self) -> Float {
|
pub fn n(&self) -> usize { self.func.len() }
|
||||||
self.device.min
|
pub fn func(&self) -> &[Float] { &self.func }
|
||||||
}
|
pub fn cdf(&self) -> &[Float] { &self.cdf }
|
||||||
pub fn max(&self) -> Float {
|
|
||||||
self.device.max
|
|
||||||
}
|
|
||||||
pub fn n(&self) -> usize {
|
|
||||||
self.device.n as usize
|
|
||||||
}
|
|
||||||
pub fn integral(&self) -> Float {
|
pub fn integral(&self) -> Float {
|
||||||
self.device.func_integral
|
// func_integral is the un-normalized sum. After normalization cdf[n] == 1.0,
|
||||||
|
// so we reconstruct from the last CDF entry before normalization.
|
||||||
|
// But since we normalized in-place, we need to store it. Let's compute it.
|
||||||
|
let n = self.func.len();
|
||||||
|
let delta = (self.max - self.min) / n as Float;
|
||||||
|
self.func.iter().sum::<Float>() * delta
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn func(&self) -> &[Float] {
|
/// Host-side sampling (for scene construction, not rendering).
|
||||||
&self.func
|
/// During rendering, use the device struct via arena-uploaded Ptrs.
|
||||||
|
pub fn sample_host(&self, u: Float) -> (Float, Float, usize) {
|
||||||
|
let n = self.func.len();
|
||||||
|
let offset = self.find_interval_host(u);
|
||||||
|
let cdf_offset = self.cdf[offset];
|
||||||
|
let cdf_next = self.cdf[offset + 1];
|
||||||
|
let du = if cdf_next - cdf_offset > 0.0 {
|
||||||
|
(u - cdf_offset) / (cdf_next - cdf_offset)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let delta = (self.max - self.min) / n as Float;
|
||||||
|
let x = self.min + (offset as Float + du) * delta;
|
||||||
|
let func_integral = self.integral();
|
||||||
|
let pdf = if func_integral > 0.0 {
|
||||||
|
self.func[offset] / func_integral
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
(x, pdf, offset)
|
||||||
}
|
}
|
||||||
pub fn cdf(&self) -> &[Float] {
|
|
||||||
&self.cdf
|
fn find_interval_host(&self, u: Float) -> usize {
|
||||||
|
let n = self.func.len();
|
||||||
|
let mut size = n;
|
||||||
|
let mut first = 0usize;
|
||||||
|
while size > 0 {
|
||||||
|
let half = size >> 1;
|
||||||
|
let middle = first + half;
|
||||||
|
if self.cdf[middle] <= u {
|
||||||
|
first = middle + 1;
|
||||||
|
size -= half + 1;
|
||||||
|
} else {
|
||||||
|
size = half;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
first.saturating_sub(1).min(n - 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for PiecewiseConstant1D {
|
#[derive(DeviceRepr)]
|
||||||
type Target = DevicePiecewiseConstant1D;
|
#[device(name = "DevicePiecewiseConstant1D")]
|
||||||
|
pub struct PiecewiseConstant1D {
|
||||||
fn deref(&self) -> &Self::Target {
|
pub func: Vec<Float>,
|
||||||
&self.device
|
pub cdf: Vec<Float>,
|
||||||
}
|
pub min: Float,
|
||||||
|
pub max: Float,
|
||||||
|
pub n: u32,
|
||||||
|
#[device(expr = "self.integral()")]
|
||||||
|
pub func_integral: Float,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Clone, Debug, DeviceRepr)]
|
||||||
|
#[device(name = "DevicePiecewiseConstant2D")]
|
||||||
pub struct PiecewiseConstant2D {
|
pub struct PiecewiseConstant2D {
|
||||||
pub conditionals: Vec<PiecewiseConstant1D>,
|
pub conditionals: Vec<PiecewiseConstant1D>,
|
||||||
|
#[device(flatten)]
|
||||||
pub marginal: PiecewiseConstant1D,
|
pub marginal: PiecewiseConstant1D,
|
||||||
pub conditional_devices: Box<[DevicePiecewiseConstant1D]>,
|
pub n_u: u32,
|
||||||
pub device: DevicePiecewiseConstant2D,
|
pub n_v: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for PiecewiseConstant2D {
|
|
||||||
type Target = DevicePiecewiseConstant2D;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PiecewiseConstant2D {
|
impl PiecewiseConstant2D {
|
||||||
pub fn new(data: &Array2D<Float>) -> Self {
|
pub fn new(data: &Array2D<Float>) -> Self {
|
||||||
|
|
@ -141,8 +141,8 @@ impl PiecewiseConstant2D {
|
||||||
pub fn new_with_bounds(data: &Array2D<Float>, domain: Bounds2f) -> Self {
|
pub fn new_with_bounds(data: &Array2D<Float>, domain: Bounds2f) -> Self {
|
||||||
Self::from_slice(
|
Self::from_slice(
|
||||||
data.as_slice(),
|
data.as_slice(),
|
||||||
data.x_size() as usize,
|
data.x_size(),
|
||||||
data.y_size() as usize,
|
data.y_size(),
|
||||||
domain,
|
domain,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -154,36 +154,23 @@ impl PiecewiseConstant2D {
|
||||||
let mut marginal_func = Vec::with_capacity(n_v);
|
let mut marginal_func = Vec::with_capacity(n_v);
|
||||||
|
|
||||||
for v in 0..n_v {
|
for v in 0..n_v {
|
||||||
let row_start = v * n_u;
|
let row = data[v * n_u..(v + 1) * n_u].to_vec();
|
||||||
let row: Vec<Float> = data[row_start..row_start + n_u].to_vec();
|
let conditional = PiecewiseConstant1D::new_with_bounds(
|
||||||
let conditional =
|
row,
|
||||||
PiecewiseConstant1D::new_with_bounds(row, domain.p_min.x(), domain.p_max.x());
|
domain.p_min.x(),
|
||||||
|
domain.p_max.x(),
|
||||||
|
);
|
||||||
marginal_func.push(conditional.integral());
|
marginal_func.push(conditional.integral());
|
||||||
conditionals.push(conditional);
|
conditionals.push(conditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
let marginal =
|
let marginal = PiecewiseConstant1D::new_with_bounds(
|
||||||
PiecewiseConstant1D::new_with_bounds(marginal_func, domain.p_min.y(), domain.p_max.y());
|
marginal_func,
|
||||||
|
domain.p_min.y(),
|
||||||
|
domain.p_max.y(),
|
||||||
|
);
|
||||||
|
|
||||||
let conditional_devices: Box<[DevicePiecewiseConstant1D]> = conditionals
|
Self { conditionals, marginal, n_u, n_v }
|
||||||
.iter()
|
|
||||||
.map(|c| c.device)
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.into_boxed_slice();
|
|
||||||
|
|
||||||
let device = DevicePiecewiseConstant2D {
|
|
||||||
conditionals: conditional_devices.as_ptr().into(),
|
|
||||||
marginal: marginal.device,
|
|
||||||
n_u: n_u as u32,
|
|
||||||
n_v: n_v as u32,
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
conditionals,
|
|
||||||
marginal,
|
|
||||||
conditional_devices,
|
|
||||||
device,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_image(image: &Image) -> Self {
|
pub fn from_image(image: &Image) -> Self {
|
||||||
|
|
@ -192,12 +179,9 @@ impl PiecewiseConstant2D {
|
||||||
let n_v = res.y() as usize;
|
let n_v = res.y() as usize;
|
||||||
|
|
||||||
let mut data = Vec::with_capacity(n_u * n_v);
|
let mut data = Vec::with_capacity(n_u * n_v);
|
||||||
|
|
||||||
for v in 0..n_v {
|
for v in 0..n_v {
|
||||||
for u in 0..n_u {
|
for u in 0..n_u {
|
||||||
let p = Point2i::new(u as i32, v as i32);
|
data.push(image.get_channels(Point2i::new(u as i32, v as i32)).average());
|
||||||
let luminance = image.get_channels(p).average();
|
|
||||||
data.push(luminance);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -217,10 +201,14 @@ struct PiecewiseLinear2DStorage<const N: usize> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PiecewiseLinear2DHost<const N: usize> {
|
pub struct PiecewiseLinear2DHost<const N: usize> {
|
||||||
pub view: PiecewiseLinear2D<N>,
|
size: Vector2i,
|
||||||
_storage: Arc<PiecewiseLinear2DStorage<N>>,
|
inv_patch_size: Vector2f,
|
||||||
|
param_size: [u32; N],
|
||||||
|
param_strides: [u32; N],
|
||||||
|
storage: Arc<PiecewiseLinear2DStorage<N>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl<const N: usize> PiecewiseLinear2DHost<N> {
|
impl<const N: usize> PiecewiseLinear2DHost<N> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
data: &[Float],
|
data: &[Float],
|
||||||
|
|
@ -354,27 +342,49 @@ impl<const N: usize> PiecewiseLinear2DHost<N> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> DeviceRepr for PiecewiseLinear2DHost<N> {
|
||||||
|
type Target = PiecewiseLinear2D<N>;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> PiecewiseLinear2D<N> {
|
||||||
|
let s = &self.storage;
|
||||||
|
|
||||||
|
let (data_ptr, _) = arena.alloc_slice(&s.data);
|
||||||
|
let (marginal_ptr, _) = arena.alloc_slice(&s.marginal_cdf);
|
||||||
|
let (conditional_ptr, _) = arena.alloc_slice(&s.conditional_cdf);
|
||||||
|
|
||||||
|
let param_ptrs: [Ptr<Float>; N] = std::array::from_fn(|i| {
|
||||||
|
let (ptr, _) = arena.alloc_slice(&s.param_values[i]);
|
||||||
|
ptr
|
||||||
|
});
|
||||||
|
|
||||||
|
PiecewiseLinear2D {
|
||||||
|
size: self.size,
|
||||||
|
inv_patch_size: self.inv_patch_size,
|
||||||
|
param_size: self.param_size,
|
||||||
|
param_strides: self.param_strides,
|
||||||
|
param_values: param_ptrs,
|
||||||
|
data: data_ptr,
|
||||||
|
marginal_cdf: marginal_ptr,
|
||||||
|
conditional_cdf: conditional_ptr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct AliasTableHost {
|
pub struct AliasTableHost {
|
||||||
pub view: AliasTable,
|
bins: Vec<Bin>,
|
||||||
_storage: Vec<Bin>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AliasTableHost {
|
impl AliasTableHost {
|
||||||
pub fn new(weights: &[Float]) -> Self {
|
pub fn new(weights: &[Float]) -> Self {
|
||||||
let n = weights.len();
|
let n = weights.len();
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return Self {
|
return Self { bins: Vec::new() };
|
||||||
view: AliasTable {
|
|
||||||
bins: Ptr::null(),
|
|
||||||
size: 0,
|
|
||||||
},
|
|
||||||
_storage: Vec::new(),
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let sum: f64 = weights.iter().map(|&w| w as f64).sum();
|
let sum: f64 = weights.iter().map(|&w| w as f64).sum();
|
||||||
assert!(sum > 0.0, "Sum of weights must be positive");
|
assert!(sum > 0.0, "Sum of weights must be positive");
|
||||||
|
|
||||||
let mut bins = Vec::with_capacity(n);
|
let mut bins = Vec::with_capacity(n);
|
||||||
for &w in weights {
|
for &w in weights {
|
||||||
bins.push(Bin {
|
bins.push(Bin {
|
||||||
|
|
@ -384,10 +394,7 @@ impl AliasTableHost {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Outcome {
|
struct Outcome { p_hat: f64, index: usize }
|
||||||
p_hat: f64,
|
|
||||||
index: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut under = Vec::with_capacity(n);
|
let mut under = Vec::with_capacity(n);
|
||||||
let mut over = Vec::with_capacity(n);
|
let mut over = Vec::with_capacity(n);
|
||||||
|
|
@ -409,17 +416,10 @@ impl AliasTableHost {
|
||||||
bins[un.index].alias = ov.index as u32;
|
bins[un.index].alias = ov.index as u32;
|
||||||
|
|
||||||
let p_excess = un.p_hat + ov.p_hat - 1.0;
|
let p_excess = un.p_hat + ov.p_hat - 1.0;
|
||||||
|
|
||||||
if p_excess < 1.0 {
|
if p_excess < 1.0 {
|
||||||
under.push(Outcome {
|
under.push(Outcome { p_hat: p_excess, index: ov.index });
|
||||||
p_hat: p_excess,
|
|
||||||
index: ov.index,
|
|
||||||
});
|
|
||||||
} else {
|
} else {
|
||||||
over.push(Outcome {
|
over.push(Outcome { p_hat: p_excess, index: ov.index });
|
||||||
p_hat: p_excess,
|
|
||||||
index: ov.index,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -427,51 +427,40 @@ impl AliasTableHost {
|
||||||
bins[ov.index].q = 1.0;
|
bins[ov.index].q = 1.0;
|
||||||
bins[ov.index].alias = ov.index as u32;
|
bins[ov.index].alias = ov.index as u32;
|
||||||
}
|
}
|
||||||
|
|
||||||
while let Some(un) = under.pop() {
|
while let Some(un) = under.pop() {
|
||||||
bins[un.index].q = 1.0;
|
bins[un.index].q = 1.0;
|
||||||
bins[un.index].alias = un.index as u32;
|
bins[un.index].alias = un.index as u32;
|
||||||
}
|
}
|
||||||
|
|
||||||
let view = AliasTable {
|
Self { bins }
|
||||||
bins: bins.as_ptr().into(),
|
|
||||||
size: bins.len() as u32,
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
view,
|
|
||||||
_storage: bins,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_device<A: GpuAllocator>(&self, arena: &Arena<A>) -> AliasTable {
|
pub fn size(&self) -> usize { self.bins.len() }
|
||||||
if self._storage.is_empty() {
|
pub fn is_empty(&self) -> bool { self.bins.is_empty() }
|
||||||
return AliasTable {
|
}
|
||||||
bins: Ptr::null(),
|
|
||||||
size: 0,
|
impl DeviceRepr for AliasTableHost {
|
||||||
};
|
type Target = AliasTable;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> AliasTable {
|
||||||
|
if self.bins.is_empty() {
|
||||||
|
return AliasTable { bins: Ptr::null(), size: 0 };
|
||||||
}
|
}
|
||||||
let (bins_ptr, _) = arena.alloc_slice(&self._storage);
|
let (bins_ptr, _) = arena.alloc_slice(&self.bins);
|
||||||
AliasTable {
|
AliasTable {
|
||||||
bins: bins_ptr,
|
bins: bins_ptr,
|
||||||
size: self._storage.len() as u32,
|
size: self.bins.len() as u32,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, DeviceRepr)]
|
||||||
|
#[device(name = "DeviceSummedAreaTable")]
|
||||||
pub struct SummedAreaTable {
|
pub struct SummedAreaTable {
|
||||||
pub device: DeviceSummedAreaTable,
|
#[device(flatten)]
|
||||||
sum: Array2D<f64>,
|
sum: Array2D<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for SummedAreaTable {
|
|
||||||
type Target = DeviceSummedAreaTable;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SummedAreaTable {
|
impl SummedAreaTable {
|
||||||
pub fn new(values: &Array2D<Float>) -> Self {
|
pub fn new(values: &Array2D<Float>) -> Self {
|
||||||
let width = values.x_size() as i32;
|
let width = values.x_size() as i32;
|
||||||
|
|
@ -483,46 +472,194 @@ impl SummedAreaTable {
|
||||||
for x in 1..width {
|
for x in 1..width {
|
||||||
sum[(x, 0)] = values[(x, 0)] as f64 + sum[(x - 1, 0)];
|
sum[(x, 0)] = values[(x, 0)] as f64 + sum[(x - 1, 0)];
|
||||||
}
|
}
|
||||||
|
|
||||||
for y in 1..height {
|
for y in 1..height {
|
||||||
sum[(0, y)] = values[(0, y)] as f64 + sum[(0, y - 1)];
|
sum[(0, y)] = values[(0, y)] as f64 + sum[(0, y - 1)];
|
||||||
}
|
}
|
||||||
|
|
||||||
for y in 1..height {
|
for y in 1..height {
|
||||||
for x in 1..width {
|
for x in 1..width {
|
||||||
let term = values[(x, y)] as f64;
|
sum[(x, y)] = values[(x, y)] as f64
|
||||||
let left = sum[(x - 1, y)];
|
+ sum[(x - 1, y)]
|
||||||
let up = sum[(x, y - 1)];
|
+ sum[(x, y - 1)]
|
||||||
let diag = sum[(x - 1, y - 1)];
|
- sum[(x - 1, y - 1)];
|
||||||
|
|
||||||
sum[(x, y)] = term + left + up - diag;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let device = DeviceSummedAreaTable { sum: *sum };
|
Self { sum }
|
||||||
Self { device, sum }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, DeviceRepr)]
|
||||||
|
#[device(name = "DeviceWindowedPiecewiseConstant2D")]
|
||||||
pub struct WindowedPiecewiseConstant2D {
|
pub struct WindowedPiecewiseConstant2D {
|
||||||
pub device: DeviceWindowedPiecewiseConstant2D,
|
#[device(flatten)]
|
||||||
sat: DeviceSummedAreaTable,
|
sat: SummedAreaTable,
|
||||||
|
#[device(flatten)]
|
||||||
func: Array2D<Float>,
|
func: Array2D<Float>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for WindowedPiecewiseConstant2D {
|
|
||||||
type Target = DeviceWindowedPiecewiseConstant2D;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WindowedPiecewiseConstant2D {
|
impl WindowedPiecewiseConstant2D {
|
||||||
pub fn new(func: Array2D<Float>) -> Self {
|
pub fn new(func: Array2D<Float>) -> Self {
|
||||||
let sat = *SummedAreaTable::new(&func);
|
let sat = SummedAreaTable::new(&func);
|
||||||
let device = DeviceWindowedPiecewiseConstant2D { sat, func: *func };
|
Self { sat, func }
|
||||||
|
}
|
||||||
Self { sat, func, device }
|
}
|
||||||
|
|
||||||
|
struct PiecewiseLinear2DStorage<const N: usize> {
|
||||||
|
data: Vec<Float>,
|
||||||
|
marginal_cdf: Vec<Float>,
|
||||||
|
conditional_cdf: Vec<Float>,
|
||||||
|
param_values: [Vec<Float>; N],
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PiecewiseLinear2DHost<const N: usize> {
|
||||||
|
size: Vector2i,
|
||||||
|
inv_patch_size: Vector2f,
|
||||||
|
param_size: [u32; N],
|
||||||
|
param_strides: [u32; N],
|
||||||
|
storage: Arc<PiecewiseLinear2DStorage<N>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> PiecewiseLinear2DHost<N> {
|
||||||
|
pub fn new(
|
||||||
|
data: &[Float],
|
||||||
|
x_size: i32,
|
||||||
|
y_size: i32,
|
||||||
|
param_res: [usize; N],
|
||||||
|
param_values: [&[Float]; N],
|
||||||
|
normalize: bool,
|
||||||
|
build_cdf: bool,
|
||||||
|
) -> Self {
|
||||||
|
if build_cdf && !normalize {
|
||||||
|
panic!("PiecewiseLinear2D: build_cdf implies normalize=true");
|
||||||
|
}
|
||||||
|
|
||||||
|
let size = Vector2i::new(x_size, y_size);
|
||||||
|
let inv_patch_size = Vector2f::new(
|
||||||
|
1.0 / (x_size - 1) as Float,
|
||||||
|
1.0 / (y_size - 1) as Float,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut param_size = [0u32; N];
|
||||||
|
let mut param_strides = [0u32; N];
|
||||||
|
let owned_param_values: [Vec<Float>; N] = gpu_array_from_fn(|i| param_values[i].to_vec());
|
||||||
|
|
||||||
|
let mut slices: u32 = 1;
|
||||||
|
for i in (0..N).rev() {
|
||||||
|
assert!(param_res[i] >= 1, "Parameter resolution must be >= 1");
|
||||||
|
param_size[i] = param_res[i] as u32;
|
||||||
|
param_strides[i] = if param_res[i] > 1 { slices } else { 0 };
|
||||||
|
slices *= param_size[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
let n_values = (x_size * y_size) as usize;
|
||||||
|
let mut new_data = vec![0.0; slices as usize * n_values];
|
||||||
|
let mut marginal_cdf = if build_cdf {
|
||||||
|
vec![0.0; slices as usize * y_size as usize]
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
let mut conditional_cdf = if build_cdf {
|
||||||
|
vec![0.0; slices as usize * n_values]
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut data_offset = 0;
|
||||||
|
for slice in 0..slices as usize {
|
||||||
|
let slice_offset = slice * n_values;
|
||||||
|
let current_data = &data[data_offset..data_offset + n_values];
|
||||||
|
let mut sum = 0.0_f64;
|
||||||
|
|
||||||
|
if normalize {
|
||||||
|
for y in 0..(y_size - 1) {
|
||||||
|
for x in 0..(x_size - 1) {
|
||||||
|
let i = (y * x_size + x) as usize;
|
||||||
|
let v00 = current_data[i] as f64;
|
||||||
|
let v10 = current_data[i + 1] as f64;
|
||||||
|
let v01 = current_data[i + x_size as usize] as f64;
|
||||||
|
let v11 = current_data[i + 1 + x_size as usize] as f64;
|
||||||
|
sum += 0.25 * (v00 + v10 + v01 + v11);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let normalization = if normalize && sum > 0.0 {
|
||||||
|
1.0 / sum as Float
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
};
|
||||||
|
for k in 0..n_values {
|
||||||
|
new_data[slice_offset + k] = current_data[k] * normalization;
|
||||||
|
}
|
||||||
|
|
||||||
|
if build_cdf {
|
||||||
|
let marginal_slice_offset = slice * y_size as usize;
|
||||||
|
for y in 0..y_size as usize {
|
||||||
|
let mut cdf_sum = 0.0;
|
||||||
|
let i_base = y * x_size as usize;
|
||||||
|
conditional_cdf[slice_offset + i_base] = 0.0;
|
||||||
|
for x in 0..(x_size - 1) as usize {
|
||||||
|
let i = i_base + x;
|
||||||
|
cdf_sum += 0.5
|
||||||
|
* (new_data[slice_offset + i] + new_data[slice_offset + i + 1]);
|
||||||
|
conditional_cdf[slice_offset + i + 1] = cdf_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
marginal_cdf[marginal_slice_offset] = 0.0;
|
||||||
|
let mut marginal_sum = 0.0;
|
||||||
|
for y in 0..(y_size - 1) as usize {
|
||||||
|
let cdf1 =
|
||||||
|
conditional_cdf[slice_offset + (y + 1) * x_size as usize - 1];
|
||||||
|
let cdf2 =
|
||||||
|
conditional_cdf[slice_offset + (y + 2) * x_size as usize - 1];
|
||||||
|
marginal_sum += 0.5 * (cdf1 + cdf2);
|
||||||
|
marginal_cdf[marginal_slice_offset + y + 1] = marginal_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data_offset += n_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
let storage = Arc::new(PiecewiseLinear2DStorage {
|
||||||
|
data: new_data,
|
||||||
|
marginal_cdf,
|
||||||
|
conditional_cdf,
|
||||||
|
param_values: owned_param_values,
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
size,
|
||||||
|
inv_patch_size,
|
||||||
|
param_size,
|
||||||
|
param_strides,
|
||||||
|
storage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> DeviceRepr for PiecewiseLinear2DHost<N> {
|
||||||
|
type Target = PiecewiseLinear2D<N>;
|
||||||
|
|
||||||
|
fn upload_value<A: GpuAllocator>(&self, arena: &Arena<A>) -> PiecewiseLinear2D<N> {
|
||||||
|
let s = &self.storage;
|
||||||
|
|
||||||
|
let (data_ptr, _) = arena.alloc_slice(&s.data);
|
||||||
|
let (marginal_ptr, _) = arena.alloc_slice(&s.marginal_cdf);
|
||||||
|
let (conditional_ptr, _) = arena.alloc_slice(&s.conditional_cdf);
|
||||||
|
|
||||||
|
let param_ptrs: [Ptr<Float>; N] = std::array::from_fn(|i| {
|
||||||
|
let (ptr, _) = arena.alloc_slice(&s.param_values[i]);
|
||||||
|
ptr
|
||||||
|
});
|
||||||
|
|
||||||
|
PiecewiseLinear2D {
|
||||||
|
size: self.size,
|
||||||
|
inv_patch_size: self.inv_patch_size,
|
||||||
|
param_size: self.param_size,
|
||||||
|
param_strides: self.param_strides,
|
||||||
|
param_values: param_ptrs,
|
||||||
|
data: data_ptr,
|
||||||
|
marginal_cdf: marginal_ptr,
|
||||||
|
conditional_cdf: conditional_ptr,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue