diff --git a/algorithms/linfa-tsne/Cargo.toml b/algorithms/linfa-tsne/Cargo.toml index ca5b53a1a..1cf57e622 100644 --- a/algorithms/linfa-tsne/Cargo.toml +++ b/algorithms/linfa-tsne/Cargo.toml @@ -16,14 +16,14 @@ categories = ["algorithms", "mathematics", "science"] [dependencies] thiserror = "2.0" ndarray = { version = "0.16" } -ndarray-rand = "0.15" -bhtsne = "0.4.0" -pdqselect = "=0.1.1" +bhtsne = { version = "0.5.4", default-features = false } linfa = { version = "0.8.1", path = "../.." } +linfa-nn = { version = "0.8.1", path = "../linfa-nn" } [dev-dependencies] rand = "0.8" +ndarray-rand = "0.15" approx = "0.5" linfa-datasets = { version = "0.8.1", path = "../../datasets", features = [ diff --git a/algorithms/linfa-tsne/src/hyperparams.rs b/algorithms/linfa-tsne/src/hyperparams.rs index ac2f8da7c..270961be7 100644 --- a/algorithms/linfa-tsne/src/hyperparams.rs +++ b/algorithms/linfa-tsne/src/hyperparams.rs @@ -1,5 +1,4 @@ use linfa::{Float, ParamGuard}; -use ndarray_rand::rand::{rngs::SmallRng, Rng, SeedableRng}; use crate::TSneError; @@ -32,16 +31,16 @@ use crate::TSneError; /// /// A verified hyper-parameter set ready for prediction #[derive(Debug, Clone, PartialEq)] -pub struct TSneValidParams { +pub struct TSneValidParams { embedding_size: usize, approx_threshold: F, perplexity: F, max_iter: usize, preliminary_iter: Option, - rng: R, + metric: D, } -impl TSneValidParams { +impl TSneValidParams { pub fn embedding_size(&self) -> usize { self.embedding_size } @@ -62,45 +61,46 @@ impl TSneValidParams { &self.preliminary_iter } - pub fn rng(&self) -> &R { - &self.rng + pub fn metric(&self) -> &D { + &self.metric } } #[derive(Debug, Clone, PartialEq)] -pub struct TSneParams(TSneValidParams); +pub struct TSneParams(TSneValidParams); -impl TSneParams { +impl TSneParams { /// Create a t-SNE param set with given embedding size /// /// # Defaults to: /// * `approx_threshold`: 0.5 /// * `perplexity`: 5.0 /// * `max_iter`: 2000 - /// * `rng`: SmallRng with seed 42 - pub fn embedding_size(embedding_size: usize) -> TSneParams { - Self::embedding_size_with_rng(embedding_size, SmallRng::seed_from_u64(42)) + pub fn embedding_size(embedding_size: usize) -> TSneParams { + Self::embedding_size_with_metric(embedding_size, linfa_nn::distance::L2Dist) } } -impl TSneParams { - /// Create a t-SNE param set with given embedding size and random number generator +impl> TSneParams { + /// Create a t-SNE param set with given embedding size and distance metric /// /// # Defaults to: /// * `approx_threshold`: 0.5 /// * `perplexity`: 5.0 /// * `max_iter`: 2000 - pub fn embedding_size_with_rng(embedding_size: usize, rng: R) -> TSneParams { + pub fn embedding_size_with_metric(embedding_size: usize, metric: D) -> Self { Self(TSneValidParams { embedding_size, - rng, approx_threshold: F::cast(0.5), perplexity: F::cast(5.0), max_iter: 2000, preliminary_iter: None, + metric, }) } +} +impl TSneParams { /// Set the approximation threshold of the Barnes Hut algorithm /// /// The threshold decides whether a cluster centroid can be used as a summary for the whole @@ -139,8 +139,8 @@ impl TSneParams { } } -impl ParamGuard for TSneParams { - type Checked = TSneValidParams; +impl ParamGuard for TSneParams { + type Checked = TSneValidParams; type Error = TSneError; /// Validates parameters diff --git a/algorithms/linfa-tsne/src/lib.rs b/algorithms/linfa-tsne/src/lib.rs index 889ec0b20..de3914565 100644 --- a/algorithms/linfa-tsne/src/lib.rs +++ b/algorithms/linfa-tsne/src/lib.rs @@ -1,8 +1,8 @@ #![doc = include_str!("../README.md")] +use std::convert::TryFrom; -use ndarray::Array2; -use ndarray_rand::rand::Rng; -use ndarray_rand::rand_distr::Normal; +use linfa_nn::distance::Distance; +use ndarray::{Array2, ArrayView1}; use linfa::{dataset::DatasetBase, traits::Transformer, Float, ParamGuard}; @@ -12,8 +12,8 @@ mod hyperparams; pub use error::{Result, TSneError}; pub use hyperparams::{TSneParams, TSneValidParams}; -impl Transformer, Result>> for TSneValidParams { - fn transform(&self, mut data: Array2) -> Result> { +impl> Transformer, Result>> for TSneValidParams { + fn transform(&self, data: Array2) -> Result> { let (nfeatures, nsamples) = (data.ncols(), data.nrows()); // validate parameter-data constraints @@ -21,6 +21,10 @@ impl Transformer, Result>> for TSn return Err(TSneError::EmbeddingSizeTooLarge); } + let Ok(embedding_size) = u8::try_from(self.embedding_size()) else { + return Err(TSneError::EmbeddingSizeTooLarge); + }; + if F::cast(nsamples - 1) < F::cast(3) * self.perplexity() { return Err(TSneError::PerplexityTooLarge); } @@ -31,43 +35,47 @@ impl Transformer, Result>> for TSn None => usize::min(self.max_iter() / 2, 250), }; - let data = data.as_slice_mut().unwrap(); - - let mut rng = self.rng().clone(); - let normal = Normal::new(0.0, 1e-4 * 10e-4).unwrap(); - - let mut embedding: Vec = (0..nsamples * self.embedding_size()) - .map(|_| rng.sample(normal)) - .map(F::cast) - .collect(); - - bhtsne::run( - data, - nsamples, - nfeatures, - &mut embedding, - self.embedding_size(), - self.perplexity(), - self.approx_threshold(), - true, - self.max_iter() as u64, - preliminary_iter as u64, - preliminary_iter as u64, - ); + let data: Vec<_> = data.as_slice().unwrap().chunks(nfeatures).collect(); + + let mut tsne = bhtsne::tSNE::new(&data); + let tsne = tsne + .embedding_dim(embedding_size) + .perplexity(self.perplexity()) + .epochs(self.max_iter()) + .stop_lying_epoch(preliminary_iter) + .momentum_switch_epoch(preliminary_iter); + + let tsne = if self.approx_threshold() <= F::zero() { + // compute exact t-SNE + tsne.exact(|a, b| { + let a = ArrayView1::from(a); + let b = ArrayView1::from(b); + self.metric().distance(a, b) + }) + } else { + // compute barnes-hut t-SNE + tsne.barnes_hut(self.approx_threshold(), |a, b| { + let a = ArrayView1::from(a); + let b = ArrayView1::from(b); + self.metric().distance(a, b) + }) + }; + + let embedding = tsne.embedding(); Array2::from_shape_vec((nsamples, self.embedding_size()), embedding).map_err(|e| e.into()) } } -impl Transformer, Result>> for TSneParams { +impl> Transformer, Result>> for TSneParams { fn transform(&self, x: Array2) -> Result> { self.check_ref()?.transform(x) } } -impl +impl> Transformer, T>, Result, T>>> - for TSneValidParams + for TSneValidParams { fn transform(&self, ds: DatasetBase, T>) -> Result, T>> { let DatasetBase { @@ -82,8 +90,8 @@ impl } } -impl - Transformer, T>, Result, T>>> for TSneParams +impl> + Transformer, T>, Result, T>>> for TSneParams { fn transform(&self, ds: DatasetBase, T>) -> Result, T>> { self.check_ref()?.transform(ds) @@ -103,17 +111,16 @@ mod tests { #[test] fn autotraits() { fn has_autotraits() {} - has_autotraits::>>(); - has_autotraits::>>(); + has_autotraits::>(); + has_autotraits::>(); has_autotraits::(); } #[test] fn iris_separate() -> Result<()> { let ds = linfa_datasets::iris(); - let rng = SmallRng::seed_from_u64(42); - let ds = TSneParams::embedding_size_with_rng(2, rng) + let ds = TSneParams::embedding_size(2) .perplexity(10.0) .approx_threshold(0.0) .transform(ds)?; @@ -123,6 +130,19 @@ mod tests { Ok(()) } + #[test] + fn iris_separate_bharnes_hut() -> Result<()> { + let ds = linfa_datasets::iris(); + + let ds = TSneParams::embedding_size(2) + .perplexity(10.0) + .transform(ds)?; + + assert!(ds.silhouette_score()? > 0.5); + + Ok(()) + } + #[test] fn blob_separate() -> Result<()> { let mut rng = SmallRng::seed_from_u64(42); @@ -137,7 +157,7 @@ mod tests { let targets = (0..200).map(|x| x < 100).collect::>(); let dataset = Dataset::new(entries, targets); - let ds = TSneParams::embedding_size_with_rng(2, rng) + let ds = TSneParams::embedding_size(2) .perplexity(60.0) .approx_threshold(0.0) .transform(dataset)?;