263 lines
6.8 KiB
Rust
263 lines
6.8 KiB
Rust
use crate::Float;
|
|
|
|
pub const SCOPE_DEVICE: u32 = 1;
|
|
|
|
#[allow(dead_code)]
|
|
pub const SCOPE_WORKGROUP: u32 = 2;
|
|
|
|
pub const SEMANTICS_RELAXED: u32 = 0x0;
|
|
|
|
#[allow(dead_code)]
|
|
pub const SEMANTICS_ACQUIRE_RELEASE: u32 = 0x8;
|
|
|
|
#[repr(C)]
|
|
#[derive(Debug)]
|
|
pub struct AtomicU32 {
|
|
value: u32,
|
|
}
|
|
|
|
impl Default for AtomicU32 {
|
|
fn default() -> Self {
|
|
Self::new(0)
|
|
}
|
|
}
|
|
|
|
impl Clone for AtomicU32 {
|
|
fn clone(&self) -> Self {
|
|
Self::new(self.load())
|
|
}
|
|
}
|
|
|
|
impl AtomicU32 {
|
|
pub fn new(val: u32) -> Self {
|
|
Self { value: val }
|
|
}
|
|
|
|
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
|
|
#[inline(always)]
|
|
pub fn load(&self) -> u32 {
|
|
let atomic = unsafe {
|
|
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
|
|
};
|
|
atomic.load(core::sync::atomic::Ordering::Relaxed)
|
|
}
|
|
|
|
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
|
|
#[inline(always)]
|
|
pub fn store(&self, val: u32) {
|
|
let atomic = unsafe {
|
|
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
|
|
};
|
|
atomic.store(val, core::sync::atomic::Ordering::Relaxed);
|
|
}
|
|
|
|
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
|
|
#[inline(always)]
|
|
pub fn fetch_add(&self, val: u32) -> u32 {
|
|
let atomic = unsafe {
|
|
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
|
|
};
|
|
atomic.fetch_add(val, core::sync::atomic::Ordering::Relaxed)
|
|
}
|
|
|
|
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
|
|
#[inline(always)]
|
|
pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result<u32, u32> {
|
|
let atomic = unsafe {
|
|
&*(core::ptr::addr_of!(self.value) as *const core::sync::atomic::AtomicU32)
|
|
};
|
|
atomic.compare_exchange_weak(
|
|
expected,
|
|
desired,
|
|
core::sync::atomic::Ordering::Relaxed,
|
|
core::sync::atomic::Ordering::Relaxed,
|
|
)
|
|
}
|
|
|
|
#[cfg(target_arch = "spirv")]
|
|
#[inline(always)]
|
|
pub fn load(&self) -> u32 {
|
|
unsafe {
|
|
spirv_std::arch::atomic_load::<u32, SCOPE_DEVICE, SEMANTICS_RELAXED>(
|
|
&self.value,
|
|
)
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "spirv")]
|
|
#[inline(always)]
|
|
pub fn store(&self, val: u32) {
|
|
unsafe {
|
|
spirv_std::arch::atomic_store::<u32, SCOPE_DEVICE, SEMANTICS_RELAXED>(
|
|
&mut *core::ptr::addr_of!(self.value).cast_mut(),
|
|
val,
|
|
);
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "spirv")]
|
|
#[inline(always)]
|
|
pub fn fetch_add(&self, val: u32) -> u32 {
|
|
unsafe {
|
|
spirv_std::arch::atomic_i_add::<u32, SCOPE_DEVICE, SEMANTICS_RELAXED>(
|
|
&mut *core::ptr::addr_of!(self.value).cast_mut(),
|
|
val,
|
|
)
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "spirv")]
|
|
#[inline(always)]
|
|
pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result<u32, u32> {
|
|
let old = unsafe {
|
|
spirv_std::arch::atomic_compare_exchange::<
|
|
u32,
|
|
SCOPE_DEVICE,
|
|
SEMANTICS_RELAXED,
|
|
SEMANTICS_RELAXED,
|
|
>(
|
|
&mut *core::ptr::addr_of!(self.value).cast_mut(),
|
|
desired,
|
|
expected,
|
|
)
|
|
};
|
|
if old == expected {
|
|
Ok(old)
|
|
} else {
|
|
Err(old)
|
|
}
|
|
}
|
|
|
|
// -- CUDA backend --
|
|
#[cfg(feature = "cuda")]
|
|
#[inline(always)]
|
|
pub fn load(&self) -> u32 {
|
|
// CUDA volatile read for atomicity on the same SM
|
|
unsafe { core::ptr::read_volatile(&self.value) }
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
#[inline(always)]
|
|
pub fn store(&self, val: u32) {
|
|
unsafe {
|
|
core::ptr::write_volatile(
|
|
core::ptr::addr_of!(self.value).cast_mut(),
|
|
val,
|
|
);
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
#[inline(always)]
|
|
pub fn fetch_add(&self, val: u32) -> u32 {
|
|
let ptr = core::ptr::addr_of!(self.value).cast_mut();
|
|
let mut old: u32;
|
|
unsafe {
|
|
core::arch::asm!(
|
|
"atom.add.u32 {old}, [{ptr}], {val};",
|
|
old = out(reg32) old,
|
|
ptr = in(reg64) ptr,
|
|
val = in(reg32) val,
|
|
);
|
|
}
|
|
old
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
#[inline(always)]
|
|
pub fn compare_exchange(&self, expected: u32, desired: u32) -> Result<u32, u32> {
|
|
let ptr = core::ptr::addr_of!(self.value).cast_mut();
|
|
let mut old: u32;
|
|
unsafe {
|
|
core::arch::asm!(
|
|
"atom.cas.b32 {old}, [{ptr}], {expected}, {desired};",
|
|
old = out(reg32) old,
|
|
ptr = in(reg64) ptr,
|
|
expected = in(reg32) expected,
|
|
desired = in(reg32) desired,
|
|
);
|
|
}
|
|
if old == expected {
|
|
Ok(old)
|
|
} else {
|
|
Err(old)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[repr(C)]
|
|
#[derive(Debug)]
|
|
pub struct AtomicF32 {
|
|
bits: AtomicU32,
|
|
}
|
|
|
|
impl Default for AtomicF32 {
|
|
fn default() -> Self {
|
|
Self::new(0.0)
|
|
}
|
|
}
|
|
|
|
impl Clone for AtomicF32 {
|
|
fn clone(&self) -> Self {
|
|
Self::new(self.get())
|
|
}
|
|
}
|
|
|
|
impl AtomicF32 {
|
|
pub fn new(val: Float) -> Self {
|
|
Self {
|
|
bits: AtomicU32::new(val.to_bits()),
|
|
}
|
|
}
|
|
|
|
pub fn get(&self) -> Float {
|
|
Float::from_bits(self.bits.load())
|
|
}
|
|
|
|
pub fn set(&self, val: Float) {
|
|
self.bits.store(val.to_bits());
|
|
}
|
|
|
|
#[cfg(not(any(target_arch = "spirv", feature = "cuda")))]
|
|
#[inline(always)]
|
|
pub fn add(&self, val: Float) {
|
|
let mut current_bits = self.bits.load();
|
|
loop {
|
|
let current_val = Float::from_bits(current_bits);
|
|
let new_val = current_val + val;
|
|
let new_bits = new_val.to_bits();
|
|
match self.bits.compare_exchange(current_bits, new_bits) {
|
|
Ok(_) => break,
|
|
Err(x) => current_bits = x,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "spirv")]
|
|
#[inline(always)]
|
|
pub fn add(&self, val: Float) {
|
|
unsafe {
|
|
let float_ptr = core::ptr::addr_of!(self.bits.value) as *mut Float;
|
|
spirv_std::arch::atomic_f_add::<Float, SCOPE_DEVICE, SEMANTICS_RELAXED>(
|
|
&mut *float_ptr,
|
|
val,
|
|
);
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
#[inline(always)]
|
|
pub fn add(&self, val: Float) {
|
|
let ptr = core::ptr::addr_of!(self.bits.value) as *mut Float;
|
|
unsafe {
|
|
core::arch::asm!(
|
|
"atom.add.f32 {old}, [{ptr}], {val};",
|
|
old = out(reg32) _,
|
|
ptr = in(reg64) ptr,
|
|
val = in(reg32) val.to_bits(),
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub type AtomicFloat = AtomicF32;
|