19 lines
616 B
Rust
19 lines
616 B
Rust
#![cfg_attr(target_arch = "spirv", no_std)]
|
|
|
|
use spirv_std::spirv;
|
|
|
|
pub fn scale_kernel_logic(idx: usize, input: &[f32], output: &mut [f32], scale: f32) {
|
|
if idx < input.len() {
|
|
output[idx] = input[idx] * scale;
|
|
}
|
|
}
|
|
|
|
#[spirv(compute(threads(64)))]
|
|
pub fn scale_kernel(
|
|
#[spirv(global_invocation_id)] id: spirv_std::glam::UVec3,
|
|
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[f32],
|
|
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [f32],
|
|
#[spirv(push_constant)] scale: &f32,
|
|
) {
|
|
scale_kernel_logic(id.x as usize, input, output, *scale);
|
|
}
|