simd in faer

lifetime branding

before we explain what lifetime branding is, we'll go through a few code examples to show the motivation behind it.

assume we have a vector length N known at compile time. we can represent such an object in rust's type system as [T; N]. if we want to index into such an array, we need an index type that maintains an invariant that it's a valid index, that we check during construction of our type.

#[derive(Copy, Clone)]
pub struct Idx<const N: usize> {
    raw: usize,
}

impl<const N: usize> Idx<N> {
    pub fn new(index: usize) -> Self {
        assert!(index < N);
        Self { raw: index }
    }

    pub unsafe fn new_unchecked(index: usize) -> Self {
        debug_assert!(index < N);
        Self { raw: index }
    }

    pub fn get(self) -> usize {
        self.raw
    }
}

what can we do with such a type? since we know that the stored value is always a valid index for arrays of length N, we can implement the following unchecked indexing api.

impl<const N: usize, T> Index<Idx<N>> for [T; N] {
    type Output = T;

    fn index(&self, index: Idx<N>) -> &T {
        // SAFETY: this is sound because the value
        // inside `Idx<N>` is guaranteed to be less than `N`.
        unsafe { self.get_unchecked(index.get()) }
    }
}

now we can write

let data = [0, 1, 2];
let idx = Idx::<3>::new(1);

data[idx]; // look ma! no bound checks

you may notice that we didn't actually remove the bound check, we just moved it elsewhere. so what was the point of doing all this?

well, moving the bound check to the constructor lets us more reuse the index multiple times with a single check.

for _ in 0..100 {
    println!("{}", data[idx]); // no bound check inside the loop
}

we can also make use of the structure that our invariant carries to compose it in interesting ways. for example, we can write

impl<const N: usize> Idx<N> {
    pub fn sequence() -> impl Iterator<Self> {
        // SAFETY: `i < N`
        (0..N).map(|i| unsafe { Self::new_unchecked(i) })
    }
}

which lets us write

for i in Idx::sequence() {
    println!("{}", data[i]); // still no bound checks
}

sure, we could have just written

for d in &data {
    println!("{}", d); // iterators also allow us to skip bound checks
}

but this becomes cumbersome once the number of arrays starts growing.

for ((((d0, d1), d2), d3), d4) in
        data0.iter()
            .zip(&d1)
            .zip(&d2)
            .zip(&d3)
            .zip(&d4) {
    println!("{}", d0);
    println!("{}", d1);
    println!("{}", d2);
    println!("{}", d3);
    println!("{}", d4);
}

compare that to the following, which feels a lot nicer to write and read.

for i in Idx::sequence() {
    println!("{}", data0[i]);
    println!("{}", data1[i]);
    println!("{}", data2[i]);
    println!("{}", data3[i]);
    println!("{}", data4[i]);
}

of course, all of this is only doable because the length is known at compile time. faer doesn't really care about those, and focuses on runtime-shaped matrices instead. if only we could do the same thing with a runtime length...

🥁🥁🥁
introducing: lifetime branding, which lets us do nearly the exact same thing we just described, but with runtime shapes.

the first thing we need is a type that is similar to the const generic. something that we can guarantee carries the same shape value every time.

this type is

use std::marker::PhantomData;

#[derive(Copy, Clone)]
pub struct Dim<'N> {
    raw: usize,
    __marker__: InvariantLifetime<'N>,
}

type InvariantLifetime<'a> = PhantomData<fn(&'a ()) -> &'a ()>;

now, how can we create an instance of this type in a controlled way? the answer is inversion of control

impl Dim<'_> {
    pub fn with_new<R>(len: usize, f: impl for<'N> FnOnce(Dim<'N>) -> R) -> R {
        f(unsafe { Self::new_unchecked(len) })
    }

    pub unsafe fn new_unchecked(len: usize) -> Self {
        Self { raw: len, __marker__: PhantomData }
    }

    pub fn get(self) -> usize {
        self.raw
    }
}

how does this guarantee uniqueness? let's try to break it and see if that gives us any insights

Dim::with_new(3, |len3| {
    let mut copy = len3; // type invariant claims that this always holds the value `3`

    Dim::with_new(4, |len4| {
        copy = len4; // did we break it?
    });
});

