Generalizing GPU framework to use Vulkan or CUDA with spirv

This commit is contained in:
Wito Wiala 2026-02-19 15:41:05 +00:00
parent 0b04d54346
commit 8a92d7642d

147
src/utils/backend.rs Normal file
View file

@ -0,0 +1,147 @@
use std::alloc::Layout;
pub trait GpuAllocator: Send + Sync {
/// Allocate `size` bytes with given alignment.
/// Returns a host-mapped pointer.
unsafe fn alloc(&self, layout: Layout) -> *mut u8;
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout);
}
/// CPU fallback — standard system allocator.
pub struct SystemAllocator;
impl Default for SystemAllocator {
fn default() -> Self {
Self
}
}
impl GpuAllocator for SystemAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
if layout.size() == 0 {
return layout.align() as *mut u8;
}
unsafe { std::alloc::alloc(layout) }
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
if layout.size() > 0 {
unsafe {
std::alloc::dealloc(ptr, layout);
}
}
}
}
/// CUDA unified memory backend using CudaAllocator
#[cfg(feature = "cuda")]
pub struct CudaAllocator;
#[cfg(feature = "cuda")]
impl Default for CudaAllocator {
fn default() -> Self {
Self
}
}
#[cfg(feature = "cuda")]
impl GpuAllocator for CudaAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
use cust::memory::cuda_malloc_unified;
let size = layout.size().max(layout.align());
if size == 0 {
return layout.align() as *mut u8;
}
let mut unified_ptr =
unsafe { cuda_malloc_unified::<u8>(size).expect("cuda_malloc_unified failed") };
let raw = unified_ptr.as_raw_mut();
std::mem::forget(unified_ptr);
raw
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
use cust::memory::{UnifiedPointer, cuda_free_unified};
if layout.size() > 0 {
let _ = unsafe { cuda_free_unified(UnifiedPointer::wrap(ptr)) };
}
}
}
/// Vulkan backend using gpu-allocator.
#[cfg(feature = "vulkan")]
pub mod vulkan {
use super::GpuAllocator;
use ash::vk;
use gpu_allocator::MemoryLocation;
use gpu_allocator::vulkan::{Allocation, AllocationCreateDesc, AllocationScheme, Allocator};
use parking_lot::Mutex;
use std::alloc::Layout;
use std::collections::HashMap;
pub struct VulkanAllocator {
allocator: Mutex<Allocator>,
/// Track pointer -> Allocation so we can free later.
allocations: Mutex<HashMap<usize, Allocation>>,
}
impl VulkanAllocator {
pub fn new(allocator: Allocator) -> Self {
Self {
allocator: Mutex::new(allocator),
allocations: Mutex::new(HashMap::new()),
}
}
}
impl Default for VulkanAllocator {
fn default() -> Self {
Self
}
}
impl GpuAllocator for VulkanAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let size = layout.size().max(layout.align());
if size == 0 {
return layout.align() as *mut u8;
}
let mut alloc = self.allocator.lock();
let allocation = alloc
.allocate(&AllocationCreateDesc {
name: "arena",
requirements: vk::MemoryRequirements {
size: size as u64,
alignment: layout.align() as u64,
memory_type_bits: u32::MAX,
},
location: MemoryLocation::CpuToGpu,
linear: true,
allocation_scheme: AllocationScheme::GpuAllocatorManaged,
})
.expect("Vulkan allocation failed");
let ptr = allocation
.mapped_ptr()
.expect("Vulkan allocation not host-mapped")
.as_ptr() as *mut u8;
self.allocations.lock().insert(ptr as usize, allocation);
ptr
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
if layout.size() == 0 {
return;
}
if let Some(allocation) = self.allocations.lock().remove(&(ptr as usize)) {
self.allocator
.lock()
.free(allocation)
.expect("Vulkan free failed");
}
}
}
}