1525 lines
43 KiB
Rust
1525 lines
43 KiB
Rust
use super::color::{RGB, XYZ};
|
|
use super::error::{InversionError, LlsError};
|
|
use crate::core::pbrt::{
|
|
Float, FloatBitOps, FloatBits, ONE_MINUS_EPSILON, PI, PI_OVER_4, clamp_t, evaluate_polynomial,
|
|
lerp,
|
|
};
|
|
use crate::geometry::{Point, Point2f, Point2i, Vector, Vector3f, VectorLike};
|
|
use crate::utils::hash::{hash_buffer, mix_bits};
|
|
use crate::utils::sobol::{SOBOL_MATRICES_32, VDC_SOBOL_MATRICES, VDC_SOBOL_MATRICES_INV};
|
|
|
|
use num_traits::{Float as NumFloat, Num, One, Signed, Zero};
|
|
use rayon::prelude::*;
|
|
use std::error::Error;
|
|
use std::fmt::{self, Display};
|
|
use std::iter::{Product, Sum};
|
|
use std::mem;
|
|
use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Rem};
|
|
|
|
#[inline]
|
|
pub fn degrees(a: Float) -> Float {
|
|
a * 180.0 / PI
|
|
}
|
|
|
|
#[inline]
|
|
pub fn radians(a: Float) -> Float {
|
|
a * PI / 180.0
|
|
}
|
|
|
|
#[inline]
|
|
pub fn square<T>(n: T) -> T
|
|
where
|
|
T: Mul<Output = T> + Copy,
|
|
{
|
|
n * n
|
|
}
|
|
|
|
#[inline]
|
|
fn fma<T>(a: T, b: T, c: T) -> T
|
|
where
|
|
T: Mul<Output = T> + Add<Output = T> + Copy,
|
|
{
|
|
a * b + c
|
|
}
|
|
|
|
#[inline]
|
|
pub fn difference_of_products<T>(a: T, b: T, c: T, d: T) -> T
|
|
where
|
|
T: Mul<Output = T> + Add<Output = T> + Neg<Output = T> + Copy,
|
|
{
|
|
let cd = c * d;
|
|
let difference_of_products = fma(a, b, -cd);
|
|
let error = fma(-c, d, cd);
|
|
difference_of_products + error
|
|
}
|
|
|
|
#[inline]
|
|
pub fn sum_of_products<T>(a: T, b: T, c: T, d: T) -> T
|
|
where
|
|
T: Mul<Output = T> + Add<Output = T> + Neg<Output = T> + Copy,
|
|
{
|
|
let cd = c * d;
|
|
let sum_of_products = fma(a, b, cd);
|
|
let error = fma(c, d, -cd);
|
|
sum_of_products + error
|
|
}
|
|
|
|
#[inline]
|
|
pub fn safe_sqrt(x: Float) -> Float {
|
|
assert!(x > -1e-3);
|
|
0.0_f32.max(x).sqrt()
|
|
}
|
|
|
|
#[inline]
|
|
pub fn safe_asin<T: NumFloat>(x: T) -> T {
|
|
let epsilon = T::from(0.0001).unwrap();
|
|
let one = T::one();
|
|
if x >= -(one + epsilon) && x <= one + epsilon {
|
|
clamp_t(x, -one, one).asin()
|
|
} else {
|
|
panic!("Not valid value for asin")
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn safe_acos(x: Float) -> Float {
|
|
if (-1.001..1.001).contains(&x) {
|
|
clamp_t(x, -1., 1.).asin()
|
|
} else {
|
|
panic!("Not valid value for acos")
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn sinx_over_x(x: Float) -> Float {
|
|
if 1. - x * x == 1. {
|
|
return 1.;
|
|
}
|
|
x.sin() / x
|
|
}
|
|
|
|
#[inline]
|
|
pub fn sinc(x: Float) -> Float {
|
|
sinx_over_x(PI * x)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn windowed_sinc(x: Float, radius: Float, tau: Float) -> Float {
|
|
if x.abs() > radius {
|
|
return 0.;
|
|
}
|
|
sinc(x) * sinc(x / tau)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn fast_exp(x: Float) -> Float {
|
|
let xp = x * 1.442695041;
|
|
let fxp = xp.floor();
|
|
let f = xp - fxp;
|
|
let i = fxp as i32;
|
|
let two_to_f = evaluate_polynomial(f, &[1., 0.695556856, 0.226173572, 0.0781455737])
|
|
.expect("Could not evaluate polynomial");
|
|
let exponent = exponent(two_to_f) + i;
|
|
if exponent < -126 {
|
|
return 0.;
|
|
}
|
|
if exponent > 127 {
|
|
return Float::INFINITY;
|
|
}
|
|
let mut bits = float_to_bits(two_to_f);
|
|
bits &= 0b10000000011111111111111111111111;
|
|
bits |= ((exponent + 127) as u32) << 23;
|
|
bits_to_float(bits)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn i0(x: Float) -> Float {
|
|
let mut val: Float = 0.0;
|
|
let mut x2i: Float = 1.0;
|
|
let mut ifact: i64 = 1;
|
|
let mut i4: i32 = 1;
|
|
|
|
for i in 0..10 {
|
|
if i > 1 {
|
|
ifact *= i as i64;
|
|
}
|
|
|
|
let denominator = (i4 as Float) * (square(ifact) as Float);
|
|
val += x2i / denominator;
|
|
|
|
x2i *= x * x;
|
|
i4 *= 4;
|
|
}
|
|
|
|
val
|
|
}
|
|
|
|
#[inline]
|
|
pub fn logistic(x: Float, s: Float) -> Float {
|
|
let y = x.abs();
|
|
(-y / s).exp() / (s * square(1. + (-y / s).exp()))
|
|
}
|
|
|
|
#[inline]
|
|
pub fn logistic_cdf(x: Float, s: Float) -> Float {
|
|
1. / (1. + (-x / s).exp())
|
|
}
|
|
|
|
#[inline]
|
|
pub fn trimmed_logistic(x: Float, s: Float, a: Float, b: Float) -> Float {
|
|
logistic(x, s) / (logistic_cdf(b, s) - logistic_cdf(a, s))
|
|
}
|
|
|
|
#[inline]
|
|
pub fn log_i0(x: Float) -> Float {
|
|
if x > 12.0 {
|
|
let inv_x = 1.0 / x;
|
|
let two_pi = 2.0 * PI as Float;
|
|
|
|
x + 0.5 * (-two_pi.ln() + inv_x.ln() + 1.0 / (8.0 * x))
|
|
} else {
|
|
i0(x).ln()
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn log2_int(v: f32) -> i32 {
|
|
debug_assert!(v > 0.0, "log2_int requires positive input");
|
|
|
|
if v < 1.0 {
|
|
return -log2_int(1.0 / v);
|
|
}
|
|
|
|
// https://graphics.stanford.edu/~seander/bithacks.html#IntegerLog
|
|
// midsignif = Significand(pow(2.0, 1.5))
|
|
const MID_SIGNIF: u32 = 0b00000000001101010000010011110011;
|
|
|
|
let bits = v.to_bits();
|
|
let exponent = ((bits >> 23) & 0xFF) as i32 - 127;
|
|
let significand = bits & 0x7FFFFF;
|
|
|
|
if significand >= MID_SIGNIF {
|
|
exponent + 1
|
|
} else {
|
|
exponent
|
|
}
|
|
}
|
|
|
|
pub fn quadratic(a: Float, b: Float, c: Float) -> Option<(Float, Float)> {
|
|
if a == 0. {
|
|
if b == 0. {
|
|
return None;
|
|
}
|
|
let t0 = -c / b;
|
|
let t1 = -c / b;
|
|
return Some((t0, t1));
|
|
}
|
|
|
|
let discrim = difference_of_products(b, b, 4. * a, c);
|
|
if discrim < 0. {
|
|
return None;
|
|
}
|
|
let root_discrim = discrim.sqrt();
|
|
|
|
let q = -0.5 * (b + root_discrim.copysign(b));
|
|
let mut t0 = q / a;
|
|
let mut t1 = c / q;
|
|
if t0 > t1 {
|
|
mem::swap(&mut t0, &mut t1);
|
|
}
|
|
|
|
Some((t0, t1))
|
|
}
|
|
|
|
pub fn smooth_step(x: Float, a: Float, b: Float) -> Float {
|
|
if a == b {
|
|
if x < a { return 0. } else { return 1. }
|
|
}
|
|
|
|
let t = clamp_t((x - a) / (b - a), 0., 1.);
|
|
t * t * (3. - 2. * t)
|
|
}
|
|
|
|
pub fn linear_least_squares<const R: usize, const N: usize>(
|
|
a: [[Float; R]; N],
|
|
b: [[Float; R]; N],
|
|
) -> Result<SquareMatrix<Float, R>, Box<dyn Error>> {
|
|
let am = Matrix::from(a);
|
|
let bm = Matrix::from(b);
|
|
let ata = am.transpose() * am;
|
|
let atb = am.transpose() * bm;
|
|
let at_ai = ata.inverse()?;
|
|
Ok((at_ai * atb).transpose())
|
|
}
|
|
|
|
pub fn newton_bisection<P>(mut x0: Float, mut x1: Float, mut f: P) -> Float
|
|
where
|
|
P: FnMut(Float) -> (Float, Float),
|
|
{
|
|
assert!(x0 < x1);
|
|
let f_eps = 1e-6;
|
|
let x_eps = 1e-6;
|
|
|
|
let (fx0, _) = f(x0);
|
|
let (fx1, _) = f(x1);
|
|
if fx0.abs() < f_eps {
|
|
return x0;
|
|
}
|
|
if fx1.abs() < f_eps {
|
|
return x1;
|
|
}
|
|
let start_is_negative = fx0 < 0.0;
|
|
|
|
let mut x_mid = x0 + (x1 - x0) * -fx0 / (fx1 - fx0);
|
|
|
|
loop {
|
|
if !(x0 < x_mid && x_mid < x1) {
|
|
x_mid = (x0 + x1) / 2.0;
|
|
}
|
|
|
|
let (fx_mid, dfx_mid) = f(x_mid);
|
|
assert!(!fx_mid.is_nan());
|
|
if start_is_negative == (fx_mid < 0.0) {
|
|
x0 = x_mid;
|
|
} else {
|
|
x1 = x_mid;
|
|
}
|
|
|
|
if (x1 - x0) < x_eps || fx_mid.abs() < f_eps {
|
|
return x_mid;
|
|
}
|
|
|
|
x_mid -= fx_mid / dfx_mid;
|
|
}
|
|
}
|
|
|
|
pub fn wrap_equal_area_square(uv: &mut Point2f) -> Point2f {
|
|
if uv[0] < 0. {
|
|
uv[0] = -uv[0];
|
|
uv[1] = 1. - uv[1];
|
|
} else if uv[0] > 1. {
|
|
uv[0] = 2. - uv[0];
|
|
uv[1] = 1. - uv[1];
|
|
}
|
|
if uv[1] < 0. {
|
|
uv[0] = 1. - uv[0];
|
|
uv[1] = -uv[1];
|
|
} else if uv[1] > 1. {
|
|
uv[0] = 1. - uv[0];
|
|
uv[1] = 2. - uv[1];
|
|
}
|
|
*uv
|
|
}
|
|
|
|
pub fn catmull_rom_weights(nodes: &[Float], x: Float) -> Option<(usize, [Float; 4])> {
|
|
if nodes.len() < 4 {
|
|
return None;
|
|
}
|
|
|
|
// Return None if x is out of bounds
|
|
if x < *nodes.first()? || x > *nodes.last()? {
|
|
return None;
|
|
}
|
|
|
|
// Search for interval idx containing x
|
|
// partition_point returns the first index where predicate is false (!= <= x means > x)
|
|
// Equivalent to upper_bound. We subtract 1 to get the interval index.
|
|
let idx = nodes.partition_point(|&n| n <= x).saturating_sub(1);
|
|
|
|
// Safety clamp (though bounds check above handles most cases)
|
|
let idx = idx.min(nodes.len() - 2);
|
|
|
|
let offset = idx.saturating_sub(1); // The C++ code uses idx - 1 for the offset
|
|
let x0 = nodes[idx];
|
|
let x1 = nodes[idx + 1];
|
|
|
|
// Compute t parameter and powers
|
|
let t = (x - x0) / (x1 - x0);
|
|
let t2 = t * t;
|
|
let t3 = t2 * t;
|
|
|
|
let mut weights = [0.0; 4];
|
|
|
|
// Compute initial node weights w1 and w2
|
|
weights[1] = 2.0 * t3 - 3.0 * t2 + 1.0;
|
|
weights[2] = -2.0 * t3 + 3.0 * t2;
|
|
|
|
// Compute first node weight w0
|
|
if idx > 0 {
|
|
let w0 = (t3 - 2.0 * t2 + t) * (x1 - x0) / (x1 - nodes[idx - 1]);
|
|
weights[0] = -w0;
|
|
weights[2] += w0;
|
|
} else {
|
|
let w0 = t3 - 2.0 * t2 + t;
|
|
weights[0] = 0.0;
|
|
weights[1] -= w0;
|
|
weights[2] += w0;
|
|
}
|
|
|
|
// Compute last node weight w3
|
|
if idx + 2 < nodes.len() {
|
|
let w3 = (t3 - t2) * (x1 - x0) / (nodes[idx + 2] - x0);
|
|
weights[1] -= w3;
|
|
weights[3] = w3;
|
|
} else {
|
|
let w3 = t3 - t2;
|
|
weights[1] -= w3;
|
|
weights[2] += w3;
|
|
weights[3] = 0.0;
|
|
}
|
|
|
|
Some((offset, weights))
|
|
}
|
|
|
|
pub fn equal_area_sphere_to_square(d: Vector3f) -> Point2f {
|
|
debug_assert!(d.norm_squared() > 0.999 && d.norm_squared() < 1.001);
|
|
let x = d.x().abs();
|
|
let y = d.y().abs();
|
|
let z = d.z().abs();
|
|
let r = safe_sqrt(1. - z);
|
|
let a = x.max(y);
|
|
let b = if a == 0. { 0. } else { x.min(y) / a };
|
|
let t1 = 0.406758566246788489601959989e-5;
|
|
let t2 = 0.636226545274016134946890922156;
|
|
let t3 = 0.61572017898280213493197203466e-2;
|
|
let t4 = -0.247333733281268944196501420480;
|
|
let t5 = 0.881770664775316294736387951347e-1;
|
|
let t6 = 0.419038818029165735901852432784e-1;
|
|
let t7 = -0.251390972343483509333252996350e-1;
|
|
let mut phi = evaluate_polynomial(b, &[t1, t2, t3, t4, t5, t6, t7])
|
|
.expect("Could not evaluate polynomial");
|
|
|
|
if x < y {
|
|
phi = 1. - phi;
|
|
}
|
|
|
|
let mut v = phi * r;
|
|
let mut u = r - v;
|
|
|
|
if d.z() < 0. {
|
|
mem::swap(&mut u, &mut v);
|
|
u = 1. - u;
|
|
v = 1. - v;
|
|
}
|
|
|
|
u = u.copysign(d.x());
|
|
v = v.copysign(d.y());
|
|
|
|
Point2f::new(0.5 * (u + 1.), 0.5 * (v + 1.))
|
|
}
|
|
|
|
pub fn equal_area_square_to_sphere(p: Point2f) -> Vector3f {
|
|
assert!(p.x() >= 0. && p.x() <= 1. && p.y() >= 0. && p.y() <= 1.);
|
|
|
|
// Transform p to [-1,1]^2 and compute absolute values
|
|
let u = 2. * p.x() - 1.;
|
|
let v = 2. * p.y() - 1.;
|
|
let up = u.abs();
|
|
let vp = v.abs();
|
|
|
|
// Compute radius as signed distance from diagonal
|
|
let signed_distance = 1. - (up + vp);
|
|
let d = signed_distance.abs();
|
|
let r = 1. - d;
|
|
|
|
// Compute angle \phi for square to sphere mapping
|
|
let mut phi = if r == 0. { 1. } else { (vp - up) / r + 1. };
|
|
phi /= PI_OVER_4;
|
|
|
|
// Find z for spherical direction
|
|
let z = (1. - square(r)).copysign(signed_distance);
|
|
|
|
// Compute $\cos\phi$ and $\sin\phi$ for original quadrant and return vector
|
|
let cos_phi = phi.cos().copysign(u);
|
|
let sin_phi = phi.sin().copysign(v);
|
|
Vector3f::new(
|
|
cos_phi * r * (2. - square(r)).sqrt(),
|
|
sin_phi * r * (2. - square(r)).sqrt(),
|
|
z,
|
|
)
|
|
}
|
|
|
|
pub fn gaussian(x: Float, y: Float, sigma: Float) -> Float {
|
|
(-(x * x + y * y) / (2. * sigma * sigma)).exp()
|
|
}
|
|
|
|
pub fn gaussian_integral(x0: Float, x1: Float, mu: Float, sigma: Float) -> Float {
|
|
assert!(sigma > 0.);
|
|
let sigma_root2 = sigma * 1.414213562373095;
|
|
0.5 * (((mu - x0) / sigma_root2).erf() - ((mu - x1) / sigma_root2).erf())
|
|
}
|
|
|
|
pub fn sample_linear(u: Float, a: Float, b: Float) -> Float {
|
|
assert!(a >= 0. && b >= 0.);
|
|
if u == 0. && a == 0. {
|
|
return 0.;
|
|
}
|
|
let x = u * (a + b) / (a + (lerp(u, square(a), square(b))));
|
|
x.min(ONE_MINUS_EPSILON)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn bits_to_float(bits: FloatBits) -> Float {
|
|
Float::from_bits_val(bits)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn float_to_bits(v: Float) -> FloatBits {
|
|
v.to_bits_val()
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn exponent(v: Float) -> i32 {
|
|
v.exponent_val()
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn significand(v: Float) -> FloatBits {
|
|
v.significand_val()
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn sign_bit(v: Float) -> FloatBits {
|
|
v.sign_bit_val()
|
|
}
|
|
|
|
#[inline]
|
|
pub fn next_float_up(v: Float) -> Float {
|
|
if v.is_infinite() && v > 0.0 {
|
|
return v;
|
|
}
|
|
let v = if v == -0.0 { 0.0 } else { v };
|
|
|
|
let mut ui = float_to_bits(v);
|
|
if v >= 0.0 {
|
|
ui += 1;
|
|
} else {
|
|
ui -= 1;
|
|
}
|
|
bits_to_float(ui)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn next_float_down(v: Float) -> Float {
|
|
if v.is_infinite() && v < 0.0 {
|
|
return v;
|
|
}
|
|
let v = if v == 0.0 { -0.0 } else { v };
|
|
|
|
let mut ui = float_to_bits(v);
|
|
if v > 0.0 {
|
|
ui -= 1;
|
|
} else {
|
|
ui += 1;
|
|
}
|
|
bits_to_float(ui)
|
|
}
|
|
|
|
pub fn sample_discrete(
|
|
weights: &[f32],
|
|
u: f32,
|
|
pmf: Option<&mut f32>,
|
|
u_remapped: Option<&mut f32>,
|
|
) -> Option<usize> {
|
|
// Handle empty weights for discrete sampling.
|
|
if weights.is_empty() {
|
|
if let Some(p) = pmf {
|
|
*p = 0.0;
|
|
}
|
|
return None;
|
|
}
|
|
|
|
// If the total weight is zero, sampling is not possible.
|
|
let sum_weights: f32 = weights.iter().sum();
|
|
if sum_weights == 0.0 {
|
|
if let Some(p) = pmf {
|
|
*p = 0.0;
|
|
}
|
|
return None;
|
|
}
|
|
|
|
let mut up = u * sum_weights;
|
|
if up >= sum_weights {
|
|
up = next_float_down(up);
|
|
}
|
|
|
|
// Find the offset in weights corresponding to the rescaled sample u'.
|
|
let mut offset = 0;
|
|
let mut sum = 0.0;
|
|
while sum + weights[offset] <= up {
|
|
sum += weights[offset];
|
|
offset += 1;
|
|
debug_assert!(offset < weights.len());
|
|
}
|
|
|
|
if let Some(p) = pmf {
|
|
*p = weights[offset] / sum_weights;
|
|
}
|
|
if let Some(ur) = u_remapped {
|
|
let weight = weights[offset];
|
|
*ur = ((up - sum) / weight).min(ONE_MINUS_EPSILON);
|
|
}
|
|
|
|
Some(offset)
|
|
}
|
|
|
|
pub fn sample_tent(u: Float, r: Float) -> Float {
|
|
let mut u_remapped = 0.0;
|
|
let offset = sample_discrete(&[0.5, 0.5], u, None, Some(&mut u_remapped))
|
|
.expect("Discrete sampling shouldn't fail");
|
|
if offset == 0 {
|
|
-r + r * sample_linear(u, 0., 1.)
|
|
} else {
|
|
r * sample_linear(u, 1., 0.)
|
|
}
|
|
}
|
|
|
|
fn left_shift2(mut x: u64) -> u64 {
|
|
x &= 0xffffffff;
|
|
x = (x ^ (x << 16)) & 0x0000ffff0000ffff;
|
|
x = (x ^ (x << 8)) & 0x00ff00ff00ff00ff;
|
|
x = (x ^ (x << 4)) & 0x0f0f0f0f0f0f0f0f;
|
|
x = (x ^ (x << 2)) & 0x3333333333333333;
|
|
x = (x ^ (x << 1)) & 0x5555555555555555;
|
|
x
|
|
}
|
|
|
|
fn left_shift3(mut x: u32) -> u32 {
|
|
if x == (1 << 10) {
|
|
x -= 1;
|
|
}
|
|
x = (x | (x << 16)) & 0b00000011000000000000000011111111;
|
|
x = (x | (x << 8)) & 0b00000011000000001111000000001111;
|
|
x = (x | (x << 4)) & 0b00000011000011000011000011000011;
|
|
x = (x | (x << 2)) & 0b00001001001001001001001001001001;
|
|
x
|
|
}
|
|
|
|
pub fn encode_morton_2(x: u32, y: u32) -> u64 {
|
|
left_shift2(y as u64) << 1 | left_shift2(x as u64)
|
|
}
|
|
|
|
pub fn encode_morton_3(x: Float, y: Float, z: Float) -> u32 {
|
|
(left_shift3(x as u32) << 2) | (left_shift3(y as u32) << 1) | left_shift3(z as u32)
|
|
}
|
|
|
|
pub fn round_up_pow2(mut n: i32) -> i32 {
|
|
if n <= 0 {
|
|
return 1;
|
|
}
|
|
n -= 1;
|
|
n |= n >> 1;
|
|
n |= n >> 2;
|
|
n |= n >> 4;
|
|
n |= n >> 8;
|
|
n |= n >> 16;
|
|
n + 1
|
|
}
|
|
|
|
pub const PRIME_TABLE_SIZE: usize = 1000;
|
|
|
|
const PRIMES: [i32; PRIME_TABLE_SIZE] = [
|
|
2, 3, 5, 7, 11, // Subsequent prime numbers
|
|
13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107,
|
|
109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211,
|
|
223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317,
|
|
331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439,
|
|
443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569,
|
|
571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677,
|
|
683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821,
|
|
823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947,
|
|
953, 967, 971, 977, 983, 991, 997, 1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061,
|
|
1063, 1069, 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181,
|
|
1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289, 1291,
|
|
1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, 1423, 1427, 1429,
|
|
1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, 1499, 1511, 1523, 1531,
|
|
1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, 1607, 1609, 1613, 1619, 1621, 1627,
|
|
1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721, 1723, 1733, 1741, 1747, 1753, 1759,
|
|
1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847, 1861, 1867, 1871, 1873, 1877, 1879, 1889,
|
|
1901, 1907, 1913, 1931, 1933, 1949, 1951, 1973, 1979, 1987, 1993, 1997, 1999, 2003, 2011, 2017,
|
|
2027, 2029, 2039, 2053, 2063, 2069, 2081, 2083, 2087, 2089, 2099, 2111, 2113, 2129, 2131, 2137,
|
|
2141, 2143, 2153, 2161, 2179, 2203, 2207, 2213, 2221, 2237, 2239, 2243, 2251, 2267, 2269, 2273,
|
|
2281, 2287, 2293, 2297, 2309, 2311, 2333, 2339, 2341, 2347, 2351, 2357, 2371, 2377, 2381, 2383,
|
|
2389, 2393, 2399, 2411, 2417, 2423, 2437, 2441, 2447, 2459, 2467, 2473, 2477, 2503, 2521, 2531,
|
|
2539, 2543, 2549, 2551, 2557, 2579, 2591, 2593, 2609, 2617, 2621, 2633, 2647, 2657, 2659, 2663,
|
|
2671, 2677, 2683, 2687, 2689, 2693, 2699, 2707, 2711, 2713, 2719, 2729, 2731, 2741, 2749, 2753,
|
|
2767, 2777, 2789, 2791, 2797, 2801, 2803, 2819, 2833, 2837, 2843, 2851, 2857, 2861, 2879, 2887,
|
|
2897, 2903, 2909, 2917, 2927, 2939, 2953, 2957, 2963, 2969, 2971, 2999, 3001, 3011, 3019, 3023,
|
|
3037, 3041, 3049, 3061, 3067, 3079, 3083, 3089, 3109, 3119, 3121, 3137, 3163, 3167, 3169, 3181,
|
|
3187, 3191, 3203, 3209, 3217, 3221, 3229, 3251, 3253, 3257, 3259, 3271, 3299, 3301, 3307, 3313,
|
|
3319, 3323, 3329, 3331, 3343, 3347, 3359, 3361, 3371, 3373, 3389, 3391, 3407, 3413, 3433, 3449,
|
|
3457, 3461, 3463, 3467, 3469, 3491, 3499, 3511, 3517, 3527, 3529, 3533, 3539, 3541, 3547, 3557,
|
|
3559, 3571, 3581, 3583, 3593, 3607, 3613, 3617, 3623, 3631, 3637, 3643, 3659, 3671, 3673, 3677,
|
|
3691, 3697, 3701, 3709, 3719, 3727, 3733, 3739, 3761, 3767, 3769, 3779, 3793, 3797, 3803, 3821,
|
|
3823, 3833, 3847, 3851, 3853, 3863, 3877, 3881, 3889, 3907, 3911, 3917, 3919, 3923, 3929, 3931,
|
|
3943, 3947, 3967, 3989, 4001, 4003, 4007, 4013, 4019, 4021, 4027, 4049, 4051, 4057, 4073, 4079,
|
|
4091, 4093, 4099, 4111, 4127, 4129, 4133, 4139, 4153, 4157, 4159, 4177, 4201, 4211, 4217, 4219,
|
|
4229, 4231, 4241, 4243, 4253, 4259, 4261, 4271, 4273, 4283, 4289, 4297, 4327, 4337, 4339, 4349,
|
|
4357, 4363, 4373, 4391, 4397, 4409, 4421, 4423, 4441, 4447, 4451, 4457, 4463, 4481, 4483, 4493,
|
|
4507, 4513, 4517, 4519, 4523, 4547, 4549, 4561, 4567, 4583, 4591, 4597, 4603, 4621, 4637, 4639,
|
|
4643, 4649, 4651, 4657, 4663, 4673, 4679, 4691, 4703, 4721, 4723, 4729, 4733, 4751, 4759, 4783,
|
|
4787, 4789, 4793, 4799, 4801, 4813, 4817, 4831, 4861, 4871, 4877, 4889, 4903, 4909, 4919, 4931,
|
|
4933, 4937, 4943, 4951, 4957, 4967, 4969, 4973, 4987, 4993, 4999, 5003, 5009, 5011, 5021, 5023,
|
|
5039, 5051, 5059, 5077, 5081, 5087, 5099, 5101, 5107, 5113, 5119, 5147, 5153, 5167, 5171, 5179,
|
|
5189, 5197, 5209, 5227, 5231, 5233, 5237, 5261, 5273, 5279, 5281, 5297, 5303, 5309, 5323, 5333,
|
|
5347, 5351, 5381, 5387, 5393, 5399, 5407, 5413, 5417, 5419, 5431, 5437, 5441, 5443, 5449, 5471,
|
|
5477, 5479, 5483, 5501, 5503, 5507, 5519, 5521, 5527, 5531, 5557, 5563, 5569, 5573, 5581, 5591,
|
|
5623, 5639, 5641, 5647, 5651, 5653, 5657, 5659, 5669, 5683, 5689, 5693, 5701, 5711, 5717, 5737,
|
|
5741, 5743, 5749, 5779, 5783, 5791, 5801, 5807, 5813, 5821, 5827, 5839, 5843, 5849, 5851, 5857,
|
|
5861, 5867, 5869, 5879, 5881, 5897, 5903, 5923, 5927, 5939, 5953, 5981, 5987, 6007, 6011, 6029,
|
|
6037, 6043, 6047, 6053, 6067, 6073, 6079, 6089, 6091, 6101, 6113, 6121, 6131, 6133, 6143, 6151,
|
|
6163, 6173, 6197, 6199, 6203, 6211, 6217, 6221, 6229, 6247, 6257, 6263, 6269, 6271, 6277, 6287,
|
|
6299, 6301, 6311, 6317, 6323, 6329, 6337, 6343, 6353, 6359, 6361, 6367, 6373, 6379, 6389, 6397,
|
|
6421, 6427, 6449, 6451, 6469, 6473, 6481, 6491, 6521, 6529, 6547, 6551, 6553, 6563, 6569, 6571,
|
|
6577, 6581, 6599, 6607, 6619, 6637, 6653, 6659, 6661, 6673, 6679, 6689, 6691, 6701, 6703, 6709,
|
|
6719, 6733, 6737, 6761, 6763, 6779, 6781, 6791, 6793, 6803, 6823, 6827, 6829, 6833, 6841, 6857,
|
|
6863, 6869, 6871, 6883, 6899, 6907, 6911, 6917, 6947, 6949, 6959, 6961, 6967, 6971, 6977, 6983,
|
|
6991, 6997, 7001, 7013, 7019, 7027, 7039, 7043, 7057, 7069, 7079, 7103, 7109, 7121, 7127, 7129,
|
|
7151, 7159, 7177, 7187, 7193, 7207, 7211, 7213, 7219, 7229, 7237, 7243, 7247, 7253, 7283, 7297,
|
|
7307, 7309, 7321, 7331, 7333, 7349, 7351, 7369, 7393, 7411, 7417, 7433, 7451, 7457, 7459, 7477,
|
|
7481, 7487, 7489, 7499, 7507, 7517, 7523, 7529, 7537, 7541, 7547, 7549, 7559, 7561, 7573, 7577,
|
|
7583, 7589, 7591, 7603, 7607, 7621, 7639, 7643, 7649, 7669, 7673, 7681, 7687, 7691, 7699, 7703,
|
|
7717, 7723, 7727, 7741, 7753, 7757, 7759, 7789, 7793, 7817, 7823, 7829, 7841, 7853, 7867, 7873,
|
|
7877, 7879, 7883, 7901, 7907, 7919,
|
|
];
|
|
|
|
#[inline]
|
|
pub fn radical_inverse(base_index: usize, mut a: u64) -> Float {
|
|
let base = PRIMES[base_index] as u64;
|
|
|
|
let limit = (u64::MAX / base).saturating_sub(base);
|
|
|
|
let inv_base = (1.0 as Float) / (base as Float);
|
|
let mut inv_base_m = 1.0 as Float;
|
|
let mut reversed_digits = 0u64;
|
|
|
|
// Loop until we run out of digits or hit the overflow safety limit
|
|
while a > 0 && reversed_digits < limit {
|
|
// Rust's div_rem optimization handles / and % efficiently together
|
|
let next = a / base;
|
|
let digit = a % base;
|
|
|
|
reversed_digits = reversed_digits.wrapping_mul(base).wrapping_add(digit);
|
|
inv_base_m *= inv_base;
|
|
a = next;
|
|
}
|
|
|
|
// Ensure result is strictly less than 1.0
|
|
(reversed_digits as Float * inv_base_m).min(ONE_MINUS_EPSILON)
|
|
}
|
|
|
|
pub fn inverse_radical_inverse(mut inverse: u64, base: u64, n_digits: u64) -> u64 {
|
|
let mut index = 0;
|
|
for _ in 0..n_digits {
|
|
let digit = inverse % base;
|
|
inverse /= base;
|
|
index = index * base + digit;
|
|
}
|
|
index
|
|
}
|
|
|
|
// Digit scrambling
|
|
#[derive(Default, Debug, Clone)]
|
|
pub struct DigitPermutation {
|
|
base: usize,
|
|
n_digits: usize,
|
|
permutations: Vec<u16>,
|
|
}
|
|
|
|
impl DigitPermutation {
|
|
pub fn new(base: usize, seed: u64) -> Self {
|
|
let mut n_digits = 0;
|
|
let inv_base = 1. / base as Float;
|
|
let mut inv_base_m = 1.;
|
|
|
|
while 1.0 - ((base as Float - 1.0) * inv_base_m) < 1.0 {
|
|
n_digits += 1;
|
|
inv_base_m *= inv_base;
|
|
}
|
|
|
|
let mut permutations = vec![0u16; n_digits * base];
|
|
|
|
for digit_index in 0..n_digits {
|
|
let hash_input = [base as u64, digit_index as u64, seed];
|
|
let dseed = hash_buffer(&hash_input, 0);
|
|
|
|
for digit_value in 0..base {
|
|
let index = digit_index * base + digit_value;
|
|
|
|
permutations[index] =
|
|
permutation_element(digit_value as u32, base as u32, dseed as u32) as u16;
|
|
}
|
|
}
|
|
|
|
Self {
|
|
base,
|
|
n_digits,
|
|
permutations,
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn permute(&self, digit_index: i32, digit_value: i32) -> i32 {
|
|
let idx = (digit_index * self.base as i32 + digit_value) as usize;
|
|
self.permutations[idx] as i32
|
|
}
|
|
}
|
|
|
|
pub fn compute_radical_inverse_permutations(seed: u64) -> Vec<DigitPermutation> {
|
|
PRIMES
|
|
.par_iter()
|
|
.map(|&base| DigitPermutation::new(base as usize, seed))
|
|
.collect()
|
|
}
|
|
|
|
pub fn scrambled_radical_inverse(base_index: usize, mut a: u64, perm: &DigitPermutation) -> Float {
|
|
let base = PRIMES[base_index] as u64;
|
|
|
|
let limit = (u64::MAX / base).saturating_sub(base);
|
|
|
|
let inv_base = 1.0 / (base as Float);
|
|
let mut inv_base_m = 1.0;
|
|
let mut reversed_digits = 0u64;
|
|
let mut digit_index = 0;
|
|
|
|
while 1.0 - ((base as Float - 1.0) * inv_base_m) < 1.0 && reversed_digits < limit {
|
|
let next = a / base;
|
|
let digit_value = (a - next * base) as i32;
|
|
|
|
// Permute the digit
|
|
let permuted = perm.permute(digit_index, digit_value) as u64;
|
|
|
|
reversed_digits = reversed_digits.wrapping_mul(base).wrapping_add(permuted);
|
|
inv_base_m *= inv_base;
|
|
digit_index += 1;
|
|
a = next;
|
|
}
|
|
|
|
(inv_base_m * reversed_digits as Float).min(ONE_MINUS_EPSILON)
|
|
}
|
|
|
|
pub fn owen_scrambled_radical_inverse(base_index: usize, mut a: u64, hash: u32) -> Float {
|
|
let base = PRIMES[base_index] as u64;
|
|
|
|
let limit = (u64::MAX / base).saturating_sub(base);
|
|
let inv_base = 1.0 / (base as Float);
|
|
let mut inv_base_m = 1.0;
|
|
let mut reversed_digits = 0u64;
|
|
|
|
let mut _digit_index = 0;
|
|
|
|
while 1.0 - inv_base_m < 1.0 && reversed_digits < limit {
|
|
let next = a / base;
|
|
let digit_value = (a - next * base) as u32;
|
|
|
|
// Compute Owen-scrambled digit
|
|
// XOR the current seed (hash) with the accumulated reversed digits so far
|
|
let digit_hash = mix_bits((hash as u64) ^ reversed_digits) as u32;
|
|
|
|
let permuted = permutation_element(digit_value, base as u32, digit_hash) as u64;
|
|
|
|
reversed_digits = reversed_digits.wrapping_mul(base).wrapping_add(permuted);
|
|
inv_base_m *= inv_base;
|
|
_digit_index += 1;
|
|
a = next;
|
|
}
|
|
|
|
(inv_base_m * reversed_digits as Float).min(ONE_MINUS_EPSILON)
|
|
}
|
|
|
|
pub fn permutation_element(mut i: u32, l: u32, p: u32) -> u32 {
|
|
let mut w = l - 1;
|
|
w |= w >> 1;
|
|
w |= w >> 2;
|
|
w |= w >> 4;
|
|
w |= w >> 8;
|
|
w |= w >> 16;
|
|
|
|
loop {
|
|
i ^= p;
|
|
i = i.wrapping_mul(0xe170893d);
|
|
i ^= p >> 16;
|
|
i ^= (i & w) >> 4;
|
|
i ^= p >> 8;
|
|
i = i.wrapping_mul(0x0929eb3f);
|
|
i ^= p >> 23;
|
|
i ^= (i & w) >> 1;
|
|
i = i.wrapping_mul(1 | (p >> 27));
|
|
i = i.wrapping_mul(0x6935fa69);
|
|
i ^= (i & w) >> 11;
|
|
i = i.wrapping_mul(0x74dcb303);
|
|
i ^= (i & w) >> 2;
|
|
i = i.wrapping_mul(0x9e501cc3);
|
|
i ^= (i & w) >> 2;
|
|
i = i.wrapping_mul(0xc860a3df);
|
|
|
|
i &= w;
|
|
i ^= i >> 5;
|
|
|
|
// If index is within range [0, l), we are done.
|
|
// Otherwise, we loop again.
|
|
if i < l {
|
|
break;
|
|
}
|
|
}
|
|
|
|
(i.wrapping_add(p)) % l
|
|
}
|
|
|
|
pub fn multiply_generator(c: &[u32], mut a: u32) -> u32 {
|
|
let mut v = 0;
|
|
let mut i = 0;
|
|
|
|
while a != 0 && i < c.len() {
|
|
if (a & 1) != 0 {
|
|
v ^= c[i];
|
|
}
|
|
a >>= 1;
|
|
i += 1;
|
|
}
|
|
v
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct NoRandomizer;
|
|
impl NoRandomizer {
|
|
pub fn scramble(&self, v: u32) -> u32 {
|
|
v
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct BinaryPermuteScrambler {
|
|
pub permutation: u32,
|
|
}
|
|
|
|
impl BinaryPermuteScrambler {
|
|
pub fn new(perm: u32) -> Self {
|
|
Self { permutation: perm }
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn scramble(&self, v: u32) -> u32 {
|
|
self.permutation ^ v
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct FastOwenScrambler {
|
|
pub seed: u32,
|
|
}
|
|
|
|
impl FastOwenScrambler {
|
|
pub fn new(seed: u32) -> Self {
|
|
Self { seed }
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn scramble(&self, mut v: u32) -> u32 {
|
|
v = v.reverse_bits();
|
|
|
|
// v ^= v * 0x3d20adea;
|
|
v ^= v.wrapping_mul(0x3d20adea);
|
|
|
|
// v += seed;
|
|
v = v.wrapping_add(self.seed);
|
|
|
|
// v *= (seed >> 16) | 1;
|
|
v = v.wrapping_mul((self.seed >> 16) | 1);
|
|
|
|
// v ^= v * 0x05526c56;
|
|
v ^= v.wrapping_mul(0x05526c56);
|
|
|
|
// v ^= v * 0x53a22864;
|
|
v ^= v.wrapping_mul(0x53a22864);
|
|
|
|
v.reverse_bits()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct OwenScrambler {
|
|
pub seed: u32,
|
|
}
|
|
|
|
impl OwenScrambler {
|
|
pub fn new(seed: u32) -> Self {
|
|
Self { seed }
|
|
}
|
|
|
|
#[inline]
|
|
pub fn scramble(&self, mut v: u32) -> u32 {
|
|
if (self.seed & 1) != 0 {
|
|
v ^= 1 << 31;
|
|
}
|
|
|
|
for b in 1..32 {
|
|
let mask = (!0u32) << (32 - b);
|
|
let input = ((v & mask) ^ self.seed) as u64;
|
|
if (mix_bits(input) as u32 & (1 << b)) != 0 {
|
|
v ^= 1 << (31 - b);
|
|
}
|
|
}
|
|
v
|
|
}
|
|
}
|
|
|
|
pub trait Scrambler {
|
|
fn scramble(&self, v: u32) -> u32;
|
|
}
|
|
|
|
impl Scrambler for NoRandomizer {
|
|
fn scramble(&self, v: u32) -> u32 {
|
|
self.scramble(v)
|
|
}
|
|
}
|
|
|
|
impl Scrambler for BinaryPermuteScrambler {
|
|
fn scramble(&self, v: u32) -> u32 {
|
|
self.scramble(v)
|
|
}
|
|
}
|
|
impl Scrambler for FastOwenScrambler {
|
|
fn scramble(&self, v: u32) -> u32 {
|
|
self.scramble(v)
|
|
}
|
|
}
|
|
impl Scrambler for OwenScrambler {
|
|
fn scramble(&self, v: u32) -> u32 {
|
|
self.scramble(v)
|
|
}
|
|
}
|
|
|
|
impl<F: Fn(u32) -> u32> Scrambler for F {
|
|
fn scramble(&self, v: u32) -> u32 {
|
|
(*self)(v)
|
|
}
|
|
}
|
|
|
|
const N_SOBOL_DIMENSIONS: usize = 1024;
|
|
const SOBOL_MATRIX_SIZE: usize = 52;
|
|
#[inline]
|
|
pub fn sobol_sample<S: Scrambler>(mut a: u64, dimension: usize, randomizer: S) -> Float {
|
|
debug_assert!(
|
|
dimension < N_SOBOL_DIMENSIONS,
|
|
"Sobol dimension out of bounds"
|
|
);
|
|
|
|
debug_assert!(a < (1u64 << SOBOL_MATRIX_SIZE), "Sobol index too large");
|
|
|
|
let mut v: u32 = 0;
|
|
let mut i = dimension * SOBOL_MATRIX_SIZE;
|
|
|
|
while a != 0 {
|
|
if (a & 1) != 0 {
|
|
v ^= SOBOL_MATRICES_32[i];
|
|
}
|
|
a >>= 1;
|
|
i += 1;
|
|
}
|
|
|
|
v = randomizer.scramble(v);
|
|
|
|
let float_val = (v as Float) * 2.3283064365386963e-10;
|
|
|
|
float_val.min(ONE_MINUS_EPSILON)
|
|
}
|
|
|
|
pub fn sobol_interval_to_index(m: u32, frame: u64, p: Point2i) -> u64 {
|
|
if m == 0 {
|
|
return frame;
|
|
}
|
|
|
|
let m2 = m << 1;
|
|
let mut index = frame << m2;
|
|
|
|
let mut delta = 0u64;
|
|
let mut current_frame = frame;
|
|
let mut c = 0;
|
|
|
|
while current_frame != 0 {
|
|
if (current_frame & 1) != 0 {
|
|
delta ^= VDC_SOBOL_MATRICES[(m - 1) as usize][c];
|
|
}
|
|
current_frame >>= 1;
|
|
c += 1;
|
|
}
|
|
|
|
let px = p.x() as u32 as u64;
|
|
let py = p.y() as u32 as u64;
|
|
|
|
let mut b = ((px << m) | py) ^ delta;
|
|
|
|
let mut c = 0;
|
|
while b != 0 {
|
|
if (b & 1) != 0 {
|
|
index ^= VDC_SOBOL_MATRICES_INV[(m - 1) as usize][c];
|
|
}
|
|
b >>= 1;
|
|
c += 1;
|
|
}
|
|
|
|
index
|
|
}
|
|
|
|
// MATRIX STUFF (TEST THOROUGHLY)
|
|
#[derive(Debug, Copy, Clone)]
|
|
pub struct Matrix<T, const R: usize, const C: usize> {
|
|
m: [[T; C]; R],
|
|
}
|
|
|
|
impl<T, const R: usize, const C: usize> Matrix<T, R, C> {
|
|
pub const fn new(data: [[T; C]; R]) -> Self {
|
|
Self { m: data }
|
|
}
|
|
|
|
pub fn zero() -> Self
|
|
where
|
|
T: Clone + Zero,
|
|
{
|
|
let m: [[T; C]; R] = std::array::from_fn(|_| std::array::from_fn(|_| T::zero()));
|
|
Self { m }
|
|
}
|
|
|
|
pub fn transpose(&self) -> Matrix<T, C, R>
|
|
where
|
|
T: Clone + Zero,
|
|
{
|
|
let mut result = Matrix::<T, C, R>::zero();
|
|
for i in 0..R {
|
|
for j in 0..C {
|
|
result.m[j][i] = self.m[i][j].clone();
|
|
}
|
|
}
|
|
result
|
|
}
|
|
}
|
|
|
|
impl<T: Zero + Clone, const R: usize, const C: usize> Default for Matrix<T, R, C> {
|
|
fn default() -> Self {
|
|
Self::zero()
|
|
}
|
|
}
|
|
|
|
impl<T, const R: usize, const C: usize> From<[[T; C]; R]> for Matrix<T, R, C> {
|
|
fn from(m: [[T; C]; R]) -> Self {
|
|
Self::new(m)
|
|
}
|
|
}
|
|
|
|
impl<T, const R: usize, const C: usize> Index<(usize, usize)> for Matrix<T, R, C> {
|
|
type Output = T;
|
|
fn index(&self, index: (usize, usize)) -> &Self::Output {
|
|
&self.m[index.0][index.1]
|
|
}
|
|
}
|
|
|
|
impl<T, const R: usize, const C: usize> IndexMut<(usize, usize)> for Matrix<T, R, C> {
|
|
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
|
|
&mut self.m[index.0][index.1]
|
|
}
|
|
}
|
|
|
|
impl<T: PartialEq, const R: usize, const C: usize> PartialEq for Matrix<T, R, C> {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.m == other.m
|
|
}
|
|
}
|
|
|
|
impl<T: Eq, const R: usize, const C: usize> Eq for Matrix<T, R, C> {}
|
|
|
|
impl<T, const R: usize, const C: usize, const N: usize> Mul<Matrix<T, N, C>> for Matrix<T, R, N>
|
|
where
|
|
T: Mul<Output = T> + Add<Output = T> + Clone + Zero,
|
|
{
|
|
type Output = Matrix<T, R, C>;
|
|
fn mul(self, rhs: Matrix<T, N, C>) -> Self::Output {
|
|
let mut result = Matrix::<T, R, C>::zero();
|
|
for i in 0..R {
|
|
for j in 0..C {
|
|
let mut sum = T::zero();
|
|
for k in 0..N {
|
|
sum = sum + self.m[i][k].clone() * rhs.m[k][j].clone();
|
|
}
|
|
result.m[i][j] = sum;
|
|
}
|
|
}
|
|
result
|
|
}
|
|
}
|
|
|
|
impl<T, const R: usize, const C: usize> Mul<T> for Matrix<T, R, C>
|
|
where
|
|
T: Mul<Output = T> + Clone + Zero,
|
|
{
|
|
type Output = Self;
|
|
fn mul(self, rhs: T) -> Self::Output {
|
|
let mut result = Self::zero();
|
|
for i in 0..R {
|
|
for j in 0..C {
|
|
result.m[i][j] = self.m[i][j].clone() * rhs.clone();
|
|
}
|
|
}
|
|
result
|
|
}
|
|
}
|
|
|
|
impl<const R: usize, const C: usize> Mul<Matrix<Float, R, C>> for Float {
|
|
type Output = Matrix<Float, R, C>;
|
|
fn mul(self, rhs: Matrix<Float, R, C>) -> Self::Output {
|
|
rhs * self
|
|
}
|
|
}
|
|
|
|
impl<T, const R: usize, const C: usize> Div<T> for Matrix<T, R, C>
|
|
where
|
|
T: Div<Output = T> + Clone + Zero,
|
|
{
|
|
type Output = Self;
|
|
fn div(self, rhs: T) -> Self::Output {
|
|
let mut result = Self::zero();
|
|
for i in 0..R {
|
|
for j in 0..C {
|
|
result.m[i][j] = self.m[i][j].clone() / rhs.clone();
|
|
}
|
|
}
|
|
result
|
|
}
|
|
}
|
|
|
|
impl<T: Display, const R: usize, const C: usize> Display for Matrix<T, R, C> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
let mut col_widths = [0; C];
|
|
for row in self.m.iter() {
|
|
for (j, element) in row.iter().enumerate() {
|
|
let width = format!("{}", element).len();
|
|
if width > col_widths[j] {
|
|
col_widths[j] = width;
|
|
}
|
|
}
|
|
}
|
|
|
|
for i in 0..R {
|
|
write!(f, "[")?;
|
|
#[allow(clippy::needless_range_loop)]
|
|
for j in 0..C {
|
|
write!(f, "{: >width$} ", self.m[i][j], width = col_widths[j])?;
|
|
}
|
|
writeln!(f, "]")?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl<T, const R: usize, const C: usize> Add for Matrix<T, R, C>
|
|
where
|
|
T: Copy + Add<Output = T> + Zero,
|
|
{
|
|
type Output = Self;
|
|
fn add(self, rhs: Self) -> Self::Output {
|
|
let mut result = Matrix::<T, R, C>::zero();
|
|
for i in 0..R {
|
|
for j in 0..C {
|
|
result.m[i][j] = self.m[i][j] + rhs.m[i][j];
|
|
}
|
|
}
|
|
result
|
|
}
|
|
}
|
|
|
|
pub type SquareMatrix<T, const N: usize> = Matrix<T, N, N>;
|
|
|
|
impl<T, const N: usize> SquareMatrix<T, N> {
|
|
pub fn identity() -> Self
|
|
where
|
|
T: Copy + Zero + One,
|
|
{
|
|
Self {
|
|
m: std::array::from_fn(|i| {
|
|
std::array::from_fn(|j| if i == j { T::one() } else { T::zero() })
|
|
}),
|
|
}
|
|
}
|
|
|
|
pub fn diag(v: &[T]) -> Self
|
|
where
|
|
T: Zero + Copy,
|
|
{
|
|
let mut m = [[T::zero(); N]; N];
|
|
for i in 0..N {
|
|
m[i][i] = v[i];
|
|
}
|
|
Self { m }
|
|
}
|
|
|
|
pub fn is_identity(&self) -> bool
|
|
where
|
|
T: Zero + One + PartialEq,
|
|
{
|
|
for i in 0..N {
|
|
for j in 0..N {
|
|
if i == j {
|
|
if self.m[i][j] != T::one() {
|
|
return false;
|
|
}
|
|
} else if self.m[i][j] != T::zero() {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
true
|
|
}
|
|
|
|
pub fn trace(&self) -> T
|
|
where
|
|
T: Zero + Copy,
|
|
{
|
|
let mut sum = T::zero();
|
|
for i in 0..N {
|
|
sum = sum + self.m[i][i];
|
|
}
|
|
sum
|
|
}
|
|
}
|
|
|
|
impl<T, const N: usize> SquareMatrix<T, N>
|
|
where
|
|
T: NumFloat + Sum + Product + Copy,
|
|
{
|
|
pub fn inverse(&self) -> Result<Self, InversionError> {
|
|
if N == 0 {
|
|
return Err(InversionError::EmptyMatrix);
|
|
}
|
|
|
|
let mut mat = self.m;
|
|
let mut inv = Self::identity();
|
|
|
|
for i in 0..N {
|
|
let pivot_row = (i..N)
|
|
.max_by(|&a, &b| mat[a][i].abs().partial_cmp(&mat[b][i].abs()).unwrap())
|
|
.unwrap_or(i);
|
|
|
|
if pivot_row != i {
|
|
mat.swap(i, pivot_row);
|
|
inv.m.swap(i, pivot_row);
|
|
}
|
|
|
|
let pivot = mat[i][i];
|
|
if pivot.is_zero() {
|
|
return Err(InversionError::SingularMatrix);
|
|
}
|
|
|
|
let inv_pivot = T::one() / pivot;
|
|
mat[i][i..].iter_mut().for_each(|x| *x = *x * inv_pivot);
|
|
inv[i].iter_mut().for_each(|x| *x = *x * inv_pivot);
|
|
|
|
for j in 0..N {
|
|
if i != j {
|
|
let factor = mat[j][i];
|
|
#[allow(clippy::needless_range_loop)]
|
|
for k in i..N {
|
|
mat[j][k] = mat[j][k] - factor * mat[i][k];
|
|
}
|
|
|
|
#[allow(clippy::needless_range_loop)]
|
|
for k in 0..N {
|
|
inv.m[j][k] = inv.m[j][k] - factor * inv.m[i][k];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(inv)
|
|
}
|
|
|
|
pub fn determinant(&self) -> T {
|
|
let m = &self.m;
|
|
|
|
match N {
|
|
0 => T::one(),
|
|
1 => m[0][0],
|
|
2 => m[0][0] * m[1][1] - m[0][1] * m[1][0],
|
|
3 => {
|
|
let a = m[0][0];
|
|
let b = m[0][1];
|
|
let c = m[0][2];
|
|
let d = m[1][0];
|
|
let e = m[1][1];
|
|
let f = m[1][2];
|
|
let g = m[2][0];
|
|
let h = m[2][1];
|
|
let i = m[2][2];
|
|
|
|
a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)
|
|
}
|
|
4 => {
|
|
let det3 =
|
|
|m11: T, m12: T, m13: T, m21: T, m22: T, m23: T, m31: T, m32: T, m33: T| -> T {
|
|
m11 * (m22 * m33 - m23 * m32) - m12 * (m21 * m33 - m23 * m31)
|
|
+ m13 * (m21 * m32 - m22 * m31)
|
|
};
|
|
|
|
let c0 = det3(
|
|
m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], m[3][3],
|
|
);
|
|
let c1 = det3(
|
|
m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], m[3][3],
|
|
);
|
|
let c2 = det3(
|
|
m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], m[3][3],
|
|
);
|
|
let c3 = det3(
|
|
m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], m[3][2],
|
|
);
|
|
|
|
m[0][0] * c0 - m[0][1] * c1 + m[0][2] * c2 - m[0][3] * c3
|
|
}
|
|
_ => {
|
|
// Fallback to LU decomposition for N > 4
|
|
let mut lum = self.m;
|
|
let mut parity = 0;
|
|
|
|
for i in 0..N {
|
|
let max_row = (i..N)
|
|
.max_by(|&a, &b| lum[a][i].abs().partial_cmp(&lum[b][i].abs()).unwrap())
|
|
.unwrap_or(i);
|
|
|
|
if max_row != i {
|
|
lum.swap(i, max_row);
|
|
parity += 1;
|
|
}
|
|
|
|
// Singular matrix
|
|
if lum[i][i] == T::zero() {
|
|
return T::zero();
|
|
}
|
|
|
|
// Gaussian elimination
|
|
for j in (i + 1)..N {
|
|
let factor = lum[j][i] / lum[i][i];
|
|
|
|
#[allow(clippy::needless_range_loop)]
|
|
for k in i..N {
|
|
let val_i = lum[i][k];
|
|
lum[j][k] = lum[j][k] - factor * val_i;
|
|
}
|
|
}
|
|
}
|
|
|
|
let det: T = (0..N).map(|i| lum[i][i]).product();
|
|
if parity < 0 { -det } else { det }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T, const N: usize> Index<usize> for SquareMatrix<T, N> {
|
|
type Output = [T; N];
|
|
|
|
fn index(&self, index: usize) -> &Self::Output {
|
|
&self.m[index]
|
|
}
|
|
}
|
|
|
|
impl<T, const N: usize> IndexMut<usize> for SquareMatrix<T, N> {
|
|
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
|
&mut self.m[index]
|
|
}
|
|
}
|
|
|
|
impl<T, const N: usize> Mul<Vector<T, N>> for SquareMatrix<T, N>
|
|
where
|
|
T: Copy + Mul<Output = T> + Sum + Default,
|
|
{
|
|
type Output = Vector<T, N>;
|
|
fn mul(self, rhs: Vector<T, N>) -> Self::Output {
|
|
let arr = std::array::from_fn(|i| self.m[i].iter().zip(&rhs.0).map(|(m, v)| *m * *v).sum());
|
|
Vector(arr)
|
|
}
|
|
}
|
|
|
|
impl<T, const N: usize> Mul<Point<T, N>> for SquareMatrix<T, N>
|
|
where
|
|
T: Copy + Mul<Output = T> + Sum + Default,
|
|
{
|
|
type Output = Point<T, N>;
|
|
fn mul(self, rhs: Point<T, N>) -> Self::Output {
|
|
let arr = std::array::from_fn(|i| self.m[i].iter().zip(&rhs.0).map(|(m, v)| *m * *v).sum());
|
|
Point(arr)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn assert_matrix_approx_eq<const N: usize>(a: &SquareMatrix<f64, N>, b: &SquareMatrix<f64, N>) {
|
|
const EPSILON: f64 = 1e-9;
|
|
for i in 0..N {
|
|
for j in 0..N {
|
|
assert!(
|
|
(a[i][j] - b[i][j]).abs() < EPSILON,
|
|
"Matrices differ at ({},{}): {} vs {}\nLeft:\n{}\nRight:\n{}",
|
|
i,
|
|
j,
|
|
a[i][j],
|
|
b[i][j],
|
|
a,
|
|
b
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_inverse_2x2() {
|
|
let m = SquareMatrix {
|
|
m: [[4.0, 7.0], [2.0, 6.0]],
|
|
};
|
|
let identity = SquareMatrix::identity();
|
|
let inv = m.inverse().expect("Matrix should be invertible");
|
|
let product = m * inv;
|
|
assert_matrix_approx_eq(&product, &identity);
|
|
}
|
|
|
|
#[test]
|
|
fn test_inverse_3x3() {
|
|
let m = SquareMatrix::new([[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]]);
|
|
let identity = SquareMatrix::identity();
|
|
let inv = m.inverse().expect("Matrix should be invertible");
|
|
let product = m.clone() * inv.clone();
|
|
let product_inv = inv.clone() * m.clone();
|
|
assert_matrix_approx_eq(&product, &identity);
|
|
assert_matrix_approx_eq(&product_inv, &identity);
|
|
}
|
|
|
|
#[test]
|
|
fn test_singular_inverse() {
|
|
let m = SquareMatrix {
|
|
m: [[1.0, 2.0], [2.0, 4.0]],
|
|
}; // Determinant is 0
|
|
assert!(m.inverse().is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiplication_2x2() {
|
|
let a = SquareMatrix {
|
|
m: [[1.0, 2.0], [3.0, 4.0]],
|
|
};
|
|
let b = SquareMatrix {
|
|
m: [[2.0, 0.0], [1.0, 2.0]],
|
|
};
|
|
let expected = SquareMatrix {
|
|
m: [[4.0, 4.0], [10.0, 8.0]],
|
|
};
|
|
let result = a * b;
|
|
assert_matrix_approx_eq(&result, &expected);
|
|
}
|
|
|
|
#[test]
|
|
fn test_determinant_3x3() {
|
|
let m = SquareMatrix {
|
|
m: [[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]],
|
|
};
|
|
assert_eq!(m.determinant(), 1.0);
|
|
}
|
|
}
|