turns out that no, we did not break it.
the borrow checker tells us (in a somewhat roundabout way) that we're not allowed to do this for some reason

error[E0521]: borrowed data escapes outside of closure
  --> src/main.rs:31:9
   |
28 |     let mut copy = len3; // type invariant claims that this always holds the value `3`
   |         -------- `copy` declared here, outside of the closure body
29 |
30 |     Dim::with_new(4, |len4| {
   |                       ---- `len4` is a reference that is only valid in the closure body
31 |         copy = len4; // did we break it?
   |         ^^^^^^^^^^^ `len4` escapes the closure body here
   |
   = note: requirement occurs because of the type `Dim<'_>`, which makes the generic argument `'_` invariant
   = note: the struct `Dim<'N>` is invariant over the parameter `'N`
   = help: see <https://doc.rust-lang.org/nomicon/subtyping.html> for more information about variance

error[E0521]: borrowed data escapes outside of closure
  --> src/main.rs:31:9
   |
27 |     Dim::with_new(3, |len3| {
   |                       ----
   |                       |
   |                       `len3` is a reference that is only valid in the closure body
   |                       has type `Dim<'1>`
...
31 |         copy = len4; // did we break it?
   |         ^^^^^^^^^^^
   |         |
   |         `len3` escapes the closure body here
   |         assignment requires that `'1` must outlive `'static`

what does all of this mean? ¯\_(ツ)_/¯

the rustc error message is unusually unhelpful, because we used lifetimes for something they weren't built for, but the important thing is that it works.

the actual details of the mechanisms behind this can be found in Aria Desires' thesis, section 6.3.

one thing that is still a bit annoying is having to create a nested scope with closures whenever we want to do this. it's a small price to pay, but it doesn't play well with control flow and error propagation. luckily, there's another technique implemented by the generativity crate, which is also a bit cryptic to understand. but lets us write the previous code like this:

make_guard!(len3);
make_guard!(len4);

let len3 = Dim::new(3, len3);
let len4 = Dim::new(4, len4);

what does this look like in practice?

well, the faer code based on these ideas looks roughly like this

fn sum(A: MatRef<'_, f64>) -> f64 {
    // `nrows` has type `Dim<'a>`
    // `ncols` has type `Dim<'b>`
    dims!({
        let nrows = A.nrows();
        let ncols = A.ncols();
    });

    let A = A.with_shape(nrows, ncols);

    let mut sum = 0.0;
    for j in ncols.indices() {
        for i in nrows.indices() {
            sum += A[(i, j)]; // no bound checks
        }
    }
}

and now for a simd example:

generic simd

use pulp::Simd;
use faer::utils::simd::SimdCtx;
use faer::ComplexField;

#[inline(always)]
fn sum_simd<T: ComplexField, S: Simd>(simd: S, A: MatRef<'_, T>) -> T {
    if let Some(A) = A.try_as_col_major() {
        // `nrows` has type `Dim<'a>`
        // `ncols` has type `Dim<'b>`
        dims!({
            let nrows = A.nrows();
            let ncols = A.ncols();
        });

        let simd = SimdCtx::<'_, T, S>::new(T::simd_ctx(simd), nrows);

        // instead of splitting the slice into a vector chunk and scalar chunk,
        // faer takes a different approach and slices it into three chunks for performance reasons.
        // elements of the head and tail that correspond to out of bound indices are replaced with
        // zeros and their corresponding memory addresses are not accessed.
        let (head, body, tail) = simd.indices();

        let A = A.with_shape(nrows, ncols);
        let mut sum = simd.zero();

        for j in ncols.indices() {
            let Aj = A.col(j);

            if let Some(i) = head {
                sum = simd.add(sum, simd.read(Aj, i));
            }
            for i in body.clone() {
                sum = simd.add(sum, simd.read(Aj, i));
            }
            if let Some(i) = tail {
                sum = simd.add(sum, simd.read(Aj, i));
            }
        }
        simd.reduce_sum(sum)
    } else {
        // skipping scalar fallback since it's not
        // interesting for this example
        todo!()
    }
}