#![cfg_attr(target_arch = "spirv", no_std)] use spirv_std::spirv; pub fn scale_kernel_logic(idx: usize, input: &[f32], output: &mut [f32], scale: f32) { if idx < input.len() { output[idx] = input[idx] * scale; } } #[spirv(compute(threads(64)))] pub fn scale_kernel( #[spirv(global_invocation_id)] id: spirv_std::glam::UVec3, #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[f32], #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [f32], #[spirv(push_constant)] scale: &f32, ) { scale_kernel_logic(id.x as usize, input, output, *scale); }