-
Notifications
You must be signed in to change notification settings - Fork 145
feat: turboquant encoding for vectors #7167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
lwwmanning
wants to merge
25
commits into
develop
Choose a base branch
from
claude/admiring-lichterman
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,181
−1
Draft
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
a208542
feat[turboquant]: add TurboQuant vector quantization encoding
lwwmanning 6c32a2e
feat[turboquant]: add TurboQuantCompressor and WriteStrategyBuilder i…
lwwmanning 76b63a7
refactor[turboquant]: integrate into BtrBlocks compressor directly
lwwmanning a31a0a0
bench[turboquant]: add compression/decompression throughput benchmarks
lwwmanning 3be5c3d
perf[turboquant]: replace dense rotation with randomized Hadamard tra…
lwwmanning b01d9f6
test[turboquant]: add theoretical error bound and inner product bias …
lwwmanning 339df6e
chore[turboquant]: fix review issues and generate public-api.lock
lwwmanning fae6c61
chore[turboquant]: review cleanup — tighter tests, naming, validation
lwwmanning 6b7303f
docs[turboquant]: add crate-level docs with compression ratios and er…
lwwmanning 66e31f9
feat[turboquant]: support 1-8 bit quantization
lwwmanning de07e4c
feat[turboquant]: support 9-bit Prod for tensor core int8 GEMM
lwwmanning 38a24ab
bench[turboquant]: add dim 1024 and 1536 benchmarks
lwwmanning ed56bc6
feat[turboquant]: add rotation sign export/import and hot-path inverse
lwwmanning fb6e118
feat[turboquant]: define TurboQuantMSEArray and TurboQuantQJLArray
lwwmanning cfd9374
feat[turboquant]: add new compression functions for cascaded arrays
lwwmanning 47ee19e
refactor[btrblocks]: simplify TurboQuant compressor for cascaded arrays
lwwmanning 3901305
chore[turboquant]: regenerate public-api.lock for new array types
lwwmanning 5590fee
refactor[turboquant]: restructure into subdirectory modules, delete d…
lwwmanning 2d84cbf
test[turboquant]: improve test coverage and add explanatory comments
lwwmanning 334d31e
perf[turboquant]: restore fast SIMD-friendly decode by expanding stor…
lwwmanning 8281972
fix[turboquant]: address PR review findings
lwwmanning 2d6d8dc
fix[turboquant]: second-round review fixes and merge conflict resolution
lwwmanning c51db31
refactor[turboquant]: simplify code from review findings
lwwmanning 77438ca
fix[turboquant]: address PR review comments from AdamGS
lwwmanning b75e1ac
chore[turboquant]: cleanup from second simplify pass
lwwmanning File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| [package] | ||
| name = "vortex-turboquant" | ||
| authors = { workspace = true } | ||
| categories = { workspace = true } | ||
| description = "Vortex TurboQuant vector quantization encoding" | ||
| edition = { workspace = true } | ||
| homepage = { workspace = true } | ||
| include = { workspace = true } | ||
| keywords = { workspace = true } | ||
| license = { workspace = true } | ||
| readme = { workspace = true } | ||
| repository = { workspace = true } | ||
| rust-version = { workspace = true } | ||
| version = { workspace = true } | ||
|
|
||
| [lints] | ||
| workspace = true | ||
|
|
||
| [dependencies] | ||
| prost = { workspace = true } | ||
| rand = { workspace = true } | ||
| vortex-array = { workspace = true } | ||
| vortex-buffer = { workspace = true } | ||
| vortex-error = { workspace = true } | ||
| vortex-fastlanes = { workspace = true } | ||
| vortex-session = { workspace = true } | ||
| vortex-utils = { workspace = true } | ||
|
|
||
| [dev-dependencies] | ||
| rand_distr = { workspace = true } | ||
| rstest = { workspace = true } | ||
| vortex-array = { workspace = true, features = ["_test-harness"] } |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,271 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| //! Max-Lloyd centroid computation for TurboQuant scalar quantizers. | ||
| //! | ||
| //! Pre-computes optimal scalar quantizer centroids for the marginal distribution of coordinates | ||
| //! after random rotation of a unit-norm vector. In high dimensions, each coordinate of a randomly | ||
| //! rotated unit vector follows a distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, | ||
| //! which converges to `N(0, 1/d)`. The Max-Lloyd algorithm finds optimal quantization centroids | ||
| //! that minimize MSE for this distribution. | ||
|
|
||
| use std::sync::LazyLock; | ||
|
|
||
| use vortex_error::VortexResult; | ||
| use vortex_error::vortex_bail; | ||
| use vortex_utils::aliases::dash_map::DashMap; | ||
|
|
||
| /// Number of numerical integration points for computing conditional expectations. | ||
| const INTEGRATION_POINTS: usize = 1000; | ||
|
|
||
| /// Max-Lloyd convergence threshold. | ||
| const CONVERGENCE_EPSILON: f64 = 1e-12; | ||
|
|
||
| /// Maximum iterations for Max-Lloyd algorithm. | ||
| const MAX_ITERATIONS: usize = 200; | ||
|
|
||
| /// Global centroid cache keyed by (dimension, bit_width). | ||
| static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Vec<f32>>> = LazyLock::new(DashMap::default); | ||
|
|
||
| /// Get or compute cached centroids for the given dimension and bit width. | ||
| /// | ||
| /// Returns `2^bit_width` centroids sorted in ascending order, representing | ||
| /// optimal scalar quantization levels for the coordinate distribution after | ||
| /// random rotation in `dimension`-dimensional space. | ||
| pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> { | ||
| if !(1..=8).contains(&bit_width) { | ||
| vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}"); | ||
| } | ||
| if dimension < 2 { | ||
| vortex_bail!("TurboQuant dimension must be >= 2, got {dimension}"); | ||
| } | ||
|
|
||
| if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { | ||
| return Ok(centroids.clone()); | ||
| } | ||
|
|
||
| let centroids = max_lloyd_centroids(dimension, bit_width); | ||
| CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); | ||
| Ok(centroids) | ||
| } | ||
|
|
||
| /// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. | ||
| /// | ||
| /// Operates on the marginal distribution of a single coordinate of a randomly | ||
| /// rotated unit vector in d dimensions. The PDF is: | ||
| /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` | ||
| /// where `C_d` is the normalizing constant. | ||
| fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> { | ||
| let num_centroids = 1usize << bit_width; | ||
| let dim = dimension as f64; | ||
|
|
||
| // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. | ||
| let exponent = (dim - 3.0) / 2.0; | ||
|
|
||
| // Initialize centroids uniformly on [-1, 1]. | ||
| let mut centroids: Vec<f64> = (0..num_centroids) | ||
| .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) | ||
| .collect(); | ||
|
|
||
| for _ in 0..MAX_ITERATIONS { | ||
| // Compute decision boundaries (midpoints between adjacent centroids). | ||
| let mut boundaries = Vec::with_capacity(num_centroids + 1); | ||
| boundaries.push(-1.0); | ||
| for idx in 0..num_centroids - 1 { | ||
| boundaries.push((centroids[idx] + centroids[idx + 1]) / 2.0); | ||
| } | ||
| boundaries.push(1.0); | ||
|
|
||
| // Update each centroid to the conditional mean within its Voronoi cell. | ||
| let mut max_change = 0.0f64; | ||
| for idx in 0..num_centroids { | ||
| let lo = boundaries[idx]; | ||
| let hi = boundaries[idx + 1]; | ||
| let new_centroid = conditional_mean(lo, hi, exponent); | ||
| max_change = max_change.max((new_centroid - centroids[idx]).abs()); | ||
| centroids[idx] = new_centroid; | ||
| } | ||
|
|
||
| if max_change < CONVERGENCE_EPSILON { | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| #[allow(clippy::cast_possible_truncation)] | ||
| centroids.iter().map(|&val| val as f32).collect() | ||
| } | ||
|
|
||
| /// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. | ||
| /// | ||
| /// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` | ||
| /// on [-1, 1]. | ||
| fn conditional_mean(lo: f64, hi: f64, exponent: f64) -> f64 { | ||
| if (hi - lo).abs() < 1e-15 { | ||
| return (lo + hi) / 2.0; | ||
| } | ||
|
|
||
| let num_points = INTEGRATION_POINTS; | ||
| let dx = (hi - lo) / num_points as f64; | ||
|
|
||
| let mut numerator = 0.0; | ||
| let mut denominator = 0.0; | ||
|
|
||
| for step in 0..=num_points { | ||
| let x_val = lo + (step as f64) * dx; | ||
| let weight = pdf_unnormalized(x_val, exponent); | ||
|
|
||
| let trap_weight = if step == 0 || step == num_points { | ||
| 0.5 | ||
| } else { | ||
| 1.0 | ||
| }; | ||
|
|
||
| numerator += trap_weight * x_val * weight; | ||
| denominator += trap_weight * weight; | ||
| } | ||
|
|
||
| if denominator.abs() < 1e-30 { | ||
| (lo + hi) / 2.0 | ||
| } else { | ||
| numerator / denominator | ||
| } | ||
| } | ||
|
|
||
| /// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. | ||
| #[inline] | ||
| fn pdf_unnormalized(x_val: f64, exponent: f64) -> f64 { | ||
| (1.0 - x_val * x_val).max(0.0).powf(exponent) | ||
| } | ||
|
|
||
| /// Precompute decision boundaries (midpoints between adjacent centroids). | ||
| /// | ||
| /// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps | ||
| /// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, | ||
| /// and a value >= `boundaries[k-2]` maps to centroid `k-1`. | ||
| pub fn compute_boundaries(centroids: &[f32]) -> Vec<f32> { | ||
| centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() | ||
| } | ||
|
|
||
| /// Find the index of the nearest centroid using precomputed decision boundaries. | ||
| /// | ||
| /// `boundaries` must be the output of [`compute_boundaries`] for the corresponding | ||
| /// centroids. Uses binary search on the midpoints, avoiding distance comparisons | ||
| /// in the inner loop. | ||
| #[inline] | ||
| #[allow(clippy::cast_possible_truncation)] | ||
| pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { | ||
| debug_assert!( | ||
| boundaries.windows(2).all(|w| w[0] <= w[1]), | ||
| "boundaries must be sorted" | ||
| ); | ||
| boundaries.partition_point(|&b| b < value) as u8 | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use rstest::rstest; | ||
| use vortex_error::VortexResult; | ||
|
|
||
| use super::*; | ||
|
|
||
| #[rstest] | ||
| #[case(128, 1, 2)] | ||
| #[case(128, 2, 4)] | ||
| #[case(128, 3, 8)] | ||
| #[case(128, 4, 16)] | ||
| #[case(768, 2, 4)] | ||
| #[case(1536, 3, 8)] | ||
| fn centroids_have_correct_count( | ||
| #[case] dim: u32, | ||
| #[case] bits: u8, | ||
| #[case] expected: usize, | ||
| ) -> VortexResult<()> { | ||
| let centroids = get_centroids(dim, bits)?; | ||
| assert_eq!(centroids.len(), expected); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[rstest] | ||
| #[case(128, 1)] | ||
| #[case(128, 2)] | ||
| #[case(128, 3)] | ||
| #[case(128, 4)] | ||
| #[case(768, 2)] | ||
| fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { | ||
| let centroids = get_centroids(dim, bits)?; | ||
| for window in centroids.windows(2) { | ||
| assert!( | ||
| window[0] < window[1], | ||
| "centroids not sorted: {:?}", | ||
| centroids | ||
| ); | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[rstest] | ||
| #[case(128, 1)] | ||
| #[case(128, 2)] | ||
| #[case(256, 2)] | ||
| #[case(768, 2)] | ||
| fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { | ||
| let centroids = get_centroids(dim, bits)?; | ||
| let count = centroids.len(); | ||
| for idx in 0..count / 2 { | ||
| let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); | ||
| assert!( | ||
| diff < 1e-5, | ||
| "centroids not symmetric: c[{idx}]={}, c[{}]={}", | ||
| centroids[idx], | ||
| count - 1 - idx, | ||
| centroids[count - 1 - idx] | ||
| ); | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[rstest] | ||
| #[case(128, 1)] | ||
| #[case(128, 4)] | ||
| fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { | ||
| let centroids = get_centroids(dim, bits)?; | ||
| for &val in ¢roids { | ||
| assert!( | ||
| (-1.0..=1.0).contains(&val), | ||
| "centroid out of [-1, 1]: {val}", | ||
| ); | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn centroids_cached() -> VortexResult<()> { | ||
| let c1 = get_centroids(128, 2)?; | ||
| let c2 = get_centroids(128, 2)?; | ||
| assert_eq!(c1, c2); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn find_nearest_basic() -> VortexResult<()> { | ||
| let centroids = get_centroids(128, 2)?; | ||
| let boundaries = compute_boundaries(¢roids); | ||
| assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); | ||
| #[allow(clippy::cast_possible_truncation)] | ||
| let last_idx = (centroids.len() - 1) as u8; | ||
| assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); | ||
| for (idx, &cv) in centroids.iter().enumerate() { | ||
| #[allow(clippy::cast_possible_truncation)] | ||
| let expected = idx as u8; | ||
| assert_eq!(find_nearest_centroid(cv, &boundaries), expected); | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn rejects_invalid_params() { | ||
| assert!(get_centroids(128, 0).is_err()); | ||
| assert!(get_centroids(128, 9).is_err()); | ||
| assert!(get_centroids(1, 2).is_err()); | ||
| } | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think avoiding the branch with
max(0.0)is faster here? compiler explorer says its slightly more instructions