use crate::Float; pub const SCOPE_DEVICE: u32 = 1; #[allow(dead_code)] pub const SCOPE_WORKGROUP: u32 = 2; pub const SEMANTICS_RELAXED: u32 = 0x0; #[allow(dead_code)] pub const SEMANTICS_ACQUIRE_RELEASE: u32 = 0x8; #[repr(C)] #[derive(Debug)] pub struct AtomicU32 { value: u32, } impl Default for AtomicU32 { fn default() -> Self { Self::new(0) } } impl Clone for AtomicU32 { fn clone(&self) -> Self { Self::new(self.load()) } } impl AtomicU32 { pub fn new(val: u32) -> Self { Self { value: val } } #[cfg(not(any(target_arch = "spirv", feature = "cuda")))] #[inline(always)] pub fn load(&self) -> u32 { let atomic = unsafe { &*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32) }; atomic.load(core::sync::atomic::Ordering::Relaxed) } #[cfg(not(any(target_arch = "spirv", feature = "cuda")))] #[inline(always)] pub fn store(&self, val: u32) { let atomic = unsafe { &*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32) }; atomic.store(val, core::sync::atomic::Ordering::Relaxed); } #[cfg(not(any(target_arch = "spirv", feature = "cuda")))] #[inline(always)] pub fn fetch_add(&self, val: u32) -> u32 { let atomic = unsafe { &*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32) }; atomic.fetch_add(val, core::sync::atomic::Ordering::Relaxed) } #[cfg(not(any(target_arch = "spirv", feature = "cuda")))] #[inline(always)] pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result { let atomic = unsafe { &*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32) }; atomic.compare_exchange_weak( expected, desired, core::sync::atomic::Ordering::Relaxed, core::sync::atomic::Ordering::Relaxed, ) } #[cfg(target_arch = "spirv")] #[inline(always)] pub fn load(&self) -> u32 { unsafe { spirv_std::arch::atomic_load::( &self.value, ) } } #[cfg(target_arch = "spirv")] #[inline(always)] pub fn store(&self, val: u32) { unsafe { spirv_std::arch::atomic_store::( &mut *core::ptr::addr_of!(self.value).cast_mut(), val, ); } } #[cfg(target_arch = "spirv")] #[inline(always)] pub fn fetch_add(&self, val: u32) -> u32 { unsafe { spirv_std::arch::atomic_i_add::( &mut *core::ptr::addr_of!(self.value).cast_mut(), val, ) } } #[cfg(target_arch = "spirv")] #[inline(always)] pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result { let old = unsafe { spirv_std::arch::atomic_compare_exchange::< u32, SCOPE_DEVICE, SEMANTICS_RELAXED, SEMANTICS_RELAXED, >( &mut *core::ptr::addr_of!(self.value).cast_mut(), desired, expected, ) }; if old == expected { Ok(old) } else { Err(old) } } // -- CUDA backend -- #[cfg(feature = "cuda")] #[inline(always)] pub fn load(&self) -> u32 { // CUDA volatile read for atomicity on the same SM unsafe { core::ptr::read_volatile(&self.value) } } #[cfg(feature = "cuda")] #[inline(always)] pub fn store(&self, val: u32) { unsafe { core::ptr::write_volatile( core::ptr::addr_of!(self.value).cast_mut(), val, ); } } #[cfg(feature = "cuda")] #[inline(always)] pub fn fetch_add(&self, val: u32) -> u32 { let ptr = core::ptr::addr_of!(self.value).cast_mut(); let mut old: u32; unsafe { core::arch::asm!( "atom.add.u32 {old}, [{ptr}], {val};", old = out(reg32) old, ptr = in(reg64) ptr, val = in(reg32) val, ); } old } #[cfg(feature = "cuda")] #[inline(always)] pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result { let ptr = core::ptr::addr_of!(self.value).cast_mut(); let mut old: u32; unsafe { core::arch::asm!( "atom.cas.b32 {old}, [{ptr}], {expected}, {desired};", old = out(reg32) old, ptr = in(reg64) ptr, expected = in(reg32) expected, desired = in(reg32) desired, ); } if old == expected { Ok(old) } else { Err(old) } } } #[repr(C)] #[derive(Debug)] pub struct AtomicF32 { bits: AtomicU32, } impl Default for AtomicF32 { fn default() -> Self { Self::new(0.0) } } impl Clone for AtomicF32 { fn clone(&self) -> Self { Self::new(self.get()) } } impl AtomicF32 { pub fn new(val: Float) -> Self { Self { bits: AtomicU32::new(val.to_bits()), } } pub fn get(&self) -> Float { Float::from_bits(self.bits.load()) } pub fn set(&self, val: Float) { self.bits.store(val.to_bits()); } #[cfg(not(any(target_arch = "spirv", feature = "cuda")))] #[inline(always)] pub fn add(&self, val: Float) { let mut current_bits = self.bits.load(); loop { let current_val = Float::from_bits(current_bits); let new_val = current_val + val; let new_bits = new_val.to_bits(); match self.bits.compare_exchange(current_bits, new_bits) { Ok(_) => break, Err(x) => current_bits = x, } } } #[cfg(target_arch = "spirv")] #[inline(always)] pub fn add(&self, val: Float) { unsafe { let float_ptr = core::ptr::addr_of!(self.bits.value) as *mut Float; spirv_std::arch::atomic_f_add::( &mut *float_ptr, val, ); } } #[cfg(feature = "cuda")] #[inline(always)] pub fn add(&self, val: Float) { let ptr = core::ptr::addr_of!(self.bits.value) as *mut Float; unsafe { core::arch::asm!( "atom.add.f32 {old}, [{ptr}], {val};", old = out(reg32) _, ptr = in(reg64) ptr, val = in(reg32) val.to_bits(), ); } } } pub type AtomicFloat = AtomicF32;