use rayon::prelude::*; use shared::Float; use shared::core::geometry::{Bounds3f, Point3f, Ray, Vector3f}; use shared::core::primitive::PrimitiveTrait; use shared::core::shape::ShapeIntersection; use shared::utils::math::encode_morton_3; use shared::utils::{find_interval, partition_slice}; use std::cmp::Ordering; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; #[repr(C)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SplitMethod { AH, Hlbvh, Middle, EqualCounts, } #[repr(C)] #[derive(Debug, Default, Clone, Copy, PartialEq)] struct BVHSplitBucket { pub count: usize, pub bounds: Bounds3f, } #[derive(Debug, Clone, Default)] pub struct LinearBVHNode { pub bounds: Bounds3f, pub primitives_offset: usize, pub n_primitives: u16, pub axis: u8, pub pad: u8, } #[derive(Debug, Clone, Copy, Default)] struct MortonPrimitive { primitive_index: usize, morton_code: u32, } struct LBVHTreelet { start_index: usize, n_primitives: usize, } #[derive(Debug, Clone)] pub struct BVHPrimitiveInfo { primitive_number: usize, // Index into the original primitives vector bounds: Bounds3f, centroid: Point3f, } impl BVHPrimitiveInfo { fn new(primitive_number: usize, bounds: Bounds3f) -> Self { Self { primitive_number, bounds, centroid: bounds.centroid(), } } } #[derive(Clone, Debug)] pub enum BVHBuildNode { Leaf { first_prim_offset: usize, n_primitives: usize, bounds: Bounds3f, }, Interior { split_axis: u8, children: [Box; 2], bounds: Bounds3f, }, } impl Default for BVHBuildNode { fn default() -> Self { BVHBuildNode::Leaf { first_prim_offset: 0, n_primitives: 0, bounds: Bounds3f::default(), } } } impl BVHBuildNode { pub fn new_leaf(first_prim_offset: usize, n_primitives: usize, bounds: Bounds3f) -> Self { Self::Leaf { bounds, first_prim_offset, n_primitives, } } pub fn new_interior(axis: u8, c0: Box, c1: Box) -> Self { let bounds = c0.bounds().union(c1.bounds()); Self::Interior { bounds, children: [c0, c1], split_axis: axis, } } pub fn bounds(&self) -> Bounds3f { match self { Self::Leaf { bounds, .. } => *bounds, Self::Interior { bounds, .. } => *bounds, } } pub fn split_axis(&self) -> Option { match self { Self::Interior { split_axis, .. } => Some(*split_axis), _ => None, } } } pub struct SharedPrimitiveBuffer<'a> { ptr: *mut Arc, pub offset: &'a AtomicUsize, _marker: std::marker::PhantomData<&'a mut [Arc]>, } unsafe impl<'a> Sync for SharedPrimitiveBuffer<'a> {} unsafe impl<'a> Send for SharedPrimitiveBuffer<'a> {} impl<'a> SharedPrimitiveBuffer<'a> { pub fn new(slice: &'a mut [Arc], offset: &'a AtomicUsize) -> Self { Self { ptr: slice.as_mut_ptr(), offset, _marker: std::marker::PhantomData, } } pub fn append( &self, primitives: &[Arc], indices: &[BVHPrimitiveInfo], ) -> usize { let count = indices.len(); let start_index = self.offset.fetch_add(count, AtomicOrdering::Relaxed); unsafe { for (i, info) in indices.iter().enumerate() { let target_ptr = self.ptr.add(start_index + i); std::ptr::write(target_ptr, primitives[info.primitive_number].clone()); } } start_index } } pub struct BVHAggregate { max_prims_in_node: usize, primitives: Vec>, split_method: SplitMethod, nodes: Vec, } impl BVHAggregate { pub fn new( mut primitives: Vec>, max_prims_in_node: usize, split_method: SplitMethod, ) -> Self { let max_prims_in_node = std::cmp::min(255, max_prims_in_node); if primitives.is_empty() { return Self { max_prims_in_node, primitives, split_method, nodes: Vec::new(), }; } let mut primitive_info: Vec = primitives .iter() .enumerate() .map(|(i, p)| BVHPrimitiveInfo::new(i, p.bounds())) .collect(); let ordered_prims: Vec>; let total_nodes_count: usize; let root: Box; match split_method { SplitMethod::Hlbvh => { let nodes_counter = AtomicUsize::new(0); let ordered_prims_offset = AtomicUsize::new(0); let mut local_ordered = vec![primitives[0].clone(); primitives.len()]; let shared_buffer = SharedPrimitiveBuffer::new(&mut local_ordered, &ordered_prims_offset); root = Self::build_hlbvh(&primitive_info, &nodes_counter, &shared_buffer, &primitives); ordered_prims = local_ordered; total_nodes_count = nodes_counter.load(AtomicOrdering::Relaxed); } _ => { let nodes_counter = AtomicUsize::new(0); let ordered_prims_offset = AtomicUsize::new(0); let mut local_ordered = vec![primitives[0].clone(); primitives.len()]; let shared_buffer = SharedPrimitiveBuffer::new(&mut local_ordered, &ordered_prims_offset); root = Self::build_recursive( &mut primitive_info, &nodes_counter, &shared_buffer, &primitives, max_prims_in_node, split_method, ); ordered_prims = local_ordered; total_nodes_count = nodes_counter.load(AtomicOrdering::Relaxed); } }; primitives = ordered_prims; let mut nodes = vec![LinearBVHNode::default(); total_nodes_count]; let mut offset = 0; Self::flatten_bvh(&root, &mut nodes, &mut offset); Self { max_prims_in_node, primitives, split_method, nodes, } } fn flatten_bvh(node: &BVHBuildNode, nodes: &mut [LinearBVHNode], offset: &mut usize) -> usize { let local_offset = *offset; *offset += 1; match node { BVHBuildNode::Leaf { first_prim_offset, n_primitives, bounds, } => { let linear_node = &mut nodes[local_offset]; linear_node.bounds = *bounds; linear_node.n_primitives = *n_primitives as u16; linear_node.primitives_offset = *first_prim_offset; linear_node.axis = 0; // Irrelevant for leaves } BVHBuildNode::Interior { split_axis, children, bounds, } => { nodes[local_offset].bounds = *bounds; nodes[local_offset].axis = *split_axis; nodes[local_offset].n_primitives = 0; Self::flatten_bvh(&children[0], nodes, offset); let second_child_offset = Self::flatten_bvh(&children[1], nodes, offset); nodes[local_offset].primitives_offset = second_child_offset; } } local_offset } pub fn build_hlbvh( bvh_primitives: &[BVHPrimitiveInfo], total_nodes: &AtomicUsize, ordered_prims: &SharedPrimitiveBuffer, original_primitives: &[Arc], ) -> Box { let bounds = bvh_primitives .iter() .fold(Bounds3f::default(), |b, p| b.union(p.bounds)); let mut morton_prims: Vec = bvh_primitives .par_iter() .map(|prim| { const MORTON_BITS: i32 = 10; const MORTON_SCALE: i32 = 1 << MORTON_BITS; let centroid_offset = bounds.offset(&prim.centroid); let offset = centroid_offset * (MORTON_SCALE as Float); MortonPrimitive { primitive_index: prim.primitive_number, morton_code: encode_morton_3(offset.x(), offset.y(), offset.z()), } }) .collect(); morton_prims.par_sort_unstable_by_key(|p| p.morton_code); const TREELET_MASK: u32 = 0b00111111111111000000000000000000; let mut split_indices: Vec = morton_prims .par_windows(2) // Iterates over overlapping pairs [i, i+1] .enumerate() .filter_map(|(i, w)| { let m1 = w[0].morton_code & TREELET_MASK; let m2 = w[1].morton_code & TREELET_MASK; // If mask changes, the split is at index i + 1 if m1 != m2 { Some(i + 1) } else { None } }) .collect(); let mut boundaries = Vec::with_capacity(split_indices.len() + 2); boundaries.push(0); boundaries.append(&mut split_indices); boundaries.push(morton_prims.len()); let treelets_to_build: Vec = boundaries .windows(2) .map(|w| LBVHTreelet { start_index: w[0], n_primitives: w[1] - w[0], }) .collect(); let treelet_roots: Vec> = treelets_to_build .par_iter() .map(|tr| { let mut nodes_created = 0; const FIRST_BIT_INDEX: i32 = 29 - 12; let root = Self::emit_lbvh( bvh_primitives, &morton_prims[tr.start_index..tr.start_index + tr.n_primitives], &mut nodes_created, ordered_prims, original_primitives, FIRST_BIT_INDEX, 4, ); total_nodes.fetch_add(nodes_created, AtomicOrdering::Relaxed); root }) .collect(); let mut contiguous_nodes: Vec = treelet_roots .into_iter() .map(|node_box| *node_box) .collect(); Self::build_upper_sah(&mut contiguous_nodes, total_nodes) } fn emit_lbvh( bvh_primitives: &[BVHPrimitiveInfo], morton_prims: &[MortonPrimitive], total_nodes: &mut usize, ordered_prims: &SharedPrimitiveBuffer, original_primitives: &[Arc], bit_index: i32, max_prims_in_node: usize, ) -> Box { let n_primitives = morton_prims.len(); if bit_index == -1 || n_primitives <= max_prims_in_node { *total_nodes += 1; // Calculate bounds while collecting indices let mut bounds = Bounds3f::default(); let mut indices = Vec::with_capacity(n_primitives); for mp in morton_prims { let info = &bvh_primitives[mp.primitive_index]; bounds = bounds.union(info.bounds); indices.push(info.clone()); } let first_prim_offset = ordered_prims.append(original_primitives, &indices); return Box::new(BVHBuildNode::new_leaf( first_prim_offset, n_primitives, bounds, )); } let mask = 1 << bit_index; let first_code = morton_prims[0].morton_code; let last_match_index = find_interval(n_primitives, |index| { let current_code = morton_prims[index].morton_code; (current_code & mask) == (first_code & mask) }); let split_offset = (last_match_index + 1) as usize; if split_offset >= n_primitives { return Self::emit_lbvh( bvh_primitives, morton_prims, total_nodes, ordered_prims, original_primitives, bit_index - 1, max_prims_in_node, ); } let (left_morton, right_morton) = morton_prims.split_at(split_offset); *total_nodes += 1; let child0 = Self::emit_lbvh( bvh_primitives, left_morton, total_nodes, ordered_prims, original_primitives, bit_index - 1, max_prims_in_node, ); let child1 = Self::emit_lbvh( bvh_primitives, right_morton, total_nodes, ordered_prims, original_primitives, bit_index - 1, max_prims_in_node, ); let axis = (bit_index % 3) as u8; Box::new(BVHBuildNode::new_interior(axis, child0, child1)) } fn build_upper_sah(nodes: &mut [BVHBuildNode], total_nodes: &AtomicUsize) -> Box { let n_nodes = nodes.len(); if n_nodes == 1 { return Box::new(nodes[0].clone()); } total_nodes.fetch_add(1, AtomicOrdering::Relaxed); let bounds = nodes .iter() .fold(Bounds3f::default(), |b, node| b.union(node.bounds())); let centroid_bounds = nodes.iter().fold(Bounds3f::default(), |b, node| { b.union_point(node.bounds().centroid()) }); let dim = centroid_bounds.max_dimension(); if centroid_bounds.p_max[dim] == centroid_bounds.p_min[dim] { let mid = n_nodes / 2; let (left_part, right_part) = nodes.split_at_mut(mid); return Box::new(BVHBuildNode::new_interior( dim as u8, Self::build_upper_sah(left_part, total_nodes), Self::build_upper_sah(right_part, total_nodes), )); } const N_BUCKETS: usize = 12; #[derive(Copy, Clone, Default)] struct Bucket { count: usize, bounds: Bounds3f, } let mut buckets = [Bucket::default(); N_BUCKETS]; let get_bucket_idx = |node: &BVHBuildNode| -> usize { let offset = centroid_bounds.offset(&node.bounds().centroid())[dim]; let mut b = (N_BUCKETS as Float * offset) as usize; if b == N_BUCKETS { b = N_BUCKETS - 1; } b }; // Initialize _Bucket_ for HLBVH SAH partition buckets for node in nodes.iter() { let b = get_bucket_idx(node); buckets[b].count += 1; buckets[b].bounds = buckets[b].bounds.union(node.bounds()); } // Compute costs for splitting after each bucket let mut cost = [0.0; N_BUCKETS - 1]; // Forward Pass: Accumulate Left side (0 -> N-1) let mut left_area = [0.0; N_BUCKETS]; let mut left_count = [0; N_BUCKETS]; let mut b_left = Bounds3f::default(); let mut c_left = 0; for i in 0..N_BUCKETS { b_left = b_left.union(buckets[i].bounds); c_left += buckets[i].count; left_area[i] = b_left.surface_area(); left_count[i] = c_left; } // Backward Pass: Accumulate Right side (N-1 -> 0) and compute cost let mut b_right = Bounds3f::default(); let mut c_right = 0; let inv_total_sa = 1.0 / bounds.surface_area(); for i in (0..N_BUCKETS - 1).rev() { b_right = b_right.union(buckets[i + 1].bounds); c_right += buckets[i + 1].count; let count_left = left_count[i]; let sa_left = left_area[i]; let sa_right = b_right.surface_area(); cost[i] = 0.125 + (count_left as Float * sa_left + c_right as Float * sa_right) * inv_total_sa; } // Find bucket to split at that minimizes SAH metric let mut min_cost = cost[0]; let mut min_cost_split_bucket = 0; for (i, &c) in cost.iter().enumerate().skip(1) { if c < min_cost { min_cost = c; min_cost_split_bucket = i; } } // Split nodes and create interior HLBVH SAH node let mid = { let mut left = 0; for i in 0..n_nodes { let b = get_bucket_idx(&nodes[i]); if b <= min_cost_split_bucket { nodes.swap(left, i); left += 1; } } left }; if mid == 0 || mid == n_nodes { let mid = n_nodes / 2; // Partially sort so the median is in the middle and elements are partitioned around it nodes.select_nth_unstable_by(mid, |a, b| { a.bounds().centroid()[dim] .partial_cmp(&b.bounds().centroid()[dim]) .unwrap_or(std::cmp::Ordering::Equal) }); let (left_part, right_part) = nodes.split_at_mut(mid); Box::new(BVHBuildNode::new_interior( dim as u8, Self::build_upper_sah(left_part, total_nodes), Self::build_upper_sah(right_part, total_nodes), )) } else { // Standard SAH Split let (left_part, right_part) = nodes.split_at_mut(mid); Box::new(BVHBuildNode::new_interior( dim as u8, Self::build_upper_sah(left_part, total_nodes), Self::build_upper_sah(right_part, total_nodes), )) } } fn build_recursive( bvh_primitives: &mut [BVHPrimitiveInfo], total_nodes: &AtomicUsize, ordered_prims: &SharedPrimitiveBuffer, original_primitives: &[Arc], max_prims_in_node: usize, split_method: SplitMethod, ) -> Box { total_nodes.fetch_add(1, AtomicOrdering::Relaxed); let bounds = bvh_primitives .iter() .fold(Bounds3f::default(), |b, p| b.union(p.bounds)); let n_primitives = bvh_primitives.len(); if bounds.surface_area() == 0.0 || n_primitives == 1 || n_primitives <= max_prims_in_node { let first_prim_offset = ordered_prims.append(original_primitives, bvh_primitives); return Box::new(BVHBuildNode::new_leaf( first_prim_offset, n_primitives, bounds, )); } let centroid_bounds = bvh_primitives.iter().fold(Bounds3f::default(), |b, p| { b.union_point(p.bounds.centroid()) }); let dim = centroid_bounds.max_dimension(); if centroid_bounds.p_max[dim] == centroid_bounds.p_min[dim] { let first_prim_offset = ordered_prims.append(original_primitives, bvh_primitives); return Box::new(BVHBuildNode::new_leaf( first_prim_offset, n_primitives, bounds, )); } let mut mid: usize; match split_method { SplitMethod::Middle => { let pmid = (centroid_bounds.p_min[dim] + centroid_bounds.p_max[dim]) / 2.; mid = partition_slice(bvh_primitives, |p| p.centroid[dim] < pmid); if mid != 0 && mid != n_primitives { } else { mid = n_primitives / 2; bvh_primitives.select_nth_unstable_by(mid, |a, b| { a.centroid[dim].partial_cmp(&b.centroid[dim]).unwrap() }); } } SplitMethod::EqualCounts => { mid = n_primitives / 2; bvh_primitives.select_nth_unstable_by(mid, |a, b| { a.centroid[dim].partial_cmp(&b.centroid[dim]).unwrap() }); } SplitMethod::SAH | _ => { if n_primitives < 2 { mid = n_primitives / 2; bvh_primitives.select_nth_unstable_by(mid, |a, b| { a.centroid[dim] .partial_cmp(&b.centroid[dim]) .unwrap_or(Ordering::Equal) }); } else { const N_BUCKETS: usize = 12; let mut buckets = [BVHSplitBucket::default(); N_BUCKETS]; for prim in bvh_primitives.iter() { let mut b = (N_BUCKETS as Float * centroid_bounds.offset(&prim.centroid)[dim]) as usize; if b == N_BUCKETS { b = N_BUCKETS - 1; } buckets[b].count += 1; buckets[b].bounds = buckets[b].bounds.union(prim.bounds); } // Compute costs for splitting after each bucket> const N_SPLITS: usize = N_BUCKETS - 1; let mut costs = [0.0 as Float; N_SPLITS]; let mut count_below = 0; let mut bound_below = Bounds3f::default(); for i in 0..N_SPLITS { bound_below = bound_below.union(buckets[i].bounds); count_below += buckets[i].count; costs[i] += count_below as Float * bound_below.surface_area(); } // Finish initializing costs using a backward scan over splits let mut count_above = 0; let mut bound_above = Bounds3f::default(); for i in (0..N_SPLITS).rev() { bound_above = bound_above.union(buckets[i + 1].bounds); count_above += buckets[i + 1].count; costs[i] += count_above as Float * bound_above.surface_area(); } // Find bucket to split at that minimizes SAH metric> let mut min_cost = Float::INFINITY; let mut min_cost_split_bucket = 0; for (i, &cost) in costs.iter().enumerate().take(N_SPLITS) { if cost < min_cost { min_cost = cost; min_cost_split_bucket = i; } } // Compute leaf cost and SAH split cost for chosen split let leaf_cost = n_primitives as Float; min_cost = 0.5 + min_cost / bounds.surface_area(); // Either create leaf or split primitives at selected SAH bucket> if n_primitives > max_prims_in_node || min_cost < leaf_cost { mid = partition_slice(bvh_primitives, |bp| { let mut b = (N_BUCKETS as Float * centroid_bounds.offset(&bp.centroid)[dim]) as usize; if b == N_BUCKETS { b = N_BUCKETS - 1; } b <= min_cost_split_bucket }); } else { let first_prim_offset = ordered_prims.append(original_primitives, bvh_primitives); return Box::new(BVHBuildNode::new_leaf( first_prim_offset, n_primitives, bounds, )); } } } }; let (left_prims, right_prims) = bvh_primitives.split_at_mut(mid); if n_primitives > 128 * 1024 { let (child0, child1) = rayon::join( || { Self::build_recursive( left_prims, total_nodes, ordered_prims, original_primitives, max_prims_in_node, split_method, ) }, || { Self::build_recursive( right_prims, total_nodes, ordered_prims, original_primitives, max_prims_in_node, split_method, ) }, ); let axis = dim as u8; Box::new(BVHBuildNode::new_interior(axis, child0, child1)) } else { let child0 = Self::build_recursive( left_prims, total_nodes, ordered_prims, original_primitives, max_prims_in_node, split_method, ); let child1 = Self::build_recursive( right_prims, total_nodes, ordered_prims, original_primitives, max_prims_in_node, split_method, ); let axis = dim as u8; Box::new(BVHBuildNode::new_interior(axis, child0, child1)) } } pub fn intersect(&self, r: &Ray, t_max: Option) -> Option { if self.nodes.is_empty() { return None; } let mut best_si: Option = None; let mut hit_t = t_max.unwrap_or(Float::INFINITY); let inv_dir = Vector3f::new(1.0 / r.d.x(), 1.0 / r.d.y(), 1.0 / r.d.z()); let dir_is_neg = [ if inv_dir.x() < 0.0 { 1 } else { 0 }, if inv_dir.y() < 0.0 { 1 } else { 0 }, if inv_dir.z() < 0.0 { 1 } else { 0 }, ]; let mut to_visit_offset = 0; let mut current_node_index = 0; let mut nodes_to_visit = [0usize; 64]; loop { let node = &self.nodes[current_node_index]; // Check ray against BVH node bounds using the current closest hit_t if node .bounds .intersect_p(r.o, hit_t, inv_dir, &dir_is_neg) .is_some() { if node.n_primitives > 0 { // Intersect ray with all primitives in this leaf for i in 0..node.n_primitives { let prim_idx = node.primitives_offset + i as usize; let prim = &self.primitives[prim_idx]; if let Some(si) = prim.intersect(r, Some(hit_t)) { hit_t = si.t_hit(); best_si = Some(si); } } if to_visit_offset == 0 { break; } to_visit_offset -= 1; current_node_index = nodes_to_visit[to_visit_offset]; } else { // Check the sign of the ray direction against the split axis if dir_is_neg[node.axis as usize] == 1 { // Ray is negative (Right -> Left). // Near child is Second Child (stored in primitives_offset). // Far child is First Child (current + 1). // Push Far nodes_to_visit[to_visit_offset] = current_node_index + 1; to_visit_offset += 1; // Visit Near immediately current_node_index = node.primitives_offset; } else { // Ray is positive (Left -> Right). // Push Far nodes_to_visit[to_visit_offset] = node.primitives_offset; to_visit_offset += 1; current_node_index += 1; } } } else { // The ray missed the AABB of this node. Pop stack to try the next node. if to_visit_offset == 0 { break; } to_visit_offset -= 1; current_node_index = nodes_to_visit[to_visit_offset]; } } best_si } fn intersect_p(&self, r: &Ray, t_max: Option) -> bool { if self.nodes.is_empty() { return false; } let t_max = t_max.unwrap_or(Float::INFINITY); let inv_dir = Vector3f::new(1.0 / r.d.x(), 1.0 / r.d.y(), 1.0 / r.d.z()); let dir_is_neg = [ if inv_dir.x() < 0.0 { 1 } else { 0 }, if inv_dir.y() < 0.0 { 1 } else { 0 }, if inv_dir.z() < 0.0 { 1 } else { 0 }, ]; let mut to_visit_offset = 0; let mut current_node_index = 0; let mut nodes_to_visit = [0usize; 64]; loop { let node = &self.nodes[current_node_index]; // Check AABB if node .bounds .intersect_p(r.o, t_max, inv_dir, &dir_is_neg) .is_some() { if node.n_primitives > 0 { for i in 0..node.n_primitives { let prim_idx = node.primitives_offset + i as usize; let prim = &self.primitives[prim_idx]; if prim.intersect_p(r, Some(t_max)) { return true; } } // No intersection in this leaf, try next node in stack if to_visit_offset == 0 { break; } to_visit_offset -= 1; current_node_index = nodes_to_visit[to_visit_offset]; } else { // Standard front-to-back traversal order helps find an occlusion // closer to the origin faster, potentially saving work. if dir_is_neg[node.axis as usize] == 1 { nodes_to_visit[to_visit_offset] = current_node_index + 1; to_visit_offset += 1; current_node_index = node.primitives_offset; } else { nodes_to_visit[to_visit_offset] = node.primitives_offset; to_visit_offset += 1; current_node_index += 1; } } } else { if to_visit_offset == 0 { break; } to_visit_offset -= 1; current_node_index = nodes_to_visit[to_visit_offset]; } } false } }