diff --git a/Cargo.lock b/Cargo.lock index f43fe0b8314..9f951496659 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10023,6 +10023,7 @@ dependencies = [ "mimalloc", "parquet 58.0.0", "rand 0.10.0", + "rand_distr 0.6.0", "serde_json", "tokio", "tracing", @@ -10052,6 +10053,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-turboquant", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10212,6 +10214,7 @@ dependencies = [ "vortex-runend", "vortex-sequence", "vortex-sparse", + "vortex-turboquant", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10524,6 +10527,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-turboquant", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10923,6 +10927,22 @@ dependencies = [ "web-sys", ] +[[package]] +name = "vortex-turboquant" +version = "0.1.0" +dependencies = [ + "prost 0.14.3", + "rand 0.10.0", + "rand_distr 0.6.0", + "rstest", + "vortex-array", + "vortex-buffer", + "vortex-error", + "vortex-fastlanes", + "vortex-session", + "vortex-utils", +] + [[package]] name = "vortex-utils" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 66cbff64fcf..6951d54d7f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ members = [ "encodings/zigzag", "encodings/zstd", "encodings/bytebool", + "encodings/turboquant", # Benchmarks "benchmarks/lance-bench", "benchmarks/compress-bench", @@ -282,6 +283,7 @@ vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-feat vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false } vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false } vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false } +vortex-turboquant = { version = "0.1.0", path = "./encodings/turboquant", default-features = false } vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false } vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false } vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false } diff --git a/encodings/turboquant/Cargo.toml b/encodings/turboquant/Cargo.toml new file mode 100644 index 00000000000..4a93be69df3 --- /dev/null +++ b/encodings/turboquant/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "vortex-turboquant" +authors = { workspace = true } +categories = { workspace = true } +description = "Vortex TurboQuant vector quantization encoding" +edition = { workspace = true } +homepage = { workspace = true } +include = { workspace = true } +keywords = { workspace = true } +license = { workspace = true } +readme = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +prost = { workspace = true } +rand = { workspace = true } +vortex-array = { workspace = true } +vortex-buffer = { workspace = true } +vortex-error = { workspace = true } +vortex-fastlanes = { workspace = true } +vortex-session = { workspace = true } +vortex-utils = { workspace = true } + +[dev-dependencies] +rand_distr = { workspace = true } +rstest = { workspace = true } +vortex-array = { workspace = true, features = ["_test-harness"] } diff --git a/encodings/turboquant/public-api.lock b/encodings/turboquant/public-api.lock new file mode 100644 index 00000000000..0d0c6018435 --- /dev/null +++ b/encodings/turboquant/public-api.lock @@ -0,0 +1,343 @@ +pub mod vortex_turboquant + +pub mod vortex_turboquant::centroids + +pub fn vortex_turboquant::centroids::compute_boundaries(centroids: &[f32]) -> alloc::vec::Vec + +pub fn vortex_turboquant::centroids::find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 + +pub fn vortex_turboquant::centroids::get_centroids(dimension: u32, bit_width: u8) -> vortex_error::VortexResult> + +pub mod vortex_turboquant::rotation + +pub struct vortex_turboquant::rotation::RotationMatrix + +impl vortex_turboquant::rotation::RotationMatrix + +pub fn vortex_turboquant::rotation::RotationMatrix::dimension(&self) -> usize + +pub fn vortex_turboquant::rotation::RotationMatrix::export_inverse_signs_bool_array(&self) -> vortex_array::arrays::bool::array::BoolArray + +pub fn vortex_turboquant::rotation::RotationMatrix::from_bool_array(signs_array: &vortex_array::arrays::bool::array::BoolArray, dim: usize) -> vortex_error::VortexResult + +pub fn vortex_turboquant::rotation::RotationMatrix::inverse_rotate(&self, input: &[f32], output: &mut [f32]) + +pub fn vortex_turboquant::rotation::RotationMatrix::norm_factor(&self) -> f32 + +pub fn vortex_turboquant::rotation::RotationMatrix::padded_dim(&self) -> usize + +pub fn vortex_turboquant::rotation::RotationMatrix::rotate(&self, input: &[f32], output: &mut [f32]) + +pub fn vortex_turboquant::rotation::RotationMatrix::try_new(seed: u64, dimension: usize) -> vortex_error::VortexResult + +pub struct vortex_turboquant::TurboQuantConfig + +pub vortex_turboquant::TurboQuantConfig::bit_width: u8 + +pub vortex_turboquant::TurboQuantConfig::seed: core::option::Option + +impl core::clone::Clone for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::clone(&self) -> vortex_turboquant::TurboQuantConfig + +impl core::fmt::Debug for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub struct vortex_turboquant::TurboQuantMSE + +impl vortex_turboquant::TurboQuantMSE + +pub const vortex_turboquant::TurboQuantMSE::ID: vortex_array::vtable::dyn_::ArrayId + +impl core::clone::Clone for vortex_turboquant::TurboQuantMSE + +pub fn vortex_turboquant::TurboQuantMSE::clone(&self) -> vortex_turboquant::TurboQuantMSE + +impl core::fmt::Debug for vortex_turboquant::TurboQuantMSE + +pub fn vortex_turboquant::TurboQuantMSE::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::vtable::VTable for vortex_turboquant::TurboQuantMSE + +pub type vortex_turboquant::TurboQuantMSE::Array = vortex_turboquant::TurboQuantMSEArray + +pub type vortex_turboquant::TurboQuantMSE::Metadata = vortex_array::metadata::ProstMetadata + +pub type vortex_turboquant::TurboQuantMSE::OperationsVTable = vortex_array::vtable::NotSupported + +pub type vortex_turboquant::TurboQuantMSE::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromChild + +pub fn vortex_turboquant::TurboQuantMSE::array_eq(array: &vortex_turboquant::TurboQuantMSEArray, other: &vortex_turboquant::TurboQuantMSEArray, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_turboquant::TurboQuantMSE::array_hash(array: &vortex_turboquant::TurboQuantMSEArray, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_turboquant::TurboQuantMSE::buffer(_array: &vortex_turboquant::TurboQuantMSEArray, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_turboquant::TurboQuantMSE::buffer_name(_array: &vortex_turboquant::TurboQuantMSEArray, _idx: usize) -> core::option::Option + +pub fn vortex_turboquant::TurboQuantMSE::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantMSE::child(array: &vortex_turboquant::TurboQuantMSEArray, idx: usize) -> vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantMSE::child_name(_array: &vortex_turboquant::TurboQuantMSEArray, idx: usize) -> alloc::string::String + +pub fn vortex_turboquant::TurboQuantMSE::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantMSE::dtype(array: &vortex_turboquant::TurboQuantMSEArray) -> &vortex_array::dtype::DType + +pub fn vortex_turboquant::TurboQuantMSE::execute(array: alloc::sync::Arc>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantMSE::id(&self) -> vortex_array::vtable::dyn_::ArrayId + +pub fn vortex_turboquant::TurboQuantMSE::len(array: &vortex_turboquant::TurboQuantMSEArray) -> usize + +pub fn vortex_turboquant::TurboQuantMSE::metadata(array: &vortex_turboquant::TurboQuantMSEArray) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantMSE::nbuffers(_array: &vortex_turboquant::TurboQuantMSEArray) -> usize + +pub fn vortex_turboquant::TurboQuantMSE::nchildren(_array: &vortex_turboquant::TurboQuantMSEArray) -> usize + +pub fn vortex_turboquant::TurboQuantMSE::serialize(metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_turboquant::TurboQuantMSE::stats(array: &vortex_turboquant::TurboQuantMSEArray) -> vortex_array::stats::array::StatsSetRef<'_> + +pub fn vortex_turboquant::TurboQuantMSE::vtable(_array: &Self::Array) -> &Self + +pub fn vortex_turboquant::TurboQuantMSE::with_children(array: &mut Self::Array, children: alloc::vec::Vec) -> vortex_error::VortexResult<()> + +impl vortex_array::vtable::validity::ValidityChild for vortex_turboquant::TurboQuantMSE + +pub fn vortex_turboquant::TurboQuantMSE::validity_child(array: &vortex_turboquant::TurboQuantMSEArray) -> &vortex_array::array::ArrayRef + +pub struct vortex_turboquant::TurboQuantMSEArray + +impl vortex_turboquant::TurboQuantMSEArray + +pub fn vortex_turboquant::TurboQuantMSEArray::bit_width(&self) -> u8 + +pub fn vortex_turboquant::TurboQuantMSEArray::centroids(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantMSEArray::codes(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantMSEArray::dimension(&self) -> u32 + +pub fn vortex_turboquant::TurboQuantMSEArray::norms(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantMSEArray::padded_dim(&self) -> u32 + +pub fn vortex_turboquant::TurboQuantMSEArray::rotation_seed(&self) -> u64 + +pub fn vortex_turboquant::TurboQuantMSEArray::rotation_signs(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantMSEArray::try_new(dtype: vortex_array::dtype::DType, codes: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef, centroids: vortex_array::array::ArrayRef, rotation_signs: vortex_array::array::ArrayRef, dimension: u32, bit_width: u8, padded_dim: u32, rotation_seed: u64) -> vortex_error::VortexResult + +impl vortex_turboquant::TurboQuantMSEArray + +pub fn vortex_turboquant::TurboQuantMSEArray::to_array(&self) -> vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_turboquant::TurboQuantMSEArray + +pub fn vortex_turboquant::TurboQuantMSEArray::clone(&self) -> vortex_turboquant::TurboQuantMSEArray + +impl core::convert::AsRef for vortex_turboquant::TurboQuantMSEArray + +pub fn vortex_turboquant::TurboQuantMSEArray::as_ref(&self) -> &dyn vortex_array::array::DynArray + +impl core::convert::From for vortex_array::array::ArrayRef + +pub fn vortex_array::array::ArrayRef::from(value: vortex_turboquant::TurboQuantMSEArray) -> vortex_array::array::ArrayRef + +impl core::fmt::Debug for vortex_turboquant::TurboQuantMSEArray + +pub fn vortex_turboquant::TurboQuantMSEArray::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::ops::deref::Deref for vortex_turboquant::TurboQuantMSEArray + +pub type vortex_turboquant::TurboQuantMSEArray::Target = dyn vortex_array::array::DynArray + +pub fn vortex_turboquant::TurboQuantMSEArray::deref(&self) -> &Self::Target + +impl vortex_array::array::IntoArray for vortex_turboquant::TurboQuantMSEArray + +pub fn vortex_turboquant::TurboQuantMSEArray::into_array(self) -> vortex_array::array::ArrayRef + +pub struct vortex_turboquant::TurboQuantMSEMetadata + +pub vortex_turboquant::TurboQuantMSEMetadata::bit_width: u32 + +pub vortex_turboquant::TurboQuantMSEMetadata::dimension: u32 + +pub vortex_turboquant::TurboQuantMSEMetadata::padded_dim: u32 + +pub vortex_turboquant::TurboQuantMSEMetadata::rotation_seed: u64 + +impl core::clone::Clone for vortex_turboquant::TurboQuantMSEMetadata + +pub fn vortex_turboquant::TurboQuantMSEMetadata::clone(&self) -> vortex_turboquant::TurboQuantMSEMetadata + +impl core::default::Default for vortex_turboquant::TurboQuantMSEMetadata + +pub fn vortex_turboquant::TurboQuantMSEMetadata::default() -> Self + +impl core::fmt::Debug for vortex_turboquant::TurboQuantMSEMetadata + +pub fn vortex_turboquant::TurboQuantMSEMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl prost::message::Message for vortex_turboquant::TurboQuantMSEMetadata + +pub fn vortex_turboquant::TurboQuantMSEMetadata::clear(&mut self) + +pub fn vortex_turboquant::TurboQuantMSEMetadata::encoded_len(&self) -> usize + +pub struct vortex_turboquant::TurboQuantQJL + +impl vortex_turboquant::TurboQuantQJL + +pub const vortex_turboquant::TurboQuantQJL::ID: vortex_array::vtable::dyn_::ArrayId + +impl core::clone::Clone for vortex_turboquant::TurboQuantQJL + +pub fn vortex_turboquant::TurboQuantQJL::clone(&self) -> vortex_turboquant::TurboQuantQJL + +impl core::fmt::Debug for vortex_turboquant::TurboQuantQJL + +pub fn vortex_turboquant::TurboQuantQJL::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::vtable::VTable for vortex_turboquant::TurboQuantQJL + +pub type vortex_turboquant::TurboQuantQJL::Array = vortex_turboquant::TurboQuantQJLArray + +pub type vortex_turboquant::TurboQuantQJL::Metadata = vortex_array::metadata::ProstMetadata + +pub type vortex_turboquant::TurboQuantQJL::OperationsVTable = vortex_array::vtable::NotSupported + +pub type vortex_turboquant::TurboQuantQJL::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromChild + +pub fn vortex_turboquant::TurboQuantQJL::array_eq(array: &vortex_turboquant::TurboQuantQJLArray, other: &vortex_turboquant::TurboQuantQJLArray, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_turboquant::TurboQuantQJL::array_hash(array: &vortex_turboquant::TurboQuantQJLArray, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_turboquant::TurboQuantQJL::buffer(_array: &vortex_turboquant::TurboQuantQJLArray, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_turboquant::TurboQuantQJL::buffer_name(_array: &vortex_turboquant::TurboQuantQJLArray, _idx: usize) -> core::option::Option + +pub fn vortex_turboquant::TurboQuantQJL::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantQJL::child(array: &vortex_turboquant::TurboQuantQJLArray, idx: usize) -> vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantQJL::child_name(_array: &vortex_turboquant::TurboQuantQJLArray, idx: usize) -> alloc::string::String + +pub fn vortex_turboquant::TurboQuantQJL::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantQJL::dtype(array: &vortex_turboquant::TurboQuantQJLArray) -> &vortex_array::dtype::DType + +pub fn vortex_turboquant::TurboQuantQJL::execute(array: alloc::sync::Arc>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantQJL::id(&self) -> vortex_array::vtable::dyn_::ArrayId + +pub fn vortex_turboquant::TurboQuantQJL::len(array: &vortex_turboquant::TurboQuantQJLArray) -> usize + +pub fn vortex_turboquant::TurboQuantQJL::metadata(array: &vortex_turboquant::TurboQuantQJLArray) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantQJL::nbuffers(_array: &vortex_turboquant::TurboQuantQJLArray) -> usize + +pub fn vortex_turboquant::TurboQuantQJL::nchildren(_array: &vortex_turboquant::TurboQuantQJLArray) -> usize + +pub fn vortex_turboquant::TurboQuantQJL::serialize(metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_turboquant::TurboQuantQJL::stats(array: &vortex_turboquant::TurboQuantQJLArray) -> vortex_array::stats::array::StatsSetRef<'_> + +pub fn vortex_turboquant::TurboQuantQJL::vtable(_array: &Self::Array) -> &Self + +pub fn vortex_turboquant::TurboQuantQJL::with_children(array: &mut Self::Array, children: alloc::vec::Vec) -> vortex_error::VortexResult<()> + +impl vortex_array::vtable::validity::ValidityChild for vortex_turboquant::TurboQuantQJL + +pub fn vortex_turboquant::TurboQuantQJL::validity_child(array: &vortex_turboquant::TurboQuantQJLArray) -> &vortex_array::array::ArrayRef + +pub struct vortex_turboquant::TurboQuantQJLArray + +impl vortex_turboquant::TurboQuantQJLArray + +pub fn vortex_turboquant::TurboQuantQJLArray::bit_width(&self) -> u8 + +pub fn vortex_turboquant::TurboQuantQJLArray::mse_inner(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantQJLArray::padded_dim(&self) -> u32 + +pub fn vortex_turboquant::TurboQuantQJLArray::qjl_signs(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantQJLArray::residual_norms(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantQJLArray::rotation_seed(&self) -> u64 + +pub fn vortex_turboquant::TurboQuantQJLArray::rotation_signs(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantQJLArray::try_new(dtype: vortex_array::dtype::DType, mse_inner: vortex_array::array::ArrayRef, qjl_signs: vortex_array::array::ArrayRef, residual_norms: vortex_array::array::ArrayRef, rotation_signs: vortex_array::array::ArrayRef, bit_width: u8, padded_dim: u32, rotation_seed: u64) -> vortex_error::VortexResult + +impl vortex_turboquant::TurboQuantQJLArray + +pub fn vortex_turboquant::TurboQuantQJLArray::to_array(&self) -> vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_turboquant::TurboQuantQJLArray + +pub fn vortex_turboquant::TurboQuantQJLArray::clone(&self) -> vortex_turboquant::TurboQuantQJLArray + +impl core::convert::AsRef for vortex_turboquant::TurboQuantQJLArray + +pub fn vortex_turboquant::TurboQuantQJLArray::as_ref(&self) -> &dyn vortex_array::array::DynArray + +impl core::convert::From for vortex_array::array::ArrayRef + +pub fn vortex_array::array::ArrayRef::from(value: vortex_turboquant::TurboQuantQJLArray) -> vortex_array::array::ArrayRef + +impl core::fmt::Debug for vortex_turboquant::TurboQuantQJLArray + +pub fn vortex_turboquant::TurboQuantQJLArray::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::ops::deref::Deref for vortex_turboquant::TurboQuantQJLArray + +pub type vortex_turboquant::TurboQuantQJLArray::Target = dyn vortex_array::array::DynArray + +pub fn vortex_turboquant::TurboQuantQJLArray::deref(&self) -> &Self::Target + +impl vortex_array::array::IntoArray for vortex_turboquant::TurboQuantQJLArray + +pub fn vortex_turboquant::TurboQuantQJLArray::into_array(self) -> vortex_array::array::ArrayRef + +pub struct vortex_turboquant::TurboQuantQJLMetadata + +pub vortex_turboquant::TurboQuantQJLMetadata::bit_width: u32 + +pub vortex_turboquant::TurboQuantQJLMetadata::padded_dim: u32 + +pub vortex_turboquant::TurboQuantQJLMetadata::rotation_seed: u64 + +impl core::clone::Clone for vortex_turboquant::TurboQuantQJLMetadata + +pub fn vortex_turboquant::TurboQuantQJLMetadata::clone(&self) -> vortex_turboquant::TurboQuantQJLMetadata + +impl core::default::Default for vortex_turboquant::TurboQuantQJLMetadata + +pub fn vortex_turboquant::TurboQuantQJLMetadata::default() -> Self + +impl core::fmt::Debug for vortex_turboquant::TurboQuantQJLMetadata + +pub fn vortex_turboquant::TurboQuantQJLMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl prost::message::Message for vortex_turboquant::TurboQuantQJLMetadata + +pub fn vortex_turboquant::TurboQuantQJLMetadata::clear(&mut self) + +pub fn vortex_turboquant::TurboQuantQJLMetadata::encoded_len(&self) -> usize + +pub const vortex_turboquant::FIXED_SHAPE_TENSOR_EXT_ID: &str + +pub const vortex_turboquant::VECTOR_EXT_ID: &str + +pub fn vortex_turboquant::initialize(session: &mut vortex_session::VortexSession) + +pub fn vortex_turboquant::turboquant_encode_mse(fsl: &vortex_array::arrays::fixed_size_list::array::FixedSizeListArray, config: &vortex_turboquant::TurboQuantConfig) -> vortex_error::VortexResult + +pub fn vortex_turboquant::turboquant_encode_qjl(fsl: &vortex_array::arrays::fixed_size_list::array::FixedSizeListArray, config: &vortex_turboquant::TurboQuantConfig) -> vortex_error::VortexResult diff --git a/encodings/turboquant/src/centroids.rs b/encodings/turboquant/src/centroids.rs new file mode 100644 index 00000000000..6d316aeff75 --- /dev/null +++ b/encodings/turboquant/src/centroids.rs @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. +//! +//! Pre-computes optimal scalar quantizer centroids for the marginal distribution of coordinates +//! after random rotation of a unit-norm vector. In high dimensions, each coordinate of a randomly +//! rotated unit vector follows a distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, +//! which converges to `N(0, 1/d)`. The Max-Lloyd algorithm finds optimal quantization centroids +//! that minimize MSE for this distribution. + +use std::sync::LazyLock; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_utils::aliases::dash_map::DashMap; + +/// Number of numerical integration points for computing conditional expectations. +const INTEGRATION_POINTS: usize = 1000; + +/// Max-Lloyd convergence threshold. +const CONVERGENCE_EPSILON: f64 = 1e-12; + +/// Maximum iterations for Max-Lloyd algorithm. +const MAX_ITERATIONS: usize = 200; + +/// Global centroid cache keyed by (dimension, bit_width). +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); + +/// Get or compute cached centroids for the given dimension and bit width. +/// +/// Returns `2^bit_width` centroids sorted in ascending order, representing +/// optimal scalar quantization levels for the coordinate distribution after +/// random rotation in `dimension`-dimensional space. +pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { + if !(1..=8).contains(&bit_width) { + vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}"); + } + if dimension < 2 { + vortex_bail!("TurboQuant dimension must be >= 2, got {dimension}"); + } + + if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { + return Ok(centroids.clone()); + } + + let centroids = max_lloyd_centroids(dimension, bit_width); + CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + Ok(centroids) +} + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// +/// Operates on the marginal distribution of a single coordinate of a randomly +/// rotated unit vector in d dimensions. The PDF is: +/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` +/// where `C_d` is the normalizing constant. +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec { + let num_centroids = 1usize << bit_width; + let dim = dimension as f64; + + // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. + let exponent = (dim - 3.0) / 2.0; + + // Initialize centroids uniformly on [-1, 1]. + let mut centroids: Vec = (0..num_centroids) + .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .collect(); + + for _ in 0..MAX_ITERATIONS { + // Compute decision boundaries (midpoints between adjacent centroids). + let mut boundaries = Vec::with_capacity(num_centroids + 1); + boundaries.push(-1.0); + for idx in 0..num_centroids - 1 { + boundaries.push((centroids[idx] + centroids[idx + 1]) / 2.0); + } + boundaries.push(1.0); + + // Update each centroid to the conditional mean within its Voronoi cell. + let mut max_change = 0.0f64; + for idx in 0..num_centroids { + let lo = boundaries[idx]; + let hi = boundaries[idx + 1]; + let new_centroid = conditional_mean(lo, hi, exponent); + max_change = max_change.max((new_centroid - centroids[idx]).abs()); + centroids[idx] = new_centroid; + } + + if max_change < CONVERGENCE_EPSILON { + break; + } + } + + #[allow(clippy::cast_possible_truncation)] + centroids.iter().map(|&val| val as f32).collect() +} + +/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. +/// +/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` +/// on [-1, 1]. +fn conditional_mean(lo: f64, hi: f64, exponent: f64) -> f64 { + if (hi - lo).abs() < 1e-15 { + return (lo + hi) / 2.0; + } + + let num_points = INTEGRATION_POINTS; + let dx = (hi - lo) / num_points as f64; + + let mut numerator = 0.0; + let mut denominator = 0.0; + + for step in 0..=num_points { + let x_val = lo + (step as f64) * dx; + let weight = pdf_unnormalized(x_val, exponent); + + let trap_weight = if step == 0 || step == num_points { + 0.5 + } else { + 1.0 + }; + + numerator += trap_weight * x_val * weight; + denominator += trap_weight * weight; + } + + if denominator.abs() < 1e-30 { + (lo + hi) / 2.0 + } else { + numerator / denominator + } +} + +/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. +#[inline] +fn pdf_unnormalized(x_val: f64, exponent: f64) -> f64 { + (1.0 - x_val * x_val).max(0.0).powf(exponent) +} + +/// Precompute decision boundaries (midpoints between adjacent centroids). +/// +/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps +/// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, +/// and a value >= `boundaries[k-2]` maps to centroid `k-1`. +pub fn compute_boundaries(centroids: &[f32]) -> Vec { + centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() +} + +/// Find the index of the nearest centroid using precomputed decision boundaries. +/// +/// `boundaries` must be the output of [`compute_boundaries`] for the corresponding +/// centroids. Uses binary search on the midpoints, avoiding distance comparisons +/// in the inner loop. +#[inline] +#[allow(clippy::cast_possible_truncation)] +pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { + debug_assert!( + boundaries.windows(2).all(|w| w[0] <= w[1]), + "boundaries must be sorted" + ); + boundaries.partition_point(|&b| b < value) as u8 +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[rstest] + #[case(128, 1, 2)] + #[case(128, 2, 4)] + #[case(128, 3, 8)] + #[case(128, 4, 16)] + #[case(768, 2, 4)] + #[case(1536, 3, 8)] + fn centroids_have_correct_count( + #[case] dim: u32, + #[case] bits: u8, + #[case] expected: usize, + ) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + assert_eq!(centroids.len(), expected); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(768, 2)] + fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for window in centroids.windows(2) { + assert!( + window[0] < window[1], + "centroids not sorted: {:?}", + centroids + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(256, 2)] + #[case(768, 2)] + fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + let count = centroids.len(); + for idx in 0..count / 2 { + let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); + assert!( + diff < 1e-5, + "centroids not symmetric: c[{idx}]={}, c[{}]={}", + centroids[idx], + count - 1 - idx, + centroids[count - 1 - idx] + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 4)] + fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for &val in ¢roids { + assert!( + (-1.0..=1.0).contains(&val), + "centroid out of [-1, 1]: {val}", + ); + } + Ok(()) + } + + #[test] + fn centroids_cached() -> VortexResult<()> { + let c1 = get_centroids(128, 2)?; + let c2 = get_centroids(128, 2)?; + assert_eq!(c1, c2); + Ok(()) + } + + #[test] + fn find_nearest_basic() -> VortexResult<()> { + let centroids = get_centroids(128, 2)?; + let boundaries = compute_boundaries(¢roids); + assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + #[allow(clippy::cast_possible_truncation)] + let last_idx = (centroids.len() - 1) as u8; + assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + for (idx, &cv) in centroids.iter().enumerate() { + #[allow(clippy::cast_possible_truncation)] + let expected = idx as u8; + assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + } + Ok(()) + } + + #[test] + fn rejects_invalid_params() { + assert!(get_centroids(128, 0).is_err()); + assert!(get_centroids(128, 9).is_err()); + assert!(get_centroids(1, 2).is_err()); + } +} diff --git a/encodings/turboquant/src/compress.rs b/encodings/turboquant/src/compress.rs new file mode 100644 index 00000000000..b1a31c54cf4 --- /dev/null +++ b/encodings/turboquant/src/compress.rs @@ -0,0 +1,350 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant encoding (quantization) logic. + +use vortex_array::IntoArray; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::BitBufferMut; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_fastlanes::bitpack_compress::bitpack_encode; + +use crate::centroids::compute_boundaries; +use crate::centroids::find_nearest_centroid; +use crate::centroids::get_centroids; +use crate::mse::array::TurboQuantMSEArray; +use crate::qjl::array::TurboQuantQJLArray; +use crate::rotation::RotationMatrix; + +/// Configuration for TurboQuant encoding. +#[derive(Clone, Debug)] +pub struct TurboQuantConfig { + /// Bits per coordinate. + /// + /// For MSE encoding: 1-8. + /// For QJL encoding: 2-9 (the MSE inner uses `bit_width - 1`). + pub bit_width: u8, + /// Optional seed for the rotation matrix. If None, a random seed is generated. + pub seed: Option, +} + +/// Extract elements from a FixedSizeListArray as a flat f32 vec. +#[allow(clippy::cast_possible_truncation)] +fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult> { + let elements = fsl.elements(); + let primitive = elements.to_canonical()?.into_primitive(); + let ptype = primitive.ptype(); + + match ptype { + PType::F32 => Ok(primitive.as_slice::().to_vec()), + PType::F64 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| v as f32) + .collect()), + _ => vortex_bail!("TurboQuant requires f32 or f64 elements, got {ptype:?}"), + } +} + +/// Compute the L2 norm of a vector. +#[inline] +fn l2_norm(x: &[f32]) -> f32 { + x.iter().map(|&v| v * v).sum::().sqrt() +} + +/// Encode a FixedSizeListArray into a `TurboQuantMSEArray`. +/// +/// The input must be non-nullable. TurboQuant is a lossy encoding that does not +/// preserve null positions; callers must handle validity externally. +pub fn turboquant_encode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 1 && config.bit_width <= 8, + "MSE bit_width must be 1-8, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 2, + "TurboQuant requires dimension >= 2, got {dimension}" + ); + + let seed = config.seed.unwrap_or_else(rand::random); + let dim = dimension as usize; + let num_rows = fsl.len(); + + let rotation = RotationMatrix::try_new(seed, dim)?; + let padded_dim = rotation.padded_dim(); + + if num_rows == 0 { + return build_empty_mse_array(fsl, config.bit_width, padded_dim, seed); + } + + let f32_elements = extract_f32_elements(fsl)?; + #[allow(clippy::cast_possible_truncation)] + let centroids = get_centroids(padded_dim as u32, config.bit_width)?; + let boundaries = compute_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut norms_buf = BufferMut::::with_capacity(num_rows); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let x = &f32_elements[row * dim..(row + 1) * dim]; + let norm = l2_norm(x); + norms_buf.push(norm); + + // Normalize and write into [..dim]; tail [dim..padded_dim] stays zero + // from initialization and is never overwritten. + if norm > 0.0 { + let inv_norm = 1.0 / norm; + for (dst, &src) in padded[..dim].iter_mut().zip(x.iter()) { + *dst = src * inv_norm; + } + } else { + padded[..dim].fill(0.0); + } + rotation.rotate(&padded, &mut rotated); + + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + // Pack indices: bitpack for 1-7 bits, store raw u8 for 8 bits. + let indices_array = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); + let codes = if config.bit_width < 8 { + bitpack_encode(&indices_array, config.bit_width, None)?.into_array() + } else { + indices_array.into_array() + }; + + let norms_array = PrimitiveArray::new::(norms_buf.freeze(), Validity::NonNullable); + + // Store centroids as a child array. + let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); + centroids_buf.extend_from_slice(¢roids); + let centroids_array = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); + + // Store rotation signs as a BoolArray child. + let rotation_signs = rotation.export_inverse_signs_bool_array(); + + #[allow(clippy::cast_possible_truncation)] + TurboQuantMSEArray::try_new( + fsl.dtype().clone(), + codes, + norms_array.into_array(), + centroids_array.into_array(), + rotation_signs.into_array(), + dimension, + config.bit_width, + padded_dim as u32, + seed, + ) +} + +/// Encode a FixedSizeListArray into a `TurboQuantQJLArray`. +/// +/// Produces a cascaded structure: QJLArray wrapping an MSEArray at `bit_width - 1`. +/// The input must be non-nullable. TurboQuant is a lossy encoding that does not +/// preserve null positions; callers must handle validity externally. +pub fn turboquant_encode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 2 && config.bit_width <= 9, + "QJL bit_width must be 2-9, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 2, + "TurboQuant requires dimension >= 2, got {dimension}" + ); + + let seed = config.seed.unwrap_or_else(rand::random); + let dim = dimension as usize; + let num_rows = fsl.len(); + let mse_bit_width = config.bit_width - 1; + + // First, encode the MSE inner at (bit_width - 1). + let mse_config = TurboQuantConfig { + bit_width: mse_bit_width, + seed: Some(seed), + }; + let mse_inner = turboquant_encode_mse(fsl, &mse_config)?; + + let rotation = RotationMatrix::try_new(seed, dim)?; + let padded_dim = rotation.padded_dim(); + + if num_rows == 0 { + return build_empty_qjl_array(fsl, config.bit_width, padded_dim, seed); + } + + // TODO(perf): `turboquant_encode_mse` above already extracts f32 elements + // internally. Refactor to share the buffer to avoid double materialization. + let f32_elements = extract_f32_elements(fsl)?; + #[allow(clippy::cast_possible_truncation)] + let centroids = get_centroids(padded_dim as u32, mse_bit_width)?; + let boundaries = compute_boundaries(¢roids); + + // QJL uses a different rotation than the MSE stage to ensure statistical + // independence between the quantization noise and the sign projection. + let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(1), dim)?; + + let mut residual_norms_buf = BufferMut::::with_capacity(num_rows); + let total_sign_bits = num_rows * padded_dim; + let mut qjl_sign_bits = BitBufferMut::new_unset(total_sign_bits); + + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + let mut dequantized_rotated = vec![0.0f32; padded_dim]; + let mut dequantized = vec![0.0f32; padded_dim]; + let mut residual = vec![0.0f32; padded_dim]; + let mut projected = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let x = &f32_elements[row * dim..(row + 1) * dim]; + let norm = l2_norm(x); + + // Reproduce the same quantization as MSE encoding. + if norm > 0.0 { + let inv_norm = 1.0 / norm; + for (dst, &src) in padded[..dim].iter_mut().zip(x.iter()) { + *dst = src * inv_norm; + } + } else { + padded[..dim].fill(0.0); + } + rotation.rotate(&padded, &mut rotated); + + for j in 0..padded_dim { + let idx = find_nearest_centroid(rotated[j], &boundaries); + dequantized_rotated[j] = centroids[idx as usize]; + } + + rotation.inverse_rotate(&dequantized_rotated, &mut dequantized); + if norm > 0.0 { + for val in dequantized.iter_mut() { + *val *= norm; + } + } + + // Compute residual. + residual.fill(0.0); + for j in 0..dim { + residual[j] = x[j] - dequantized[j]; + } + let residual_norm = l2_norm(&residual[..dim]); + residual_norms_buf.push(residual_norm); + + // QJL: sign(S * r). + projected.fill(0.0); + if residual_norm > 0.0 { + qjl_rotation.rotate(&residual, &mut projected); + } + + let bit_offset = row * padded_dim; + for j in 0..padded_dim { + if projected[j] >= 0.0 { + qjl_sign_bits.set(bit_offset + j); + } + } + } + + let residual_norms_array = + PrimitiveArray::new::(residual_norms_buf.freeze(), Validity::NonNullable); + let qjl_signs = BoolArray::new(qjl_sign_bits.freeze(), Validity::NonNullable); + let qjl_rotation_signs = qjl_rotation.export_inverse_signs_bool_array(); + + #[allow(clippy::cast_possible_truncation)] + TurboQuantQJLArray::try_new( + fsl.dtype().clone(), + mse_inner.into_array(), + qjl_signs.into_array(), + residual_norms_array.into_array(), + qjl_rotation_signs.into_array(), + config.bit_width, + padded_dim as u32, + seed.wrapping_add(1), + ) +} + +fn build_empty_mse_array( + fsl: &FixedSizeListArray, + bit_width: u8, + padded_dim: usize, + seed: u64, +) -> VortexResult { + let rotation = RotationMatrix::try_new(seed, fsl.list_size() as usize)?; + let codes = PrimitiveArray::empty::(fsl.dtype().nullability()); + let norms = PrimitiveArray::empty::(fsl.dtype().nullability()); + #[allow(clippy::cast_possible_truncation)] + let centroids_vec = get_centroids(padded_dim as u32, bit_width)?; + let mut centroids_buf = BufferMut::::with_capacity(centroids_vec.len()); + centroids_buf.extend_from_slice(¢roids_vec); + let centroids = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); + let rotation_signs = rotation.export_inverse_signs_bool_array(); + + #[allow(clippy::cast_possible_truncation)] + TurboQuantMSEArray::try_new( + fsl.dtype().clone(), + codes.into_array(), + norms.into_array(), + centroids.into_array(), + rotation_signs.into_array(), + fsl.list_size(), + bit_width, + padded_dim as u32, + seed, + ) +} + +fn build_empty_qjl_array( + fsl: &FixedSizeListArray, + bit_width: u8, + padded_dim: usize, + seed: u64, +) -> VortexResult { + let mse_config = TurboQuantConfig { + bit_width: bit_width - 1, + seed: Some(seed), + }; + let mse_inner = turboquant_encode_mse(fsl, &mse_config)?; + let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(1), fsl.list_size() as usize)?; + let residual_norms = PrimitiveArray::empty::(fsl.dtype().nullability()); + let qjl_signs = BoolArray::new(BitBufferMut::new_unset(0).freeze(), Validity::NonNullable); + let qjl_rotation_signs = qjl_rotation.export_inverse_signs_bool_array(); + + #[allow(clippy::cast_possible_truncation)] + TurboQuantQJLArray::try_new( + fsl.dtype().clone(), + mse_inner.into_array(), + qjl_signs.into_array(), + residual_norms.into_array(), + qjl_rotation_signs.into_array(), + bit_width, + padded_dim as u32, + seed.wrapping_add(1), + ) +} diff --git a/encodings/turboquant/src/decompress.rs b/encodings/turboquant/src/decompress.rs new file mode 100644 index 00000000000..e2a92d45f0c --- /dev/null +++ b/encodings/turboquant/src/decompress.rs @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant decoding (dequantization) logic. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +use crate::mse::array::TurboQuantMSEArray; +use crate::qjl::array::TurboQuantQJLArray; +use crate::rotation::RotationMatrix; + +/// QJL correction scale factor: `sqrt(π/2) / padded_dim`. +/// +/// Accounts for the SRHT normalization (`1/padded_dim^{3/2}` per transform) +/// combined with `E[|z|] = sqrt(2/π)` for half-normal sign expectations. +/// Verified empirically via the `qjl_inner_product_bias` test suite. +#[inline] +fn qjl_correction_scale(padded_dim: usize) -> f32 { + (std::f32::consts::FRAC_PI_2).sqrt() / (padded_dim as f32) +} + +/// Decompress a `TurboQuantMSEArray` into a `FixedSizeListArray` of floats. +/// +/// Reads stored centroids and rotation signs from the array's children, +/// avoiding any recomputation. +pub fn execute_decompress_mse( + array: TurboQuantMSEArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dim = array.dimension() as usize; + let padded_dim = array.padded_dim() as usize; + let num_rows = array.norms.len(); + + if num_rows == 0 { + let elements = PrimitiveArray::empty::(array.dtype.nullability()); + return Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + 0, + )? + .into_array()); + } + + // Read stored centroids — no recomputation. + let centroids_prim = array.centroids.clone().execute::(ctx)?; + let centroids = centroids_prim.as_slice::(); + + // Expand stored rotation signs into f32 ±1.0 vectors once (amortized over all rows). + // This costs 3 × padded_dim × 4 bytes of temporary memory (e.g. 12KB for dim=1024) + // but enables autovectorized f32 multiply in the per-row SRHT hot loop. + let signs_bool = array.rotation_signs.clone().execute::(ctx)?; + let rotation = RotationMatrix::from_bool_array(&signs_bool, dim)?; + + // Unpack codes. + let codes_prim = array.codes.clone().execute::(ctx)?; + let indices = codes_prim.as_slice::(); + + let norms_prim = array.norms.clone().execute::(ctx)?; + let norms = norms_prim.as_slice::(); + + let mut output = BufferMut::::with_capacity(num_rows * dim); + let mut dequantized = vec![0.0f32; padded_dim]; + let mut unrotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim]; + let norm = norms[row]; + + for idx in 0..padded_dim { + dequantized[idx] = centroids[row_indices[idx] as usize]; + } + + rotation.inverse_rotate(&dequantized, &mut unrotated); + + for idx in 0..dim { + unrotated[idx] *= norm; + } + + output.extend_from_slice(&unrotated[..dim]); + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )? + .into_array()) +} + +/// Decompress a `TurboQuantQJLArray` into a `FixedSizeListArray` of floats. +/// +/// First decodes the inner MSE array, then applies QJL residual correction. +pub fn execute_decompress_qjl( + array: TurboQuantQJLArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let padded_dim = array.padded_dim() as usize; + let num_rows = array.residual_norms.len(); + + if num_rows == 0 { + return Ok(array + .mse_inner + .execute::(ctx)? + .into_array()); + } + + // Decode MSE inner → FixedSizeListArray. + let mse_decoded = array.mse_inner.clone().execute::(ctx)?; + let mse_elements_prim = mse_decoded.elements().to_canonical()?.into_primitive(); + let mse_elements = mse_elements_prim.as_slice::(); + let dim = mse_decoded.list_size() as usize; + + // Read QJL signs. + let qjl_signs_bool = array.qjl_signs.clone().execute::(ctx)?; + let qjl_bit_buf = qjl_signs_bool.to_bit_buffer(); + + // Read residual norms. + let residual_norms_prim = array + .residual_norms + .clone() + .execute::(ctx)?; + let residual_norms = residual_norms_prim.as_slice::(); + + // Read QJL rotation signs and reconstruct the rotation matrix. + let qjl_rot_signs_bool = array.rotation_signs.clone().execute::(ctx)?; + let qjl_rot = RotationMatrix::from_bool_array(&qjl_rot_signs_bool, dim)?; + + let qjl_scale = qjl_correction_scale(padded_dim); + + let mut output = BufferMut::::with_capacity(num_rows * dim); + let mut qjl_signs_vec = vec![0.0f32; padded_dim]; + let mut qjl_projected = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let mse_row = &mse_elements[row * dim..(row + 1) * dim]; + let residual_norm = residual_norms[row]; + + let bit_offset = row * padded_dim; + for idx in 0..padded_dim { + qjl_signs_vec[idx] = if qjl_bit_buf.value(bit_offset + idx) { + 1.0 + } else { + -1.0 + }; + } + + qjl_rot.inverse_rotate(&qjl_signs_vec, &mut qjl_projected); + let scale = qjl_scale * residual_norm; + + for idx in 0..dim { + output.push(mse_row[idx] + scale * qjl_projected[idx]); + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + Ok(FixedSizeListArray::try_new( + elements.into_array(), + mse_decoded.list_size(), + Validity::NonNullable, + num_rows, + )? + .into_array()) +} diff --git a/encodings/turboquant/src/lib.rs b/encodings/turboquant/src/lib.rs new file mode 100644 index 00000000000..7a845ff175e --- /dev/null +++ b/encodings/turboquant/src/lib.rs @@ -0,0 +1,648 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant vector quantization encoding for Vortex. +//! +//! Implements the TurboQuant algorithm ([arXiv:2504.19874]) for lossy compression of +//! high-dimensional vector data. The encoding operates on `FixedSizeList` arrays of floats +//! (the storage format of `Vector` and `FixedShapeTensor` extension types). +//! +//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! +//! # Variants +//! +//! - **MSE** (`TurboQuantVariant::Mse`): Minimizes mean-squared reconstruction error +//! (1-8 bits per coordinate). +//! - **Prod** (`TurboQuantVariant::Prod`): Preserves inner products with an unbiased +//! estimator (uses `b-1` bits for MSE + 1-bit QJL residual correction, 2-9 bits). +//! At `b=9`, the MSE codes are raw int8 values suitable for direct use with +//! tensor core int8 GEMM kernels. +//! +//! # Theoretical error bounds +//! +//! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1 +//! guarantees normalized MSE distortion: +//! +//! > `E[||x - x̂||² / ||x||²] ≤ (√3 · π / 2) / 4^b` +//! +//! | Bits | MSE bound | Quality | +//! |------|------------|-------------------| +//! | 1 | 6.80e-01 | Poor | +//! | 2 | 1.70e-01 | Usable for ANN | +//! | 3 | 4.25e-02 | Good | +//! | 4 | 1.06e-02 | Very good | +//! | 5 | 2.66e-03 | Excellent | +//! | 6 | 6.64e-04 | Near-lossless | +//! | 7 | 1.66e-04 | Near-lossless | +//! | 8 | 4.15e-05 | Near-lossless | +//! +//! # Compression ratios +//! +//! Each vector is stored as `padded_dim × bit_width / 8` bytes of quantized codes plus a +//! 4-byte f32 norm. Non-power-of-2 dimensions are padded to the next power of 2 for the +//! Walsh-Hadamard transform, which reduces the effective ratio for those dimensions. +//! +//! | dim | padded | bits | f32 bytes | TQ bytes | ratio | +//! |------|--------|------|-----------|----------|--------| +//! | 768 | 1024 | 2 | 3072 | 260 | 11.8x | +//! | 1024 | 1024 | 2 | 4096 | 260 | 15.8x | +//! | 768 | 1024 | 4 | 3072 | 516 | 6.0x | +//! | 1024 | 1024 | 4 | 4096 | 516 | 7.9x | +//! | 768 | 1024 | 8 | 3072 | 1028 | 3.0x | +//! | 1024 | 1024 | 8 | 4096 | 1028 | 4.0x | +//! +//! # Example +//! +//! ``` +//! use vortex_array::IntoArray; +//! use vortex_array::arrays::FixedSizeListArray; +//! use vortex_array::arrays::PrimitiveArray; +//! use vortex_array::validity::Validity; +//! use vortex_buffer::BufferMut; +//! use vortex_turboquant::{TurboQuantConfig, turboquant_encode_mse}; +//! +//! // Create a FixedSizeListArray of 100 random 128-d vectors. +//! let num_rows = 100; +//! let dim = 128; +//! let mut buf = BufferMut::::with_capacity(num_rows * dim); +//! for i in 0..(num_rows * dim) { +//! buf.push((i as f32 * 0.001).sin()); +//! } +//! let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); +//! let fsl = FixedSizeListArray::try_new( +//! elements.into_array(), dim as u32, Validity::NonNullable, num_rows, +//! ).unwrap(); +//! +//! // Quantize at 2 bits per coordinate using MSE-optimal encoding. +//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; +//! let encoded = turboquant_encode_mse(&fsl, &config).unwrap(); +//! +//! // Verify compression: 100 vectors × 128 dims × 4 bytes = 51200 bytes input. +//! assert!(encoded.codes().nbytes() + encoded.norms().nbytes() < 51200); +//! ``` + +pub use compress::TurboQuantConfig; +pub use compress::turboquant_encode_mse; +pub use compress::turboquant_encode_qjl; +pub use mse::*; +pub use qjl::*; + +pub mod centroids; +mod compress; +pub(crate) mod decompress; +mod mse; +mod qjl; +pub mod rotation; + +/// Extension ID for the `Vector` type from `vortex-tensor`. +pub const VECTOR_EXT_ID: &str = "vortex.tensor.vector"; + +/// Extension ID for the `FixedShapeTensor` type from `vortex-tensor`. +pub const FIXED_SHAPE_TENSOR_EXT_ID: &str = "vortex.tensor.fixed_shape_tensor"; + +use vortex_array::session::ArraySessionExt; +use vortex_session::VortexSession; + +/// Initialize the TurboQuant encodings in the given session. +pub fn initialize(session: &mut VortexSession) { + session.arrays().register(TurboQuantMSE); + session.arrays().register(TurboQuantQJL); +} + +#[cfg(test)] +#[allow(clippy::cast_possible_truncation)] +mod tests { + use std::sync::LazyLock; + + use rand::RngExt; + use rand::SeedableRng; + use rand::rngs::StdRng; + use rand_distr::Distribution; + use rand_distr::Normal; + use rstest::rstest; + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::session::ArraySession; + use vortex_array::validity::Validity; + use vortex_buffer::BufferMut; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::TurboQuantConfig; + use crate::rotation::RotationMatrix; + use crate::turboquant_encode_mse; + use crate::turboquant_encode_qjl; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + /// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal). + fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + ) + .unwrap() + } + + fn theoretical_mse_bound(bit_width: u8) -> f32 { + let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; + sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) + } + + fn per_vector_normalized_mse( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, + ) -> f32 { + let mut total = 0.0f32; + for row in 0..num_rows { + let orig = &original[row * dim..(row + 1) * dim]; + let recon = &reconstructed[row * dim..(row + 1) * dim]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + total += err_sq / norm_sq; + } + total / num_rows as f32 + } + + /// Encode and decode, returning (original, decoded) flat f32 slices. + fn encode_decode( + fsl: &FixedSizeListArray, + encode_fn: impl FnOnce(&FixedSizeListArray) -> VortexResult, + ) -> VortexResult<(Vec, Vec)> { + let original: Vec = { + let prim = fsl.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + let encoded = encode_fn(fsl)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + let decoded_elements: Vec = { + let prim = decoded.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + Ok((original, decoded_elements)) + } + + fn encode_decode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| { + Ok(turboquant_encode_mse(fsl, &config)?.into_array()) + }) + } + + fn encode_decode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| { + Ok(turboquant_encode_qjl(fsl, &config)?.into_array()) + }) + } + + // ----------------------------------------------------------------------- + // MSE encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 1)] + #[case(32, 2)] + #[case(32, 3)] + #[case(32, 4)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 2)] + fn roundtrip_mse(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 2)] + #[case(256, 4)] + fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + let bound = theoretical_mse_bound(bit_width); + + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + #[rstest] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 6)] + #[case(256, 8)] + fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let (original_4, decoded_4) = encode_decode_mse(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" + ); + assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); + Ok(()) + } + + #[test] + fn mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 1..=8u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 2)] + #[case(32, 3)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + fn roundtrip_qjl(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(456), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + #[rstest] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + #[case(768, 4)] + fn qjl_inner_product_bias(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 100; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let num_pairs = 500; + let mut rng = StdRng::seed_from_u64(0); + let mut signed_errors = Vec::with_capacity(num_pairs); + + for _ in 0..num_pairs { + let qi = rng.random_range(0..num_rows); + let xi = rng.random_range(0..num_rows); + if qi == xi { + continue; + } + + let query = &original[qi * dim..(qi + 1) * dim]; + let orig_vec = &original[xi * dim..(xi + 1) * dim]; + let quant_vec = &decoded[xi * dim..(xi + 1) * dim]; + + let true_ip: f32 = query.iter().zip(orig_vec).map(|(&a, &b)| a * b).sum(); + let quant_ip: f32 = query.iter().zip(quant_vec).map(|(&a, &b)| a * b).sum(); + + if true_ip.abs() > 1e-6 { + signed_errors.push((quant_ip - true_ip) / true_ip.abs()); + } + } + + if signed_errors.is_empty() { + return Ok(()); + } + + let mean_rel_error: f32 = signed_errors.iter().sum::() / signed_errors.len() as f32; + assert!( + mean_rel_error.abs() < 0.3, + "QJL inner product bias too high: {mean_rel_error:.4} for dim={dim}, bits={bit_width}" + ); + Ok(()) + } + + #[test] + fn qjl_mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 2..=9u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "QJL MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) + } + + // ----------------------------------------------------------------------- + // Edge cases + // ----------------------------------------------------------------------- + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_mse_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded + .into_array() + .execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_qjl_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded + .into_array() + .execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[test] + fn mse_rejects_dimension_below_2() { + let fsl = make_fsl_dim1(); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(0), + }; + assert!(turboquant_encode_mse(&fsl, &config).is_err()); + } + + #[test] + fn qjl_rejects_dimension_below_2() { + let fsl = make_fsl_dim1(); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(0), + }; + assert!(turboquant_encode_qjl(&fsl, &config).is_err()); + } + + fn make_fsl_dim1() -> FixedSizeListArray { + let mut buf = BufferMut::::with_capacity(1); + buf.push(1.0); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new(elements.into_array(), 1, Validity::NonNullable, 1).unwrap() + } + + // ----------------------------------------------------------------------- + // Verification tests for stored metadata + // ----------------------------------------------------------------------- + + /// Verify that the centroids stored in the MSE array match what get_centroids() computes. + #[test] + fn stored_centroids_match_computed() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + let mut ctx = SESSION.create_execution_ctx(); + let stored_centroids_prim = encoded + .centroids() + .clone() + .execute::(&mut ctx)?; + let stored = stored_centroids_prim.as_slice::(); + + let padded_dim = encoded.padded_dim(); + let computed = crate::centroids::get_centroids(padded_dim, 3)?; + + assert_eq!(stored.len(), computed.len()); + for i in 0..stored.len() { + assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); + } + Ok(()) + } + + /// Verify that stored rotation signs produce identical decode to seed-based decode. + /// + /// Encodes the same data twice: once with the new path (stored signs), and + /// once by manually recomputing the rotation from the seed. Both should + /// produce identical output. + #[test] + fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + // Decode via the stored-signs path (normal decode). + let mut ctx = SESSION.create_execution_ctx(); + let decoded_fsl = encoded + .clone() + .into_array() + .execute::(&mut ctx)?; + let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_slice = decoded.as_slice::(); + + // Verify stored signs match seed-derived signs. + let rot_from_seed = RotationMatrix::try_new(123, 128)?; + let exported = rot_from_seed.export_inverse_signs_bool_array(); + let stored_signs = encoded + .rotation_signs() + .clone() + .execute::(&mut ctx)?; + + assert_eq!(exported.len(), stored_signs.len()); + let exp_buf = exported.to_bit_buffer(); + let stored_buf = stored_signs.to_bit_buffer(); + for i in 0..exported.len() { + assert_eq!( + exp_buf.value(i), + stored_buf.value(i), + "Sign mismatch at bit {i}" + ); + } + + // Also verify decode output is non-empty and has expected size. + assert_eq!(decoded_slice.len(), 20 * 128); + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL-specific quality tests + // ----------------------------------------------------------------------- + + /// Verify that QJL's MSE component (at bit_width-1) satisfies the theoretical bound. + #[rstest] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 3)] + fn qjl_mse_within_theoretical_bound( + #[case] dim: usize, + #[case] bit_width: u8, + ) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + // QJL at b bits uses (b-1)-bit MSE plus a correction term. + // The MSE should be at most the (b-1)-bit theoretical bound, though + // in practice the QJL correction often improves it further. + let mse_bound = theoretical_mse_bound(bit_width - 1); + assert!( + normalized_mse < mse_bound, + "QJL MSE {normalized_mse:.6} exceeds (b-1)-bit bound {mse_bound:.6} \ + for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + /// Verify that high-bitwidth QJL (8-9 bits) achieves very low distortion. + #[rstest] + #[case(128, 8)] + #[case(128, 9)] + fn high_bitwidth_qjl_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + // Compare against 4-bit QJL as reference ceiling. + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(789), + }; + let (original_4, decoded_4) = encode_decode_qjl(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 4-bit ({mse_4bit:.6})" + ); + assert!( + mse < 0.01, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 1%" + ); + Ok(()) + } +} diff --git a/encodings/turboquant/src/mse/array/mod.rs b/encodings/turboquant/src/mse/array/mod.rs new file mode 100644 index 00000000000..b2517ff2e17 --- /dev/null +++ b/encodings/turboquant/src/mse/array/mod.rs @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant MSE array definition: stores quantized coordinate codes, norms, +//! centroids (codebook), and rotation signs. + +use vortex_array::ArrayRef; +use vortex_array::dtype::DType; +use vortex_array::stats::ArrayStats; +use vortex_array::vtable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use super::TurboQuantMSE; + +vtable!(TurboQuantMSE); + +/// Protobuf metadata for TurboQuant MSE encoding. +#[derive(Clone, prost::Message)] +pub struct TurboQuantMSEMetadata { + /// Vector dimension d. + #[prost(uint32, tag = "1")] + pub dimension: u32, + /// Bits per coordinate (1-8). + #[prost(uint32, tag = "2")] + pub bit_width: u32, + /// Padded dimension (next power of 2 >= dimension). + #[prost(uint32, tag = "3")] + pub padded_dim: u32, + /// Deterministic seed for rotation matrix (kept for reproducibility). + #[prost(uint64, tag = "4")] + pub rotation_seed: u64, +} + +/// TurboQuant MSE array. +/// +/// Children: +/// - 0: `codes` — `BitPackedArray` or `PrimitiveArray` (quantized indices) +/// - 1: `norms` — `PrimitiveArray` (one per vector row) +/// - 2: `centroids` — `PrimitiveArray` (codebook, length 2^bit_width) +/// - 3: `rotation_signs` — `BoolArray` (3 * padded_dim bits, inverse application order) +#[derive(Clone, Debug)] +pub struct TurboQuantMSEArray { + pub(crate) dtype: DType, + pub(crate) codes: ArrayRef, + pub(crate) norms: ArrayRef, + pub(crate) centroids: ArrayRef, + pub(crate) rotation_signs: ArrayRef, + pub(crate) dimension: u32, + pub(crate) bit_width: u8, + pub(crate) padded_dim: u32, + pub(crate) rotation_seed: u64, + pub(crate) stats_set: ArrayStats, +} + +impl TurboQuantMSEArray { + /// Build a new TurboQuantMSEArray. + #[allow(clippy::too_many_arguments)] + pub fn try_new( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + dimension: u32, + bit_width: u8, + padded_dim: u32, + rotation_seed: u64, + ) -> VortexResult { + vortex_ensure!( + (1..=8).contains(&bit_width), + "MSE bit_width must be 1-8, got {bit_width}" + ); + Ok(Self { + dtype, + codes, + norms, + centroids, + rotation_signs, + dimension, + bit_width, + padded_dim, + rotation_seed, + stats_set: Default::default(), + }) + } + + /// The vector dimension d. + pub fn dimension(&self) -> u32 { + self.dimension + } + + /// Bits per coordinate. + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Padded dimension (next power of 2 >= dimension). + pub fn padded_dim(&self) -> u32 { + self.padded_dim + } + + /// The rotation matrix seed. + pub fn rotation_seed(&self) -> u64 { + self.rotation_seed + } + + /// The bit-packed codes child. + pub fn codes(&self) -> &ArrayRef { + &self.codes + } + + /// The norms child. + pub fn norms(&self) -> &ArrayRef { + &self.norms + } + + /// The centroids (codebook) child. + pub fn centroids(&self) -> &ArrayRef { + &self.centroids + } + + /// The rotation signs child (BoolArray, length 3 * padded_dim). + pub fn rotation_signs(&self) -> &ArrayRef { + &self.rotation_signs + } +} diff --git a/encodings/turboquant/src/mse/mod.rs b/encodings/turboquant/src/mse/mod.rs new file mode 100644 index 00000000000..60ffe0bc59e --- /dev/null +++ b/encodings/turboquant/src/mse/mod.rs @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant MSE encoding: MSE-optimal scalar quantization of rotated unit vectors. + +pub use array::TurboQuantMSEArray; +pub use array::TurboQuantMSEMetadata; + +pub(crate) mod array; +mod vtable; + +use vortex_array::vtable::ArrayId; + +/// Encoding marker type for TurboQuant MSE. +#[derive(Clone, Debug)] +pub struct TurboQuantMSE; + +impl TurboQuantMSE { + pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant.mse"); +} diff --git a/encodings/turboquant/src/mse/vtable/mod.rs b/encodings/turboquant/src/mse/vtable/mod.rs new file mode 100644 index 00000000000..da1956e4cf1 --- /dev/null +++ b/encodings/turboquant/src/mse/vtable/mod.rs @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! VTable implementation for TurboQuant MSE encoding. + +use std::hash::Hash; +use std::ops::Deref; +use std::sync::Arc; + +use vortex_array::ArrayEq; +use vortex_array::ArrayHash; +use vortex_array::ArrayRef; +use vortex_array::DeserializeMetadata; +use vortex_array::DynArray; +use vortex_array::ExecutionCtx; +use vortex_array::ExecutionResult; +use vortex_array::Precision; +use vortex_array::ProstMetadata; +use vortex_array::SerializeMetadata; +use vortex_array::buffer::BufferHandle; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::serde::ArrayChildren; +use vortex_array::stats::StatsSetRef; +use vortex_array::vtable::Array; +use vortex_array::vtable::ArrayId; +use vortex_array::vtable::NotSupported; +use vortex_array::vtable::VTable; +use vortex_array::vtable::ValidityChild; +use vortex_array::vtable::ValidityVTableFromChild; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_panic; +use vortex_session::VortexSession; + +use super::TurboQuantMSE; +use super::array::TurboQuantMSEArray; +use super::array::TurboQuantMSEMetadata; +use crate::decompress::execute_decompress_mse; + +impl VTable for TurboQuantMSE { + type Array = TurboQuantMSEArray; + type Metadata = ProstMetadata; + type OperationsVTable = NotSupported; + type ValidityVTable = ValidityVTableFromChild; + + fn vtable(_array: &Self::Array) -> &Self { + &TurboQuantMSE + } + + fn id(&self) -> ArrayId { + Self::ID + } + + fn len(array: &TurboQuantMSEArray) -> usize { + array.norms.len() + } + + fn dtype(array: &TurboQuantMSEArray) -> &DType { + &array.dtype + } + + fn stats(array: &TurboQuantMSEArray) -> StatsSetRef<'_> { + array.stats_set.to_ref(array.as_ref()) + } + + fn array_hash( + array: &TurboQuantMSEArray, + state: &mut H, + precision: Precision, + ) { + array.dtype.hash(state); + array.dimension.hash(state); + array.bit_width.hash(state); + array.padded_dim.hash(state); + array.rotation_seed.hash(state); + array.codes.array_hash(state, precision); + array.norms.array_hash(state, precision); + array.centroids.array_hash(state, precision); + array.rotation_signs.array_hash(state, precision); + } + + fn array_eq( + array: &TurboQuantMSEArray, + other: &TurboQuantMSEArray, + precision: Precision, + ) -> bool { + array.dtype == other.dtype + && array.dimension == other.dimension + && array.bit_width == other.bit_width + && array.padded_dim == other.padded_dim + && array.rotation_seed == other.rotation_seed + && array.codes.array_eq(&other.codes, precision) + && array.norms.array_eq(&other.norms, precision) + && array.centroids.array_eq(&other.centroids, precision) + && array + .rotation_signs + .array_eq(&other.rotation_signs, precision) + } + + fn nbuffers(_array: &TurboQuantMSEArray) -> usize { + 0 + } + + fn buffer(_array: &TurboQuantMSEArray, idx: usize) -> BufferHandle { + vortex_panic!("TurboQuantMSEArray buffer index {idx} out of bounds") + } + + fn buffer_name(_array: &TurboQuantMSEArray, _idx: usize) -> Option { + None + } + + fn nchildren(_array: &TurboQuantMSEArray) -> usize { + 4 + } + + fn child(array: &TurboQuantMSEArray, idx: usize) -> ArrayRef { + match idx { + 0 => array.codes.clone(), + 1 => array.norms.clone(), + 2 => array.centroids.clone(), + 3 => array.rotation_signs.clone(), + _ => vortex_panic!("TurboQuantMSEArray child index {idx} out of bounds"), + } + } + + fn child_name(_array: &TurboQuantMSEArray, idx: usize) -> String { + match idx { + 0 => "codes".to_string(), + 1 => "norms".to_string(), + 2 => "centroids".to_string(), + 3 => "rotation_signs".to_string(), + _ => vortex_panic!("TurboQuantMSEArray child_name index {idx} out of bounds"), + } + } + + fn metadata(array: &TurboQuantMSEArray) -> VortexResult { + Ok(ProstMetadata(TurboQuantMSEMetadata { + dimension: array.dimension, + bit_width: array.bit_width as u32, + padded_dim: array.padded_dim, + rotation_seed: array.rotation_seed, + })) + } + + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(metadata.serialize())) + } + + fn deserialize( + bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(ProstMetadata( + as DeserializeMetadata>::deserialize(bytes)?, + )) + } + + fn build( + dtype: &DType, + len: usize, + metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + let bit_width = u8::try_from(metadata.bit_width)?; + let padded_dim = metadata.padded_dim as usize; + let num_centroids = 1usize << bit_width; + + let codes_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); + let codes = children.get(0, &codes_dtype, len * padded_dim)?; + + let norms_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let norms = children.get(1, &norms_dtype, len)?; + + let centroids = children.get(2, &norms_dtype, num_centroids)?; + + let signs_dtype = DType::Bool(Nullability::NonNullable); + let rotation_signs = children.get(3, &signs_dtype, 3 * padded_dim)?; + + Ok(TurboQuantMSEArray { + dtype: dtype.clone(), + codes, + norms, + centroids, + rotation_signs, + dimension: metadata.dimension, + bit_width, + padded_dim: metadata.padded_dim, + rotation_seed: metadata.rotation_seed, + stats_set: Default::default(), + }) + } + + fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { + vortex_ensure!( + children.len() == 4, + "TurboQuantMSEArray expects 4 children, got {}", + children.len() + ); + let mut iter = children.into_iter(); + array.codes = iter.next().vortex_expect("codes child"); + array.norms = iter.next().vortex_expect("norms child"); + array.centroids = iter.next().vortex_expect("centroids child"); + array.rotation_signs = iter.next().vortex_expect("rotation_signs child"); + Ok(()) + } + + fn execute(array: Arc>, ctx: &mut ExecutionCtx) -> VortexResult { + let inner = Arc::try_unwrap(array) + .map(|a| a.into_inner()) + .unwrap_or_else(|arc| arc.as_ref().deref().clone()); + Ok(ExecutionResult::done(execute_decompress_mse(inner, ctx)?)) + } +} + +impl ValidityChild for TurboQuantMSE { + fn validity_child(array: &TurboQuantMSEArray) -> &ArrayRef { + array.codes() + } +} diff --git a/encodings/turboquant/src/qjl/array/mod.rs b/encodings/turboquant/src/qjl/array/mod.rs new file mode 100644 index 00000000000..9b6883dcdd5 --- /dev/null +++ b/encodings/turboquant/src/qjl/array/mod.rs @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant QJL array definition: wraps a TurboQuantMSEArray with 1-bit QJL +//! residual correction for unbiased inner product estimation. + +use vortex_array::ArrayRef; +use vortex_array::dtype::DType; +use vortex_array::stats::ArrayStats; +use vortex_array::vtable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use super::TurboQuantQJL; + +vtable!(TurboQuantQJL); + +/// Protobuf metadata for TurboQuant QJL encoding. +#[derive(Clone, prost::Message)] +pub struct TurboQuantQJLMetadata { + /// Total bit width (2-9, including QJL bit; MSE child uses bit_width - 1). + #[prost(uint32, tag = "1")] + pub bit_width: u32, + /// Padded dimension (next power of 2 >= dimension). + #[prost(uint32, tag = "2")] + pub padded_dim: u32, + /// QJL rotation seed (for debugging/reproducibility). + #[prost(uint64, tag = "3")] + pub rotation_seed: u64, +} + +/// TurboQuant QJL array. +/// +/// Children: +/// - 0: `mse_inner` — `TurboQuantMSEArray` (at `bit_width - 1`) +/// - 1: `qjl_signs` — `BoolArray` (num_rows * padded_dim bits) +/// - 2: `residual_norms` — `PrimitiveArray` (one per row) +/// - 3: `rotation_signs` — `BoolArray` (3 * padded_dim bits, QJL rotation, inverse order) +#[derive(Clone, Debug)] +pub struct TurboQuantQJLArray { + pub(crate) dtype: DType, + pub(crate) mse_inner: ArrayRef, + pub(crate) qjl_signs: ArrayRef, + pub(crate) residual_norms: ArrayRef, + pub(crate) rotation_signs: ArrayRef, + pub(crate) bit_width: u8, + pub(crate) padded_dim: u32, + pub(crate) rotation_seed: u64, + pub(crate) stats_set: ArrayStats, +} + +impl TurboQuantQJLArray { + /// Build a new TurboQuantQJLArray. + #[allow(clippy::too_many_arguments)] + pub fn try_new( + dtype: DType, + mse_inner: ArrayRef, + qjl_signs: ArrayRef, + residual_norms: ArrayRef, + rotation_signs: ArrayRef, + bit_width: u8, + padded_dim: u32, + rotation_seed: u64, + ) -> VortexResult { + vortex_ensure!( + (2..=9).contains(&bit_width), + "QJL bit_width must be 2-9, got {bit_width}" + ); + Ok(Self { + dtype, + mse_inner, + qjl_signs, + residual_norms, + rotation_signs, + bit_width, + padded_dim, + rotation_seed, + stats_set: Default::default(), + }) + } + + /// Total bit width (including QJL bit). + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Padded dimension. + pub fn padded_dim(&self) -> u32 { + self.padded_dim + } + + /// QJL rotation seed. + pub fn rotation_seed(&self) -> u64 { + self.rotation_seed + } + + /// The inner MSE array child. + pub fn mse_inner(&self) -> &ArrayRef { + &self.mse_inner + } + + /// The QJL sign bits child (BoolArray). + pub fn qjl_signs(&self) -> &ArrayRef { + &self.qjl_signs + } + + /// The residual norms child. + pub fn residual_norms(&self) -> &ArrayRef { + &self.residual_norms + } + + /// The QJL rotation signs child (BoolArray). + pub fn rotation_signs(&self) -> &ArrayRef { + &self.rotation_signs + } +} diff --git a/encodings/turboquant/src/qjl/mod.rs b/encodings/turboquant/src/qjl/mod.rs new file mode 100644 index 00000000000..4885f7c9ddb --- /dev/null +++ b/encodings/turboquant/src/qjl/mod.rs @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant QJL encoding: inner-product-preserving quantization (MSE + QJL residual). + +pub use array::TurboQuantQJLArray; +pub use array::TurboQuantQJLMetadata; + +pub(crate) mod array; +mod vtable; + +use vortex_array::vtable::ArrayId; + +/// Encoding marker type for TurboQuant QJL. +#[derive(Clone, Debug)] +pub struct TurboQuantQJL; + +impl TurboQuantQJL { + pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant.qjl"); +} diff --git a/encodings/turboquant/src/qjl/vtable/mod.rs b/encodings/turboquant/src/qjl/vtable/mod.rs new file mode 100644 index 00000000000..b1020e6e2d2 --- /dev/null +++ b/encodings/turboquant/src/qjl/vtable/mod.rs @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! VTable implementation for TurboQuant QJL encoding. + +use std::hash::Hash; +use std::ops::Deref; +use std::sync::Arc; + +use vortex_array::ArrayEq; +use vortex_array::ArrayHash; +use vortex_array::ArrayRef; +use vortex_array::DeserializeMetadata; +use vortex_array::DynArray; +use vortex_array::ExecutionCtx; +use vortex_array::ExecutionResult; +use vortex_array::Precision; +use vortex_array::ProstMetadata; +use vortex_array::SerializeMetadata; +use vortex_array::buffer::BufferHandle; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::serde::ArrayChildren; +use vortex_array::stats::StatsSetRef; +use vortex_array::vtable::Array; +use vortex_array::vtable::ArrayId; +use vortex_array::vtable::NotSupported; +use vortex_array::vtable::VTable; +use vortex_array::vtable::ValidityChild; +use vortex_array::vtable::ValidityVTableFromChild; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_panic; +use vortex_session::VortexSession; + +use super::TurboQuantQJL; +use super::array::TurboQuantQJLArray; +use super::array::TurboQuantQJLMetadata; +use crate::decompress::execute_decompress_qjl; + +impl VTable for TurboQuantQJL { + type Array = TurboQuantQJLArray; + type Metadata = ProstMetadata; + type OperationsVTable = NotSupported; + type ValidityVTable = ValidityVTableFromChild; + + fn vtable(_array: &Self::Array) -> &Self { + &TurboQuantQJL + } + + fn id(&self) -> ArrayId { + Self::ID + } + + fn len(array: &TurboQuantQJLArray) -> usize { + array.residual_norms.len() + } + + fn dtype(array: &TurboQuantQJLArray) -> &DType { + &array.dtype + } + + fn stats(array: &TurboQuantQJLArray) -> StatsSetRef<'_> { + array.stats_set.to_ref(array.as_ref()) + } + + fn array_hash( + array: &TurboQuantQJLArray, + state: &mut H, + precision: Precision, + ) { + array.dtype.hash(state); + array.bit_width.hash(state); + array.padded_dim.hash(state); + array.rotation_seed.hash(state); + array.mse_inner.array_hash(state, precision); + array.qjl_signs.array_hash(state, precision); + array.residual_norms.array_hash(state, precision); + array.rotation_signs.array_hash(state, precision); + } + + fn array_eq( + array: &TurboQuantQJLArray, + other: &TurboQuantQJLArray, + precision: Precision, + ) -> bool { + array.dtype == other.dtype + && array.bit_width == other.bit_width + && array.padded_dim == other.padded_dim + && array.rotation_seed == other.rotation_seed + && array.mse_inner.array_eq(&other.mse_inner, precision) + && array.qjl_signs.array_eq(&other.qjl_signs, precision) + && array + .residual_norms + .array_eq(&other.residual_norms, precision) + && array + .rotation_signs + .array_eq(&other.rotation_signs, precision) + } + + fn nbuffers(_array: &TurboQuantQJLArray) -> usize { + 0 + } + + fn buffer(_array: &TurboQuantQJLArray, idx: usize) -> BufferHandle { + vortex_panic!("TurboQuantQJLArray buffer index {idx} out of bounds") + } + + fn buffer_name(_array: &TurboQuantQJLArray, _idx: usize) -> Option { + None + } + + fn nchildren(_array: &TurboQuantQJLArray) -> usize { + 4 + } + + fn child(array: &TurboQuantQJLArray, idx: usize) -> ArrayRef { + match idx { + 0 => array.mse_inner.clone(), + 1 => array.qjl_signs.clone(), + 2 => array.residual_norms.clone(), + 3 => array.rotation_signs.clone(), + _ => vortex_panic!("TurboQuantQJLArray child index {idx} out of bounds"), + } + } + + fn child_name(_array: &TurboQuantQJLArray, idx: usize) -> String { + match idx { + 0 => "mse_inner".to_string(), + 1 => "qjl_signs".to_string(), + 2 => "residual_norms".to_string(), + 3 => "rotation_signs".to_string(), + _ => vortex_panic!("TurboQuantQJLArray child_name index {idx} out of bounds"), + } + } + + fn metadata(array: &TurboQuantQJLArray) -> VortexResult { + Ok(ProstMetadata(TurboQuantQJLMetadata { + bit_width: array.bit_width as u32, + padded_dim: array.padded_dim, + rotation_seed: array.rotation_seed, + })) + } + + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(metadata.serialize())) + } + + fn deserialize( + bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(ProstMetadata( + as DeserializeMetadata>::deserialize(bytes)?, + )) + } + + fn build( + dtype: &DType, + len: usize, + metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + let padded_dim = metadata.padded_dim as usize; + + let mse_inner = children.get(0, dtype, len)?; + + let signs_dtype = DType::Bool(Nullability::NonNullable); + let qjl_signs = children.get(1, &signs_dtype, len * padded_dim)?; + + let norms_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let residual_norms = children.get(2, &norms_dtype, len)?; + + let rotation_signs = children.get(3, &signs_dtype, 3 * padded_dim)?; + + Ok(TurboQuantQJLArray { + dtype: dtype.clone(), + mse_inner, + qjl_signs, + residual_norms, + rotation_signs, + bit_width: u8::try_from(metadata.bit_width)?, + padded_dim: metadata.padded_dim, + rotation_seed: metadata.rotation_seed, + stats_set: Default::default(), + }) + } + + fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { + vortex_ensure!( + children.len() == 4, + "TurboQuantQJLArray expects 4 children, got {}", + children.len() + ); + let mut iter = children.into_iter(); + array.mse_inner = iter.next().vortex_expect("mse_inner child"); + array.qjl_signs = iter.next().vortex_expect("qjl_signs child"); + array.residual_norms = iter.next().vortex_expect("residual_norms child"); + array.rotation_signs = iter.next().vortex_expect("rotation_signs child"); + Ok(()) + } + + fn execute(array: Arc>, ctx: &mut ExecutionCtx) -> VortexResult { + let inner = Arc::try_unwrap(array) + .map(|a| a.into_inner()) + .unwrap_or_else(|arc| arc.as_ref().deref().clone()); + Ok(ExecutionResult::done(execute_decompress_qjl(inner, ctx)?)) + } +} + +impl ValidityChild for TurboQuantQJL { + fn validity_child(array: &TurboQuantQJLArray) -> &ArrayRef { + array.mse_inner() + } +} diff --git a/encodings/turboquant/src/rotation.rs b/encodings/turboquant/src/rotation.rs new file mode 100644 index 00000000000..18d41066cc0 --- /dev/null +++ b/encodings/turboquant/src/rotation.rs @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Deterministic random rotation for TurboQuant. +//! +//! Uses a Structured Random Hadamard Transform (SRHT) for O(d log d) rotation +//! instead of a full d×d matrix multiply. The SRHT applies the sequence +//! D₃ · H · D₂ · H · D₁ where H is the Walsh-Hadamard transform and Dₖ are +//! random diagonal ±1 sign matrices. Three rounds of HD provide sufficient +//! randomness for near-uniform distribution on the sphere. +//! +//! For dimensions that are not powers of 2, the input is zero-padded to the +//! next power of 2 before the transform and truncated afterward. + +use rand::RngExt; +use rand::SeedableRng; +use rand::rngs::StdRng; +use vortex_array::arrays::BoolArray; +use vortex_array::validity::Validity; +use vortex_buffer::BitBufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// A structured random Hadamard transform for O(d log d) pseudo-random rotation. +pub struct RotationMatrix { + /// Random ±1 signs for each of the 3 diagonal matrices, each of length `padded_dim`. + signs: [Vec; 3], + /// The original (unpadded) dimension. + dim: usize, + /// The padded dimension (next power of 2 >= dim). + padded_dim: usize, + /// Normalization factor: 1/padded_dim per Hadamard, applied once at the end. + norm_factor: f32, +} + +impl RotationMatrix { + /// Create a new SRHT rotation from a deterministic seed. + pub fn try_new(seed: u64, dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate 3 random sign vectors (±1). + let signs = std::array::from_fn(|_| gen_random_signs(&mut rng, padded_dim)); + + // Each Hadamard transform has a normalization factor of 1/sqrt(padded_dim). + // With 3 Hadamard transforms: (1/sqrt(n))^3 = 1/(n * sqrt(n)). + // But we want an orthogonal-like transform that preserves norms. The + // standard WHT without normalization scales by sqrt(n) each time. With 3 + // applications: output ~ n^(3/2) * input. To normalize: divide by n^(3/2). + // Equivalently, divide by n after each WHT (making each one orthonormal). + // We fold all normalization into a single factor applied at the end. + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + signs, + dim: dimension, + padded_dim, + norm_factor, + }) + } + + /// Apply forward rotation: `output = SRHT(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. The caller + /// is responsible for zero-padding input beyond `dim` positions. + pub fn rotate(&self, input: &[f32], output: &mut [f32]) { + let pd = self.padded_dim; + debug_assert_eq!(input.len(), pd); + debug_assert_eq!(output.len(), pd); + + output.copy_from_slice(input); + self.apply_srht(output); + } + + /// Apply inverse rotation: `output = SRHT⁻¹(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. + pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) { + let pd = self.padded_dim; + debug_assert_eq!(input.len(), pd); + debug_assert_eq!(output.len(), pd); + + output.copy_from_slice(input); + self.apply_inverse_srht(output); + } + + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All rotate/inverse_rotate buffers must be this length. + pub fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the SRHT: D₃ · H · D₂ · H · D₁ · x, with normalization. + fn apply_srht(&self, buf: &mut [f32]) { + // Round 1: D₁ then H + apply_signs(buf, &self.signs[0]); + walsh_hadamard_transform(buf); + + // Round 2: D₂ then H + apply_signs(buf, &self.signs[1]); + walsh_hadamard_transform(buf); + + // Round 3: D₃ then normalize + apply_signs(buf, &self.signs[2]); + walsh_hadamard_transform(buf); + + // Apply combined normalization factor. + let norm = self.norm_factor; + for val in buf.iter_mut() { + *val *= norm; + } + } + + /// Apply the inverse SRHT. + /// + /// Forward is: norm · H · D₃ · H · D₂ · H · D₁ + /// Inverse is: norm · D₁ · H · D₂ · H · D₃ · H + fn apply_inverse_srht(&self, buf: &mut [f32]) { + walsh_hadamard_transform(buf); + apply_signs(buf, &self.signs[2]); + + walsh_hadamard_transform(buf); + apply_signs(buf, &self.signs[1]); + + walsh_hadamard_transform(buf); + apply_signs(buf, &self.signs[0]); + + let norm = self.norm_factor; + for val in buf.iter_mut() { + *val *= norm; + } + } + + /// Returns the dimension of this rotation. + pub fn dimension(&self) -> usize { + self.dim + } + + /// Returns the normalization factor for this transform. + pub fn norm_factor(&self) -> f32 { + self.norm_factor + } + + /// Export the 3 sign vectors as a single `BoolArray` in inverse-application order. + /// + /// The output `BoolArray` has length `3 * padded_dim` and stores `[D₃ | D₂ | D₁]` + /// so that decompression (which applies the inverse transform) iterates sign arrays + /// 0→1→2 sequentially. Convention: `true` = +1, `false` = -1. + pub fn export_inverse_signs_bool_array(&self) -> BoolArray { + let total_bits = 3 * self.padded_dim; + let mut bits = BitBufferMut::new_unset(total_bits); + + // Store in inverse order: signs[2] (D₃), signs[1] (D₂), signs[0] (D₁) + for (round, sign_idx) in [2, 1, 0].iter().enumerate() { + let offset = round * self.padded_dim; + for j in 0..self.padded_dim { + if self.signs[*sign_idx][j] > 0.0 { + bits.set(offset + j); + } + } + } + + BoolArray::new(bits.freeze(), Validity::NonNullable) + } + + /// Reconstruct a `RotationMatrix` from a stored `BoolArray` of signs. + /// + /// The `BoolArray` must have length `3 * padded_dim` with signs in inverse + /// application order `[D₃ | D₂ | D₁]` (as produced by + /// [`export_inverse_signs_bool_array`]). + pub fn from_bool_array(signs_array: &BoolArray, dim: usize) -> VortexResult { + let padded_dim = dim.next_power_of_two(); + vortex_ensure!( + signs_array.len() == 3 * padded_dim, + "Expected BoolArray of length {}, got {}", + 3 * padded_dim, + signs_array.len() + ); + + let bit_buf = signs_array.to_bit_buffer(); + + // Reconstruct in storage order (inverse): [D₃, D₂, D₁] → signs[2], signs[1], signs[0] + let mut signs: [Vec; 3] = std::array::from_fn(|_| Vec::with_capacity(padded_dim)); + + for (round, sign_idx) in [2, 1, 0].iter().enumerate() { + let offset = round * padded_dim; + signs[*sign_idx] = (0..padded_dim) + .map(|j| { + if bit_buf.value(offset + j) { + 1.0f32 + } else { + -1.0f32 + } + }) + .collect(); + } + + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + signs, + dim, + padded_dim, + norm_factor, + }) + } +} + +/// Generate a vector of random ±1 signs. +fn gen_random_signs(rng: &mut StdRng, len: usize) -> Vec { + (0..len) + .map(|_| { + if rng.random_bool(0.5) { + 1.0f32 + } else { + -1.0f32 + } + }) + .collect() +} + +/// Element-wise multiply by ±1 signs. +#[inline] +fn apply_signs(buf: &mut [f32], signs: &[f32]) { + for (val, &sign) in buf.iter_mut().zip(signs.iter()) { + *val *= sign; + } +} + +/// In-place Walsh-Hadamard Transform (unnormalized, iterative). +/// +/// Input length must be a power of 2. Runs in O(n log n). +fn walsh_hadamard_transform(buf: &mut [f32]) { + let len = buf.len(); + debug_assert!(len.is_power_of_two()); + + let mut half = 1; + while half < len { + for block_start in (0..len).step_by(half * 2) { + for idx in block_start..block_start + half { + let sum = buf[idx] + buf[idx + half]; + let diff = buf[idx] - buf[idx + half]; + buf[idx] = sum; + buf[idx + half] = diff; + } + } + half *= 2; + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[test] + fn deterministic_from_seed() -> VortexResult<()> { + let r1 = RotationMatrix::try_new(42, 64)?; + let r2 = RotationMatrix::try_new(42, 64)?; + let pd = r1.padded_dim(); + + let mut input = vec![0.0f32; pd]; + for i in 0..64 { + input[i] = i as f32; + } + let mut out1 = vec![0.0f32; pd]; + let mut out2 = vec![0.0f32; pd]; + + r1.rotate(&input, &mut out1); + r2.rotate(&input, &mut out2); + + assert_eq!(out1, out2); + Ok(()) + } + + /// Verify roundtrip is exact to f32 precision across many dimensions, + /// including non-power-of-two dimensions that require padding. + #[rstest] + #[case(32)] + #[case(64)] + #[case(100)] + #[case(128)] + #[case(256)] + #[case(512)] + #[case(768)] + #[case(1024)] + fn roundtrip_exact(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + let mut rotated = vec![0.0f32; padded_dim]; + let mut recovered = vec![0.0f32; padded_dim]; + + rot.rotate(&input, &mut rotated); + rot.inverse_rotate(&rotated, &mut recovered); + + let max_err: f32 = input + .iter() + .zip(recovered.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); + let rel_err = max_err / max_val; + + // SRHT roundtrip should be exact up to f32 precision (~1e-6). + assert!( + rel_err < 1e-5, + "roundtrip relative error too large for dim={dim}: {rel_err:.2e}" + ); + Ok(()) + } + + /// Verify norm preservation across dimensions. + #[rstest] + #[case(128)] + #[case(768)] + fn preserves_norm(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(7, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32) * 0.01; + } + let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); + + let mut rotated = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut rotated); + let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::().sqrt(); + + assert!( + (input_norm - rotated_norm).abs() / input_norm < 1e-5, + "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", + input_norm, + rotated_norm, + (input_norm - rotated_norm).abs() / input_norm + ); + Ok(()) + } + + /// Verify that export → from_bool_array produces identical rotation output. + #[rstest] + #[case(64)] + #[case(128)] + #[case(768)] + fn sign_export_import_roundtrip(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let signs_array = rot.export_inverse_signs_bool_array(); + let rot2 = RotationMatrix::from_bool_array(&signs_array, dim)?; + + // Verify both produce identical rotation and inverse rotation. + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + + let mut out1 = vec![0.0f32; padded_dim]; + let mut out2 = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut out1); + rot2.rotate(&input, &mut out2); + assert_eq!(out1, out2, "Forward rotation mismatch after export/import"); + + rot.inverse_rotate(&out1, &mut out2); + let mut out3 = vec![0.0f32; padded_dim]; + rot2.inverse_rotate(&out1, &mut out3); + assert_eq!(out2, out3, "Inverse rotation mismatch after export/import"); + + Ok(()) + } + + #[test] + fn wht_basic() { + // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] + let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; + walsh_hadamard_transform(&mut buf); + assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); + + // WHT is self-inverse (up to scaling by n) + walsh_hadamard_transform(&mut buf); + // After two WHTs: each element multiplied by n=4 + assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); + } +} diff --git a/vortex-btrblocks/Cargo.toml b/vortex-btrblocks/Cargo.toml index 1c745306c4a..4e51ee33014 100644 --- a/vortex-btrblocks/Cargo.toml +++ b/vortex-btrblocks/Cargo.toml @@ -35,6 +35,7 @@ vortex-pco = { workspace = true, optional = true } vortex-runend = { workspace = true } vortex-sequence = { workspace = true } vortex-sparse = { workspace = true } +vortex-turboquant = { workspace = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } diff --git a/vortex-btrblocks/public-api.lock b/vortex-btrblocks/public-api.lock index 55d23a96a26..17564a21025 100644 --- a/vortex-btrblocks/public-api.lock +++ b/vortex-btrblocks/public-api.lock @@ -194,6 +194,8 @@ pub vortex_btrblocks::BtrBlocksCompressor::int_schemes: alloc::vec::Vec<&'static pub vortex_btrblocks::BtrBlocksCompressor::string_schemes: alloc::vec::Vec<&'static dyn vortex_btrblocks::compressor::string::StringScheme> +pub vortex_btrblocks::BtrBlocksCompressor::turboquant_config: core::option::Option + impl vortex_btrblocks::BtrBlocksCompressor pub fn vortex_btrblocks::BtrBlocksCompressor::compress(&self, array: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult @@ -236,6 +238,8 @@ pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::include_int(self, codes: im pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::include_string(self, codes: impl core::iter::traits::collect::IntoIterator) -> Self +pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::with_turboquant(self, config: vortex_turboquant::compress::TurboQuantConfig) -> Self + impl core::clone::Clone for vortex_btrblocks::BtrBlocksCompressorBuilder pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::clone(&self) -> vortex_btrblocks::BtrBlocksCompressorBuilder diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index d329ec8c139..851c4e6d986 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -46,6 +46,7 @@ pub struct BtrBlocksCompressorBuilder { int_schemes: HashSet<&'static dyn IntegerScheme>, float_schemes: HashSet<&'static dyn FloatScheme>, string_schemes: HashSet<&'static dyn StringScheme>, + turboquant_config: Option, } impl Default for BtrBlocksCompressorBuilder { @@ -66,6 +67,7 @@ impl Default for BtrBlocksCompressorBuilder { .copied() .filter(|s| s.code() != StringCode::Zstd && s.code() != StringCode::ZstdBuffers) .collect(), + turboquant_config: None, } } } @@ -77,6 +79,7 @@ impl BtrBlocksCompressorBuilder { int_schemes: Default::default(), float_schemes: Default::default(), string_schemes: Default::default(), + turboquant_config: None, } } @@ -134,6 +137,16 @@ impl BtrBlocksCompressorBuilder { self } + /// Enables TurboQuant lossy vector quantization for tensor extension types. + /// + /// When enabled, `Vector` and `FixedShapeTensor` extension columns will be + /// quantized at the configured bit-width instead of using the default + /// recursive storage compression. + pub fn with_turboquant(mut self, config: vortex_turboquant::TurboQuantConfig) -> Self { + self.turboquant_config = Some(config); + self + } + /// Builds the configured `BtrBlocksCompressor`. pub fn build(self) -> BtrBlocksCompressor { // Note we should apply the schemes in the same order, in case try conflict. @@ -153,6 +166,7 @@ impl BtrBlocksCompressorBuilder { .into_iter() .sorted_by_key(|s| s.code()) .collect_vec(), + turboquant_config: self.turboquant_config, } } } diff --git a/vortex-btrblocks/src/canonical_compressor.rs b/vortex-btrblocks/src/canonical_compressor.rs index 46252060a1f..203144912fe 100644 --- a/vortex-btrblocks/src/canonical_compressor.rs +++ b/vortex-btrblocks/src/canonical_compressor.rs @@ -40,6 +40,8 @@ use crate::compressor::float::FloatScheme; use crate::compressor::integer::IntegerScheme; use crate::compressor::string::StringScheme; use crate::compressor::temporal::compress_temporal; +use crate::compressor::turboquant::compress_turboquant; +use crate::compressor::turboquant::is_tensor_extension; /// Trait for compressors that can compress canonical arrays. /// @@ -101,6 +103,9 @@ pub struct BtrBlocksCompressor { /// String compressor with configured schemes. pub string_schemes: Vec<&'static dyn StringScheme>, + + /// Optional TurboQuant configuration for tensor extension types. + pub turboquant_config: Option, } impl Default for BtrBlocksCompressor { @@ -290,6 +295,15 @@ impl CanonicalCompressor for BtrBlocksCompressor { return compress_temporal(self, temporal_array); } + // Compress tensor extension types with TurboQuant if configured. + // Falls through to default compression for nullable storage. + if let Some(tq_config) = &self.turboquant_config + && is_tensor_extension(&ext_array) + && let Some(compressed) = compress_turboquant(&ext_array, tq_config)? + { + return Ok(compressed); + } + // Compress the underlying storage array. let compressed_storage = self.compress(ext_array.storage_array())?; diff --git a/vortex-btrblocks/src/compressor/mod.rs b/vortex-btrblocks/src/compressor/mod.rs index 5c3a31271cd..e97c1d9b87b 100644 --- a/vortex-btrblocks/src/compressor/mod.rs +++ b/vortex-btrblocks/src/compressor/mod.rs @@ -34,6 +34,7 @@ mod patches; mod rle; pub(crate) mod string; pub(crate) mod temporal; +pub(crate) mod turboquant; /// Maximum cascade depth for compression. pub(crate) const MAX_CASCADE: usize = 3; diff --git a/vortex-btrblocks/src/compressor/turboquant.rs b/vortex-btrblocks/src/compressor/turboquant.rs new file mode 100644 index 00000000000..dc8e50d3c9b --- /dev/null +++ b/vortex-btrblocks/src/compressor/turboquant.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Specialized compressor for TurboQuant vector quantization of tensor extension types. + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_error::VortexResult; +use vortex_turboquant::FIXED_SHAPE_TENSOR_EXT_ID; +use vortex_turboquant::TurboQuantConfig; +use vortex_turboquant::VECTOR_EXT_ID; +use vortex_turboquant::turboquant_encode_qjl; + +/// Check if an extension array has a tensor extension type. +pub(crate) fn is_tensor_extension(ext_array: &ExtensionArray) -> bool { + let ext_id = ext_array.ext_dtype().id(); + ext_id.as_ref() == VECTOR_EXT_ID || ext_id.as_ref() == FIXED_SHAPE_TENSOR_EXT_ID +} + +/// Try to compress a tensor extension array using TurboQuant. +/// +/// Returns `Ok(Some(...))` on success, or `Ok(None)` if the storage is nullable +/// (TurboQuant requires non-nullable input). The caller should fall through to +/// default compression when `None` is returned. +/// +/// Produces a `TurboQuantQJLArray` wrapping a `TurboQuantMSEArray`, stored inside +/// the Extension wrapper. All children (codes, norms, centroids, rotation signs, +/// QJL signs, residual norms) are left for the standard BtrBlocks recursive +/// compression pipeline to handle during layout serialization. +pub(crate) fn compress_turboquant( + ext_array: &ExtensionArray, + config: &TurboQuantConfig, +) -> VortexResult> { + let storage = ext_array.storage_array(); + let fsl = storage.to_canonical()?.into_fixed_size_list(); + + if fsl.dtype().is_nullable() { + return Ok(None); + } + + // Produce the cascaded QJL(MSE) structure. The layout writer will + // recursively descend into children and compress each one. + let qjl_array = turboquant_encode_qjl(&fsl, config)?; + + Ok(Some( + ExtensionArray::new(ext_array.ext_dtype().clone(), qjl_array.into_array()).into_array(), + )) +} diff --git a/vortex-file/Cargo.toml b/vortex-file/Cargo.toml index d568328bb52..0752553c1e4 100644 --- a/vortex-file/Cargo.toml +++ b/vortex-file/Cargo.toml @@ -54,6 +54,7 @@ vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } vortex-sparse = { workspace = true } +vortex-turboquant = { workspace = true } vortex-utils = { workspace = true, features = ["dashmap"] } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } diff --git a/vortex-file/public-api.lock b/vortex-file/public-api.lock index 84cca867cba..ffb19c25fb5 100644 --- a/vortex-file/public-api.lock +++ b/vortex-file/public-api.lock @@ -358,6 +358,8 @@ pub fn vortex_file::WriteStrategyBuilder::with_flat_strategy(self, flat: alloc:: pub fn vortex_file::WriteStrategyBuilder::with_row_block_size(self, row_block_size: usize) -> Self +pub fn vortex_file::WriteStrategyBuilder::with_vector_quantization(self, config: vortex_turboquant::compress::TurboQuantConfig) -> Self + impl core::default::Default for vortex_file::WriteStrategyBuilder pub fn vortex_file::WriteStrategyBuilder::default() -> Self diff --git a/vortex-file/src/lib.rs b/vortex-file/src/lib.rs index d888eb88def..b99ba26d9e9 100644 --- a/vortex-file/src/lib.rs +++ b/vortex-file/src/lib.rs @@ -178,4 +178,5 @@ pub fn register_default_encodings(session: &mut VortexSession) { vortex_fastlanes::initialize(session); vortex_runend::initialize(session); vortex_sequence::initialize(session); + vortex_turboquant::initialize(session); } diff --git a/vortex-file/src/strategy.rs b/vortex-file/src/strategy.rs index 4d6031a220c..56ce56fc755 100644 --- a/vortex-file/src/strategy.rs +++ b/vortex-file/src/strategy.rs @@ -28,7 +28,6 @@ use vortex_array::arrays::VarBinView; use vortex_array::dtype::FieldPath; use vortex_array::session::ArrayRegistry; use vortex_array::session::ArraySession; -#[cfg(feature = "zstd")] use vortex_btrblocks::BtrBlocksCompressorBuilder; #[cfg(feature = "zstd")] use vortex_btrblocks::FloatCode; @@ -61,6 +60,8 @@ use vortex_pco::Pco; use vortex_runend::RunEnd; use vortex_sequence::Sequence; use vortex_sparse::Sparse; +use vortex_turboquant::TurboQuantMSE; +use vortex_turboquant::TurboQuantQJL; use vortex_utils::aliases::hash_map::HashMap; use vortex_zigzag::ZigZag; #[cfg(feature = "zstd")] @@ -110,6 +111,8 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { session.register(Sequence); session.register(Sparse); session.register(ZigZag); + session.register(TurboQuantMSE); + session.register(TurboQuantQJL); #[cfg(feature = "zstd")] session.register(Zstd); @@ -126,6 +129,7 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { /// bulk decoding performance, and IOPS required to perform an indexed read. pub struct WriteStrategyBuilder { compressor: Option>, + turboquant_config: Option, row_block_size: usize, field_writers: HashMap>, allow_encodings: Option, @@ -138,6 +142,7 @@ impl Default for WriteStrategyBuilder { fn default() -> Self { Self { compressor: None, + turboquant_config: None, row_block_size: 8192, field_writers: HashMap::new(), allow_encodings: Some(ALLOWED_ENCODINGS.clone()), @@ -231,6 +236,29 @@ impl WriteStrategyBuilder { self } + /// Configure lossy vector quantization for tensor columns using TurboQuant. + /// + /// Columns with `Vector` or `FixedShapeTensor` extension types will be quantized at the + /// specified bit-width. All other columns use the default BtrBlocks compression strategy. + /// The TurboQuant array's children (norms, codes) are recursively compressed by the + /// BtrBlocks compressor. + /// + /// This can be combined with other builder methods. If a custom compressor is also set + /// via [`with_compressor`](Self::with_compressor), the custom compressor takes precedence + /// and the TurboQuant config is ignored. + /// + /// # Examples + /// + /// ```ignore + /// WriteStrategyBuilder::default() + /// .with_vector_quantization(TurboQuantConfig { bit_width: 3, seed: None }) + /// .build() + /// ``` + pub fn with_vector_quantization(mut self, config: vortex_turboquant::TurboQuantConfig) -> Self { + self.turboquant_config = Some(config); + self + } + /// Builds the canonical [`LayoutStrategy`] implementation, with the configured overrides /// applied. pub fn build(self) -> Arc { @@ -249,6 +277,14 @@ impl WriteStrategyBuilder { // 5. compress each chunk let compressing = if let Some(ref compressor) = self.compressor { CompressingStrategy::new_opaque(buffered, compressor.clone()) + } else if let Some(tq_config) = self.turboquant_config { + let btrblocks = BtrBlocksCompressorBuilder::default() + .with_turboquant(tq_config) + .build(); + CompressingStrategy::new_opaque( + buffered, + Arc::new(btrblocks) as Arc, + ) } else { CompressingStrategy::new_btrblocks(buffered, true) }; diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index d8dc89882b0..23af132784a 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -44,6 +44,7 @@ vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } vortex-sparse = { workspace = true } +vortex-turboquant = { workspace = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } @@ -56,6 +57,7 @@ fastlanes = { workspace = true } mimalloc = { workspace = true } parquet = { workspace = true } rand = { workspace = true } +rand_distr = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index 4776afa4a52..dd3ec8a4b25 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -17,10 +17,12 @@ use rand::prelude::IndexedRandom; use rand::rngs::StdRng; use vortex::array::IntoArray; use vortex::array::ToCanonical; +use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::builders::dict::dict_encode; use vortex::array::builtins::ArrayBuiltins; +use vortex::array::validity::Validity; use vortex::dtype::PType; use vortex::encodings::alp::RDEncoder; use vortex::encodings::alp::alp_encode; @@ -32,11 +34,14 @@ use vortex::encodings::fsst::fsst_train_compressor; use vortex::encodings::pco::PcoArray; use vortex::encodings::runend::RunEndArray; use vortex::encodings::sequence::sequence_encode; +use vortex::encodings::turboquant::TurboQuantConfig; +use vortex::encodings::turboquant::turboquant_encode_mse; use vortex::encodings::zigzag::zigzag_encode; use vortex::encodings::zstd::ZstdArray; use vortex_array::VortexSessionExecute; use vortex_array::dtype::Nullability; use vortex_array::session::ArraySession; +use vortex_buffer::BufferMut; use vortex_sequence::SequenceArray; use vortex_session::VortexSession; @@ -405,3 +410,83 @@ fn bench_zstd_decompress_string(bencher: Bencher) { .with_inputs(|| &compressed) .bench_refs(|a| a.to_canonical()); } + +// TurboQuant vector quantization benchmarks + +const NUM_VECTORS: usize = 1_000; + +/// Generate `num_vectors` random f32 vectors of the given dimension using i.i.d. +/// standard normal components. This is a conservative test distribution: real +/// neural network embeddings typically have structure (clustered, anisotropic) +/// that the SRHT exploits for better quantization, so Gaussian i.i.d. is a +/// worst-case baseline for TurboQuant. +fn setup_vector_fsl(dim: usize) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(42); + let normal = rand_distr::Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(NUM_VECTORS * dim); + for _ in 0..(NUM_VECTORS * dim) { + buf.push(rand_distr::Distribution::sample(&normal, &mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + NUM_VECTORS, + ) + .unwrap() +} + +fn turboquant_config(bit_width: u8) -> TurboQuantConfig { + TurboQuantConfig { + bit_width, + seed: Some(123), + } +} + +macro_rules! turboquant_bench { + (compress, $dim:literal, $bits:literal, $name:ident) => { + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] + fn $name(bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &fsl) + .bench_refs(|a| turboquant_encode_mse(a, &config).unwrap()); + } + }; + (decompress, $dim:literal, $bits:literal, $name:ident) => { + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] + fn $name(bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + let compressed = turboquant_encode_mse(&fsl, &config).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + }; +} + +turboquant_bench!(compress, 128, 2, bench_tq_compress_128_2); +turboquant_bench!(decompress, 128, 2, bench_tq_decompress_128_2); +turboquant_bench!(compress, 128, 4, bench_tq_compress_128_4); +turboquant_bench!(decompress, 128, 4, bench_tq_decompress_128_4); +turboquant_bench!(compress, 768, 2, bench_tq_compress_768_2); +turboquant_bench!(decompress, 768, 2, bench_tq_decompress_768_2); +turboquant_bench!(compress, 1024, 2, bench_tq_compress_1024_2); +turboquant_bench!(decompress, 1024, 2, bench_tq_decompress_1024_2); +turboquant_bench!(compress, 1024, 4, bench_tq_compress_1024_4); +turboquant_bench!(decompress, 1024, 4, bench_tq_decompress_1024_4); +turboquant_bench!(compress, 1536, 2, bench_tq_compress_1536_2); +turboquant_bench!(decompress, 1536, 2, bench_tq_decompress_1536_2); +turboquant_bench!(compress, 1536, 4, bench_tq_compress_1536_4); +turboquant_bench!(decompress, 1536, 4, bench_tq_decompress_1536_4); diff --git a/vortex/public-api.lock b/vortex/public-api.lock index 0c8ce9d0cd9..325812fafc4 100644 --- a/vortex/public-api.lock +++ b/vortex/public-api.lock @@ -74,6 +74,10 @@ pub mod vortex::encodings::sparse pub use vortex::encodings::sparse::<> +pub mod vortex::encodings::turboquant + +pub use vortex::encodings::turboquant::<> + pub mod vortex::encodings::zigzag pub use vortex::encodings::zigzag::<> diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index a532fc1adad..454886077c3 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -143,6 +143,10 @@ pub mod encodings { pub use vortex_sparse::*; } + pub mod turboquant { + pub use vortex_turboquant::*; + } + pub mod zigzag { pub use vortex_zigzag::*; }