102 lines
2.9 KiB
Rust
102 lines
2.9 KiB
Rust
use half::f16;
|
|
use shared::utils::hash::hash_buffer;
|
|
use shared::utils::math::{permutation_element, DeviceDigitPermutation, PRIMES};
|
|
use shared::Float;
|
|
|
|
#[inline(always)]
|
|
pub fn f16_to_f32(bits: u16) -> f32 {
|
|
#[cfg(feature = "cuda")]
|
|
{
|
|
// Use hardware intrinsic on CUDA
|
|
// Cast bits to cuda_f16, then cast to f32
|
|
let half_val = unsafe { core::mem::transmute::<u16, cuda_std::f16>(bits) };
|
|
half_val.to_f32()
|
|
}
|
|
|
|
#[cfg(feature = "vulkan")]
|
|
{
|
|
// Use shared logic or spirv-std intrinsics if available.
|
|
// Sadly, f16 support in rust-gpu is still maturing.
|
|
// A manual bit-conversion function is often safest here.
|
|
f16_to_f32_software(bits)
|
|
}
|
|
|
|
#[cfg(not(any(feature = "cuda", feature = "vulkan")))]
|
|
{
|
|
f16::from_bits(bits).to_f32()
|
|
}
|
|
}
|
|
|
|
pub struct DigitPermutation {
|
|
pub permutations: Vec<u16>,
|
|
pub device: DeviceDigitPermutation,
|
|
}
|
|
|
|
impl DigitPermutation {
|
|
pub fn new(base: i32, seed: u64) -> Self {
|
|
assert!(base < 65536);
|
|
let mut n_digits: u32 = 0;
|
|
let inv_base = 1. / base as Float;
|
|
let mut inv_base_m = 1.;
|
|
|
|
while 1.0 - ((base as Float - 1.0) * inv_base_m) < 1.0 {
|
|
n_digits += 1;
|
|
inv_base_m *= inv_base;
|
|
}
|
|
|
|
let mut permutations = vec![0u16; n_digits as usize * base as usize];
|
|
|
|
for digit_index in 0..n_digits {
|
|
let hash_input = [base as u64, digit_index as u64, seed];
|
|
let dseed = hash_buffer(&hash_input, 0);
|
|
|
|
for digit_value in 0..base {
|
|
let index = (digit_index as i32 * base + digit_value) as usize;
|
|
|
|
permutations[index] =
|
|
permutation_element(digit_value as u32, base as u32, dseed as u32) as u16;
|
|
}
|
|
}
|
|
|
|
let device = DeviceDigitPermutation {
|
|
base,
|
|
n_digits,
|
|
permutations: permutations.as_ptr().into(),
|
|
};
|
|
|
|
Self {
|
|
device,
|
|
permutations,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn compute_radical_inverse_permutations(seed: u64) -> (Vec<u16>, Vec<DeviceDigitPermutation>) {
|
|
let temp_data: Vec<Vec<u16>> = PRIMES
|
|
.iter()
|
|
.map(|&base| DigitPermutation::new(base as i32, seed).permutations)
|
|
.collect();
|
|
let mut storage: Vec<u16> = Vec::with_capacity(temp_data.iter().map(|v| v.len()).sum());
|
|
|
|
for vec in &temp_data {
|
|
storage.extend_from_slice(vec);
|
|
}
|
|
|
|
let mut views = Vec::with_capacity(PRIMES.len());
|
|
// let mut current_offset = 0;
|
|
|
|
// let storage_base_ptr = storage.as_ptr();
|
|
|
|
for (i, &base) in PRIMES.iter().enumerate() {
|
|
let len = temp_data[i].len();
|
|
let n_digits = len as u32 / base as u32;
|
|
|
|
// let ptr_to_data = storage_base_ptr.add(current_offset);
|
|
|
|
views.push(DigitPermutation::new(base as i32, n_digits as u64).device);
|
|
|
|
// current_offset += len;
|
|
}
|
|
|
|
(storage, views)
|
|
}
|