pbrt/shared/src/utils/atomic.rs

263 lines
6.8 KiB
Rust

use crate::Float;
pub const SCOPE_DEVICE: u32 = 1;
#[allow(dead_code)]
pub const SCOPE_WORKGROUP: u32 = 2;
pub const SEMANTICS_RELAXED: u32 = 0x0;
#[allow(dead_code)]
pub const SEMANTICS_ACQUIRE_RELEASE: u32 = 0x8;
#[repr(C)]
#[derive(Debug)]
pub struct AtomicU32 {
value: u32,
}
impl Default for AtomicU32 {
fn default() -> Self {
Self::new(0)
}
}
impl Clone for AtomicU32 {
fn clone(&self) -> Self {
Self::new(self.load())
}
}
impl AtomicU32 {
pub fn new(val: u32) -> Self {
Self { value: val }
}
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
#[inline(always)]
pub fn load(&self) -> u32 {
let atomic = unsafe {
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
};
atomic.load(core::sync::atomic::Ordering::Relaxed)
}
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
#[inline(always)]
pub fn store(&self, val: u32) {
let atomic = unsafe {
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
};
atomic.store(val, core::sync::atomic::Ordering::Relaxed);
}
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
#[inline(always)]
pub fn fetch_add(&self, val: u32) -> u32 {
let atomic = unsafe {
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
};
atomic.fetch_add(val, core::sync::atomic::Ordering::Relaxed)
}
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
#[inline(always)]
pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result<u32, u32> {
let atomic = unsafe {
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
};
atomic.compare_exchange_weak(
expected,
desired,
core::sync::atomic::Ordering::Relaxed,
core::sync::atomic::Ordering::Relaxed,
)
}
#[cfg(target_arch = "spirv")]
#[inline(always)]
pub fn load(&self) -> u32 {
unsafe {
spirv_std::arch::atomic_load::<u32, SCOPE_DEVICE, SEMANTICS_RELAXED>(
&self.value,
)
}
}
#[cfg(target_arch = "spirv")]
#[inline(always)]
pub fn store(&self, val: u32) {
unsafe {
spirv_std::arch::atomic_store::<u32, SCOPE_DEVICE, SEMANTICS_RELAXED>(
&mut *core::ptr::addr_of!(self.value).cast_mut(),
val,
);
}
}
#[cfg(target_arch = "spirv")]
#[inline(always)]
pub fn fetch_add(&self, val: u32) -> u32 {
unsafe {
spirv_std::arch::atomic_i_add::<u32, SCOPE_DEVICE, SEMANTICS_RELAXED>(
&mut *core::ptr::addr_of!(self.value).cast_mut(),
val,
)
}
}
#[cfg(target_arch = "spirv")]
#[inline(always)]
pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result<u32, u32> {
let old = unsafe {
spirv_std::arch::atomic_compare_exchange::<
u32,
SCOPE_DEVICE,
SEMANTICS_RELAXED,
SEMANTICS_RELAXED,
>(
&mut *core::ptr::addr_of!(self.value).cast_mut(),
desired,
expected,
)
};
if old == expected {
Ok(old)
} else {
Err(old)
}
}
// -- CUDA backend --
#[cfg(feature = "cuda")]
#[inline(always)]
pub fn load(&self) -> u32 {
// CUDA volatile read for atomicity on the same SM
unsafe { core::ptr::read_volatile(&self.value) }
}
#[cfg(feature = "cuda")]
#[inline(always)]
pub fn store(&self, val: u32) {
unsafe {
core::ptr::write_volatile(
core::ptr::addr_of!(self.value).cast_mut(),
val,
);
}
}
#[cfg(feature = "cuda")]
#[inline(always)]
pub fn fetch_add(&self, val: u32) -> u32 {
let ptr = core::ptr::addr_of!(self.value).cast_mut();
let mut old: u32;
unsafe {
core::arch::asm!(
"atom.add.u32 {old}, [{ptr}], {val};",
old = out(reg32) old,
ptr = in(reg64) ptr,
val = in(reg32) val,
);
}
old
}
#[cfg(feature = "cuda")]
#[inline(always)]
pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result<u32, u32> {
let ptr = core::ptr::addr_of!(self.value).cast_mut();
let mut old: u32;
unsafe {
core::arch::asm!(
"atom.cas.b32 {old}, [{ptr}], {expected}, {desired};",
old = out(reg32) old,
ptr = in(reg64) ptr,
expected = in(reg32) expected,
desired = in(reg32) desired,
);
}
if old == expected {
Ok(old)
} else {
Err(old)
}
}
}
#[repr(C)]
#[derive(Debug)]
pub struct AtomicF32 {
bits: AtomicU32,
}
impl Default for AtomicF32 {
fn default() -> Self {
Self::new(0.0)
}
}
impl Clone for AtomicF32 {
fn clone(&self) -> Self {
Self::new(self.get())
}
}
impl AtomicF32 {
pub fn new(val: Float) -> Self {
Self {
bits: AtomicU32::new(val.to_bits()),
}
}
pub fn get(&self) -> Float {
Float::from_bits(self.bits.load())
}
pub fn set(&self, val: Float) {
self.bits.store(val.to_bits());
}
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
#[inline(always)]
pub fn add(&self, val: Float) {
let mut current_bits = self.bits.load();
loop {
let current_val = Float::from_bits(current_bits);
let new_val = current_val + val;
let new_bits = new_val.to_bits();
match self.bits.compare_exchange(current_bits, new_bits) {
Ok(_) => break,
Err(x) => current_bits = x,
}
}
}
#[cfg(target_arch = "spirv")]
#[inline(always)]
pub fn add(&self, val: Float) {
unsafe {
let float_ptr = core::ptr::addr_of!(self.bits.value) as *mut Float;
spirv_std::arch::atomic_f_add::<Float, SCOPE_DEVICE, SEMANTICS_RELAXED>(
&mut *float_ptr,
val,
);
}
}
#[cfg(feature = "cuda")]
#[inline(always)]
pub fn add(&self, val: Float) {
let ptr = core::ptr::addr_of!(self.bits.value) as *mut Float;
unsafe {
core::arch::asm!(
"atom.add.f32 {old}, [{ptr}], {val};",
old = out(reg32) _,
ptr = in(reg64) ptr,
val = in(reg32) val.to_bits(),
);
}
}
}
pub type AtomicFloat = AtomicF32;