use half::f16; use shared::utils::hash::hash_buffer; use shared::utils::math::{permutation_element, DeviceDigitPermutation, PRIMES}; use shared::Float; #[inline(always)] pub fn f16_to_f32(bits: u16) -> f32 { #[cfg(feature = "cuda")] { // Use hardware intrinsic on CUDA // Cast bits to cuda_f16, then cast to f32 let half_val = unsafe { core::mem::transmute::(bits) }; half_val.to_f32() } #[cfg(feature = "vulkan")] { // Use shared logic or spirv-std intrinsics if available. // Sadly, f16 support in rust-gpu is still maturing. // A manual bit-conversion function is often safest here. f16_to_f32_software(bits) } #[cfg(not(any(feature = "cuda", feature = "vulkan")))] { f16::from_bits(bits).to_f32() } } pub struct DigitPermutation { pub permutations: Vec, pub device: DeviceDigitPermutation, } impl DigitPermutation { pub fn new(base: i32, seed: u64) -> Self { assert!(base < 65536); let mut n_digits: u32 = 0; let inv_base = 1. / base as Float; let mut inv_base_m = 1.; while 1.0 - ((base as Float - 1.0) * inv_base_m) < 1.0 { n_digits += 1; inv_base_m *= inv_base; } let mut permutations = vec![0u16; n_digits as usize * base as usize]; for digit_index in 0..n_digits { let hash_input = [base as u64, digit_index as u64, seed]; let dseed = hash_buffer(&hash_input, 0); for digit_value in 0..base { let index = (digit_index as i32 * base + digit_value) as usize; permutations[index] = permutation_element(digit_value as u32, base as u32, dseed as u32) as u16; } } let device = DeviceDigitPermutation { base, n_digits, permutations: permutations.as_ptr().into(), }; Self { device, permutations, } } } pub fn compute_radical_inverse_permutations(seed: u64) -> (Vec, Vec) { let temp_data: Vec> = PRIMES .iter() .map(|&base| DigitPermutation::new(base as i32, seed).permutations) .collect(); let mut storage: Vec = Vec::with_capacity(temp_data.iter().map(|v| v.len()).sum()); for vec in &temp_data { storage.extend_from_slice(vec); } let mut views = Vec::with_capacity(PRIMES.len()); // let mut current_offset = 0; // let storage_base_ptr = storage.as_ptr(); for (i, &base) in PRIMES.iter().enumerate() { let len = temp_data[i].len(); let n_digits = len as u32 / base as u32; // let ptr_to_data = storage_base_ptr.add(current_offset); views.push(DigitPermutation::new(base as i32, n_digits as u64).device); // current_offset += len; } (storage, views) }