Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions diskann-quantization/src/__codegen/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ pub fn bits_v4_ip_bu2_bu2(arch: V4, x: USlice<'_, 2>, y: USlice<'_, 2>) -> MR<u3
arch.run2_inline(distances::InnerProduct, x, y)
}

#[inline(never)]
pub fn bits_v4_l2_bu4_bu4(arch: V4, x: USlice<'_, 4>, y: USlice<'_, 4>) -> MR<u32> {
arch.run2_inline(distances::SquaredL2, x, y)
}

#[inline(never)]
pub fn bits_v4_ip_bu4_bu4(arch: V4, x: USlice<'_, 4>, y: USlice<'_, 4>) -> MR<u32> {
arch.run2_inline(distances::InnerProduct, x, y)
}

//------------//
// Transposed //
//------------//
Expand Down
271 changes: 269 additions & 2 deletions diskann-quantization/src/bits/distances.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,154 @@ impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_,
}
}

/// Compute the squared L2 distance between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
///
/// # Implementation Notes
///
/// This implementation is optimized for x86 with the AVX-512 vector extension.
/// It scales the V3 approach to 512-bit registers: we load data as `u32x16`, shift and
/// mask to extract 4-bit nibbles at 16-bit granularity (`0x000f000f` mask), reinterpret
/// as `i16x32`, compute differences, and use `_mm512_madd_epi16` via `dot_simd` to
/// accumulate squared differences into `i32x16`.
///
/// AVX-512 does not have 16-bit integer bit-shift instructions, so we use 32-bit integer
/// shifts and then bit-cast to 16-bit intrinsics, which works because we apply the same
/// shift to all lanes.
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
for SquaredL2
{
#[expect(non_camel_case_types)]
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V4,
x: USlice<'_, 4>,
y: USlice<'_, 4>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;

type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
type i16s = <diskann_wide::arch::x86_64::V4 as Architecture>::i16x32;

let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();

let mut i = 0;
let mut s: u32 = 0;

// The number of 32-bit blocks over the underlying slice.
let blocks = len / 8;
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mut s2 = i32s::default(arch);
let mut s3 = i32s::default(arch);
let mask = u32s::splat(arch, 0x000f000f);
while i + 16 < blocks {
// SAFETY: We have checked that `i + 16 < blocks` which means the address
// range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::<u32>())` is
// valid.
//
// The load has no alignment requirements.
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };

let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);
Copy link
Copy Markdown
Contributor

@hildebrandmw hildebrandmw May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking the initiative! You can do much better by targeting this method. Unlike the dot-product from 16-bit integers, it does the dot product of 64 8-bit numbers and the accumulation with a i32x16 in a single instruction 😄

Ignore me. It's too late in the day. Forgot this was L2.


let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);

let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);

let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);

i += 16;
}

let remainder = blocks - i;

// SAFETY: At least one value of type `u32` is valid for an unaligned starting
// at offset `i`. The exact number is computed as `remainder`.
//
// The predicated load is guaranteed not to access memory after `remainder` and
// has no alignment requirements.
let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };

let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);

let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);

let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);

let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);

i += remainder;

s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
}

// Convert blocks to indexes.
i *= 8;

// Deal with the remainder the slow way.
if i != len {
// Outline the fallback routine to keep code-generation at this level cleaner.
#[inline(never)]
fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
let mut s: i32 = 0;
for i in from..x.len() {
// SAFETY: `i` is guaranteed to be less than `x.len()`.
let ix = unsafe { x.get_unchecked(i) } as i32;
// SAFETY: `i` is guaranteed to be less than `y.len()`.
let iy = unsafe { y.get_unchecked(i) } as i32;
let d = ix - iy;
s += d * d;
}
s as u32
}
s += fallback(x, y, i);
}

Ok(MV::new(s))
}
}

/// Compute the squared L2 distance between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
Expand Down Expand Up @@ -797,7 +945,7 @@ impl_fallback_l2!(7, 6, 5, 4, 3, 2);
retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);

#[cfg(target_arch = "x86_64")]
retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 3, 2);

dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -1115,6 +1263,126 @@ impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_,
}
}

/// Compute the inner product between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
///
/// # Implementation Notes
///
/// This is optimized around the `__mm512_dpbusd_epi32` VNNI instruction, which computes the
/// pairwise dot product between vectors of 8-bit integers and accumulates groups of 4 with
/// an `i32` accumulation vector.
///
/// For 4-bit values, each byte holds 2 nibbles. We load data as `u32x16`, mask with
/// `0x0f0f0f0f` to extract the low nibbles as bytes, and shift right by 4 then mask to
/// extract the high nibbles. This gives us `u8x64` / `i8x64` operands for VNNI, requiring
/// only 2 shift positions instead of 4 for the V3 `madd_epi16` approach.
///
/// Since AVX-512 does not have an 8-bit shift instruction, we load data as `u32x16`
/// (which has a native shift) and bit-cast to `u8x64` as needed.
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
for InnerProduct
{
#[expect(non_camel_case_types)]
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V4,
x: USlice<'_, 4>,
y: USlice<'_, 4>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;

type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;

let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();

let mut i = 0;
let mut s: u32 = 0;

// The number of 32-bit blocks over the underlying slice.
// Each u32 holds 8 nibbles = 8 four-bit values.
let blocks = len.div_ceil(8);
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mask = u32s::splat(arch, 0x0f0f0f0f);
while i + 16 < blocks {
// SAFETY: We have checked that `i + 16 < blocks` which means the address
// range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::<u32>())` is
// valid.
//
// The load has no alignment requirements.
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };

let wx: u8s = (vx & mask).reinterpret_simd();
let wy: i8s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);

let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);

i += 16;
}

// Here
// * `len / 2` gives the number of full bytes (2 nibbles per byte)
// * `4 * i` gives the number of bytes processed (4 bytes per u32 × i u32s).
let remainder = len / 2 - 4 * i;

// SAFETY: At least `remainder` bytes are valid starting at an offset of `i`.
//
// The predicated load is guaranteed not to access memory after `remainder` and
// has no alignment requirements.
let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
let vx: u32s = vx.reinterpret_simd();

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
let vy: u32s = vy.reinterpret_simd();

let wx: u8s = (vx & mask).reinterpret_simd();
let wy: i8s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);

let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);

s = (s0 + s1).sum_tree() as u32;
i = (4 * i) + remainder;
}

// Convert bytes to nibble indexes.
i *= 2;

// Deal with the remainder the slow way (at most 1 element).
debug_assert!(len - i <= 1);
if i != len {
// SAFETY: `i` is guaranteed to be less than `x.len()`.
let ix = unsafe { x.get_unchecked(i) } as u32;
// SAFETY: `i` is guaranteed to be less than `y.len()`.
let iy = unsafe { y.get_unchecked(i) } as u32;
s += ix * iy;
}

Ok(MV::new(s))
}
}

/// Compute the inner product between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
Expand Down Expand Up @@ -1346,7 +1614,6 @@ retarget!(
7,
6,
5,
4,
3,
(8, 4),
(8, 2),
Expand Down
Loading
Loading