diff --git a/datasketches/src/error.rs b/datasketches/src/error.rs index 31559d3..59ff32b 100644 --- a/datasketches/src/error.rs +++ b/datasketches/src/error.rs @@ -124,6 +124,12 @@ impl Error { "invalid preamble longs: expected {expected}, got {actual}" )) } + + pub(crate) fn invalid_preamble_ints(expected: u8, actual: u8) -> Self { + Self::deserial(format!( + "invalid preamble ints: expected {expected}, got {actual}" + )) + } } impl fmt::Debug for Error { diff --git a/datasketches/src/kll/helper.rs b/datasketches/src/kll/helper.rs new file mode 100644 index 0000000..4004150 --- /dev/null +++ b/datasketches/src/kll/helper.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::Cell; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +const POWERS_OF_THREE: [u64; 31] = [ + 1, + 3, + 9, + 27, + 81, + 243, + 729, + 2187, + 6561, + 19683, + 59049, + 177147, + 531441, + 1594323, + 4782969, + 14348907, + 43046721, + 129140163, + 387420489, + 1162261467, + 3486784401, + 10460353203, + 31381059609, + 94143178827, + 282429536481, + 847288609443, + 2541865828329, + 7625597484987, + 22876792454961, + 68630377364883, + 205891132094649, +]; + +pub fn compute_total_capacity(k: u16, m: u8, num_levels: usize) -> u32 { + let mut total: u32 = 0; + for level in 0..num_levels { + total += level_capacity(k, num_levels, level, m); + } + total +} + +pub fn level_capacity(k: u16, num_levels: usize, height: usize, min_wid: u8) -> u32 { + assert!(height < num_levels, "height must be < num_levels"); + let depth = num_levels - height - 1; + let cap = int_cap_aux(k, depth as u8); + std::cmp::max(min_wid as u32, cap as u32) +} + +pub fn int_cap_aux(k: u16, depth: u8) -> u16 { + if depth > 60 { + panic!("depth must be <= 60"); + } + if depth <= 30 { + return int_cap_aux_aux(k, depth); + } + let half = depth / 2; + let rest = depth - half; + let tmp = int_cap_aux_aux(k, half); + int_cap_aux_aux(tmp, rest) +} + +pub fn int_cap_aux_aux(k: u16, depth: u8) -> u16 { + if depth > 30 { + panic!("depth must be <= 30"); + } + let twok = (k as u64) << 1; + let tmp = (twok << depth) / POWERS_OF_THREE[depth as usize]; + let result = (tmp + 1) >> 1; + assert!(result <= k as u64, "capacity result exceeds k"); + result as u16 +} + +pub fn sum_the_sample_weights(level_sizes: &[usize]) -> u64 { + let mut total = 0u64; + let mut weight = 1u64; + for &size in level_sizes { + total += weight * size as u64; + weight <<= 1; + } + total +} + +fn seed() -> u64 { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + nanos as u64 +} + +pub fn random_bit() -> u32 { + thread_local! { + static RNG_STATE: Cell = Cell::new(seed()); + } + + RNG_STATE.with(|state| { + let mut x = state.get(); + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + state.set(x); + (x & 1) as u32 + }) +} diff --git a/datasketches/src/kll/mod.rs b/datasketches/src/kll/mod.rs new file mode 100644 index 0000000..5c27c0f --- /dev/null +++ b/datasketches/src/kll/mod.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! KLL sketch implementation for estimating quantiles and ranks. +//! +//! KLL is a compact, streaming quantiles sketch with lazy compaction and +//! near-optimal accuracy per retained item. It supports one-pass updates, +//! approximate quantiles, ranks, PMF, and CDF queries. +//! +//! This implementation follows Apache DataSketches semantics (Java KllSketch +//! / KllPreambleUtil, C++ kll_sketch) and uses the same binary serialization +//! format as those implementations. +//! +//! # Usage +//! +//! ```rust +//! # use datasketches::kll::KllSketch; +//! let mut sketch = KllSketch::::new(200); +//! sketch.update(1.0); +//! sketch.update(2.0); +//! let q = sketch.quantile(0.5, true).unwrap(); +//! assert!(q >= 1.0 && q <= 2.0); +//! ``` + +mod helper; +mod serialization; +mod sketch; +mod sorted_view; + +pub use self::sketch::KllSketch; + +/// KLL sketch specialized for `f64`. +pub type KllSketchF64 = KllSketch; +/// KLL sketch specialized for `f32`. +pub type KllSketchF32 = KllSketch; +/// KLL sketch specialized for `i64`. +pub type KllSketchI64 = KllSketch; +/// KLL sketch specialized for `String`. +pub type KllSketchString = KllSketch; + +/// Default value of parameter k. +pub const DEFAULT_K: u16 = 200; +/// Default value of parameter m. +pub const DEFAULT_M: u8 = 8; +/// Minimum value of parameter k. +pub const MIN_K: u16 = DEFAULT_M as u16; +/// Maximum value of parameter k. +pub const MAX_K: u16 = u16::MAX; diff --git a/datasketches/src/kll/serialization.rs b/datasketches/src/kll/serialization.rs new file mode 100644 index 0000000..998add5 --- /dev/null +++ b/datasketches/src/kll/serialization.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Binary serialization format constants for KLL sketches. +//! +//! Naming and layout follow the Apache DataSketches Java implementation +//! (`KllPreambleUtil`) and the C++ `kll_sketch` serialization format. + +/// Family ID for KLL sketches in DataSketches format (KllPreambleUtil.KLL_FAMILY). +pub const KLL_FAMILY_ID: u8 = 15; + +/// Serialization version for empty or full sketches (KllPreambleUtil.SERIAL_VERSION_EMPTY_FULL). +pub const SERIAL_VERSION_1: u8 = 1; +/// Serialization version for single-item sketches (KllPreambleUtil.SERIAL_VERSION_SINGLE). +pub const SERIAL_VERSION_2: u8 = 2; + +/// Preamble ints for empty and single-item sketches (KllPreambleUtil.PREAMBLE_INTS_EMPTY_SINGLE). +pub const PREAMBLE_INTS_SHORT: u8 = 2; +/// Preamble ints for sketches with more than one item (KllPreambleUtil.PREAMBLE_INTS_FULL). +pub const PREAMBLE_INTS_FULL: u8 = 5; + +/// Flag indicating the sketch is empty (KllPreambleUtil.EMPTY_BIT_MASK). +pub const FLAG_EMPTY: u8 = 1 << 0; +/// Flag indicating level zero is sorted (KllPreambleUtil.LEVEL_ZERO_SORTED_BIT_MASK). +pub const FLAG_LEVEL_ZERO_SORTED: u8 = 1 << 1; +/// Flag indicating the sketch has a single item (KllPreambleUtil.SINGLE_ITEM_BIT_MASK). +pub const FLAG_SINGLE_ITEM: u8 = 1 << 2; + +/// Serialized size for an empty sketch in bytes (KllPreambleUtil.DATA_START_ADR_SINGLE_ITEM). +pub const EMPTY_SIZE_BYTES: usize = 8; +/// Data offset for single-item sketches (KllPreambleUtil.DATA_START_ADR_SINGLE_ITEM). +pub const DATA_START_SINGLE_ITEM: usize = 8; +/// Data offset for sketches with more than one item (KllPreambleUtil.DATA_START_ADR). +pub const DATA_START: usize = 20; diff --git a/datasketches/src/kll/sketch.rs b/datasketches/src/kll/sketch.rs new file mode 100644 index 0000000..ace1476 --- /dev/null +++ b/datasketches/src/kll/sketch.rs @@ -0,0 +1,933 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; + +use super::DEFAULT_K; +use super::DEFAULT_M; +use super::MAX_K; +use super::MIN_K; +use super::helper::compute_total_capacity; +use super::helper::level_capacity; +use super::helper::random_bit; +use super::helper::sum_the_sample_weights; +use super::serialization::DATA_START; +use super::serialization::DATA_START_SINGLE_ITEM; +use super::serialization::EMPTY_SIZE_BYTES; +use super::serialization::FLAG_EMPTY; +use super::serialization::FLAG_LEVEL_ZERO_SORTED; +use super::serialization::FLAG_SINGLE_ITEM; +use super::serialization::KLL_FAMILY_ID; +use super::serialization::PREAMBLE_INTS_FULL; +use super::serialization::PREAMBLE_INTS_SHORT; +use super::serialization::SERIAL_VERSION_1; +use super::serialization::SERIAL_VERSION_2; +use super::sorted_view::build_sorted_view; +use crate::codec::SketchBytes; +use crate::codec::SketchSlice; +use crate::error::Error; + +/// Trait implemented by item types supported by [`KllSketch`]. +/// +/// Implementations must provide a total ordering via `cmp`. +/// For floating-point types, ensure `cmp` handles NaN consistently and `is_nan` +/// returns true for values that should be ignored by updates. +pub trait KllItem: Clone { + /// Compare two items. + fn cmp(a: &Self, b: &Self) -> Ordering; + + /// Returns true if the item is NaN. + fn is_nan(_value: &Self) -> bool { + false + } +} + +pub(crate) trait KllSerde: KllItem { + /// Serialized size in bytes. + fn serialized_size(value: &Self) -> usize; + + /// Serialize a single item into the buffer. + fn serialize(value: &Self, bytes: &mut SketchBytes); + + /// Deserialize a single item from the input. + fn deserialize(input: &mut SketchSlice<'_>) -> Result; +} + +/// KLL sketch for estimating quantiles and ranks. +/// +/// See the [kll module level documentation](crate::kll) for more. +#[derive(Debug, Clone, PartialEq)] +pub struct KllSketch { + k: u16, + m: u8, + min_k: u16, + n: u64, + is_level_zero_sorted: bool, + levels: Vec>, + min_item: Option, + max_item: Option, +} + +impl Default for KllSketch { + fn default() -> Self { + Self::new(DEFAULT_K) + } +} + +impl KllSketch { + /// Creates a new sketch with the given value of k. + /// + /// # Panics + /// + /// Panics if k is not in [MIN_K, MAX_K]. + /// + /// # Examples + /// + /// ``` + /// # use datasketches::kll::KllSketch; + /// let sketch = KllSketch::::new(200); + /// assert_eq!(sketch.k(), 200); + /// ``` + pub fn new(k: u16) -> Self { + assert!( + (MIN_K..=MAX_K).contains(&k), + "k must be in [{MIN_K}, {MAX_K}], got {k}" + ); + Self { + k, + m: DEFAULT_M, + min_k: k, + n: 0, + is_level_zero_sorted: false, + levels: vec![Vec::new()], + min_item: None, + max_item: None, + } + } + + /// Returns parameter k used to configure this sketch. + pub fn k(&self) -> u16 { + self.k + } + + /// Returns the minimum k used when merging sketches. + pub fn min_k(&self) -> u16 { + self.min_k + } + + /// Returns total weight of the stream. + pub fn n(&self) -> u64 { + self.n + } + + /// Returns true if the sketch has not seen any data. + pub fn is_empty(&self) -> bool { + self.n == 0 + } + + /// Returns the number of retained items. + pub fn num_retained(&self) -> usize { + self.levels.iter().map(|level| level.len()).sum() + } + + /// Returns true if the sketch is in estimation mode. + pub fn is_estimation_mode(&self) -> bool { + self.levels.len() > 1 + } + + /// Returns the minimum item seen by the sketch. + pub fn min_item(&self) -> Option<&T> { + self.min_item.as_ref() + } + + /// Returns the maximum item seen by the sketch. + pub fn max_item(&self) -> Option<&T> { + self.max_item.as_ref() + } + + /// Updates the sketch with a new item. + /// + /// NaN values are ignored for floating-point types. + pub fn update(&mut self, item: T) { + if T::is_nan(&item) { + return; + } + self.update_min_max(&item); + self.internal_update(item); + } + + /// Merges another sketch into this one. + /// + /// # Panics + /// + /// Panics if the sketches have incompatible parameters. + pub fn merge(&mut self, other: &KllSketch) { + if other.is_empty() { + return; + } + + assert_eq!( + self.m, other.m, + "incompatible m values: {} and {}", + self.m, other.m + ); + + self.update_min_max_from_other(other); + + let final_n = self.n + other.n; + for item in &other.levels[0] { + self.internal_update(item.clone()); + } + + if other.levels.len() >= 2 { + self.merge_higher_levels(other); + } + + self.n = final_n; + if other.is_estimation_mode() { + self.min_k = self.min_k.min(other.min_k); + } + + debug_assert_eq!(self.total_weight(), self.n, "total weight does not match n"); + } + + /// Returns the normalized rank of the given item. + pub fn rank(&self, item: &T, inclusive: bool) -> Option { + if self.is_empty() { + return None; + } + let view = build_sorted_view(&self.levels); + Some(view.rank(item, inclusive)) + } + + /// Returns the quantile for the given normalized rank. + /// + /// # Panics + /// + /// Panics if rank is not in [0.0, 1.0]. + pub fn quantile(&self, rank: f64, inclusive: bool) -> Option { + if self.is_empty() { + return None; + } + assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]"); + let view = build_sorted_view(&self.levels); + Some(view.quantile(rank, inclusive)) + } + + /// Returns the approximate CDF for the given split points. + pub fn cdf(&self, split_points: &[T], inclusive: bool) -> Option> { + if self.is_empty() { + return None; + } + let view = build_sorted_view(&self.levels); + Some(view.cdf(split_points, inclusive)) + } + + /// Returns the approximate PMF for the given split points. + pub fn pmf(&self, split_points: &[T], inclusive: bool) -> Option> { + if self.is_empty() { + return None; + } + let view = build_sorted_view(&self.levels); + Some(view.pmf(split_points, inclusive)) + } + + /// Returns normalized rank error for the configured k. + pub fn normalized_rank_error(&self, pmf: bool) -> f64 { + normalized_rank_error(self.min_k, pmf) + } +} + +fn serialized_size(sketch: &KllSketch) -> usize { + if sketch.is_empty() { + return EMPTY_SIZE_BYTES; + } + if sketch.n == 1 { + let item = &sketch.levels[0][0]; + return DATA_START_SINGLE_ITEM + T::serialized_size(item); + } + + let mut size = DATA_START + sketch.levels.len() * 4; + if let Some(min_item) = &sketch.min_item { + size += T::serialized_size(min_item); + } + if let Some(max_item) = &sketch.max_item { + size += T::serialized_size(max_item); + } + for level in &sketch.levels { + for item in level { + size += T::serialized_size(item); + } + } + size +} + +fn serialize_with_serde(sketch: &KllSketch) -> Vec { + let size = serialized_size(sketch); + let mut bytes = SketchBytes::with_capacity(size); + + let is_empty = sketch.is_empty(); + let is_single_item = sketch.n == 1; + + let preamble_ints = if is_empty || is_single_item { + PREAMBLE_INTS_SHORT + } else { + PREAMBLE_INTS_FULL + }; + let serial_version = if is_single_item { + SERIAL_VERSION_2 + } else { + SERIAL_VERSION_1 + }; + + let flags = (if is_empty { FLAG_EMPTY } else { 0 }) + | (if sketch.is_level_zero_sorted { + FLAG_LEVEL_ZERO_SORTED + } else { + 0 + }) + | (if is_single_item { FLAG_SINGLE_ITEM } else { 0 }); + + bytes.write_u8(preamble_ints); + bytes.write_u8(serial_version); + bytes.write_u8(KLL_FAMILY_ID); + bytes.write_u8(flags); + bytes.write_u16_le(sketch.k); + bytes.write_u8(sketch.m); + bytes.write_u8(0); + + if is_empty { + return bytes.into_bytes(); + } + + if !is_single_item { + bytes.write_u64_le(sketch.n); + bytes.write_u16_le(sketch.min_k); + bytes.write_u8(sketch.levels.len() as u8); + bytes.write_u8(0); + + let level_offsets = sketch.level_offsets(); + for offset in level_offsets.iter().take(sketch.levels.len()) { + bytes.write_u32_le(*offset); + } + + if let Some(min_item) = &sketch.min_item { + T::serialize(min_item, &mut bytes); + } + if let Some(max_item) = &sketch.max_item { + T::serialize(max_item, &mut bytes); + } + } + + for level in &sketch.levels { + for item in level { + T::serialize(item, &mut bytes); + } + } + + bytes.into_bytes() +} + +fn deserialize_with_serde(bytes: &[u8]) -> Result, Error> { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { + move |_| Error::insufficient_data(tag) + } + + let mut cursor = SketchSlice::new(bytes); + + let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let k = cursor.read_u16_le().map_err(make_error("k"))?; + let m = cursor.read_u8().map_err(make_error("m"))?; + let _unused = cursor.read_u8().map_err(make_error("unused"))?; + + if m != DEFAULT_M { + return Err(Error::deserial(format!( + "invalid m: expected {DEFAULT_M}, got {m}" + ))); + } + if family_id != KLL_FAMILY_ID { + return Err(Error::invalid_family(KLL_FAMILY_ID, family_id, "KLL")); + } + let is_empty = (flags & FLAG_EMPTY) != 0; + let is_single_item = (flags & FLAG_SINGLE_ITEM) != 0; + let is_level_zero_sorted = (flags & FLAG_LEVEL_ZERO_SORTED) != 0; + if is_empty || is_single_item { + if preamble_ints != PREAMBLE_INTS_SHORT { + return Err(Error::invalid_preamble_ints( + PREAMBLE_INTS_SHORT, + preamble_ints, + )); + } + } else if preamble_ints != PREAMBLE_INTS_FULL { + return Err(Error::invalid_preamble_ints( + PREAMBLE_INTS_FULL, + preamble_ints, + )); + } + let expected_version = if is_single_item { + SERIAL_VERSION_2 + } else { + SERIAL_VERSION_1 + }; + if serial_version != expected_version { + return Err(Error::unsupported_serial_version( + expected_version, + serial_version, + )); + } + + if !(MIN_K..=MAX_K).contains(&k) { + return Err(Error::deserial(format!("k out of range: {k}"))); + } + + if is_empty { + return Ok(KllSketch::make( + k, + k, + 0, + vec![Vec::new()], + None, + None, + is_level_zero_sorted, + )); + } + + let (n, min_k, num_levels) = if is_single_item { + (1u64, k, 1usize) + } else { + let n = cursor.read_u64_le().map_err(make_error("n"))?; + let min_k = cursor.read_u16_le().map_err(make_error("min_k"))?; + let num_levels = cursor.read_u8().map_err(make_error("num_levels"))?; + let _unused = cursor.read_u8().map_err(make_error("unused2"))?; + (n, min_k, num_levels as usize) + }; + + if num_levels == 0 { + return Err(Error::deserial("num_levels must be > 0")); + } + if min_k < MIN_K || min_k > k { + return Err(Error::deserial(format!( + "min_k must be in [{MIN_K}, {k}], got {min_k}" + ))); + } + + let capacity = compute_total_capacity(k, m, num_levels) as u32; + let mut level_offsets = Vec::with_capacity(num_levels + 1); + if !is_single_item { + for _ in 0..num_levels { + let offset = cursor.read_u32_le().map_err(make_error("levels"))?; + level_offsets.push(offset); + } + } else { + level_offsets.push(capacity - 1); + } + level_offsets.push(capacity); + + if level_offsets.is_empty() { + return Err(Error::deserial("levels array is empty")); + } + if level_offsets[0] > capacity { + return Err(Error::deserial("levels[0] exceeds capacity")); + } + for window in level_offsets.windows(2) { + if window[1] < window[0] { + return Err(Error::deserial("levels array must be non-decreasing")); + } + } + let last = *level_offsets.last().unwrap(); + if last != capacity { + return Err(Error::deserial("levels last offset must equal capacity")); + } + + let min_item = if is_single_item { + None + } else { + Some(T::deserialize(&mut cursor)?) + }; + let max_item = if is_single_item { + None + } else { + Some(T::deserialize(&mut cursor)?) + }; + + let mut levels = Vec::with_capacity(num_levels); + for level in 0..num_levels { + let size = (level_offsets[level + 1] - level_offsets[level]) as usize; + let mut items = Vec::with_capacity(size); + for _ in 0..size { + items.push(T::deserialize(&mut cursor)?); + } + levels.push(items); + } + + let mut sketch = KllSketch::make( + k, + min_k, + n, + levels, + min_item, + max_item, + is_level_zero_sorted, + ); + + if is_single_item { + if let Some(item) = sketch.levels[0].first().cloned() { + sketch.min_item = Some(item.clone()); + sketch.max_item = Some(item); + } + } + + Ok(sketch) +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) + } +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) + } +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) + } +} + +impl KllSketch { + /// Serializes the sketch to bytes. + pub fn serialize(&self) -> Vec { + serialize_with_serde(self) + } + + /// Deserializes a sketch from bytes. + pub fn deserialize(bytes: &[u8]) -> Result { + deserialize_with_serde(bytes) + } +} + +impl KllSketch { + fn make( + k: u16, + min_k: u16, + n: u64, + levels: Vec>, + min_item: Option, + max_item: Option, + is_level_zero_sorted: bool, + ) -> Self { + Self { + k, + m: DEFAULT_M, + min_k, + n, + is_level_zero_sorted, + levels, + min_item, + max_item, + } + } + + fn capacity(&self) -> usize { + compute_total_capacity(self.k, self.m, self.levels.len()) as usize + } + + fn level_offsets(&self) -> Vec { + let capacity = self.capacity() as u32; + let retained = self.num_retained() as u32; + assert!(capacity >= retained, "capacity must be >= retained"); + + let mut offsets = Vec::with_capacity(self.levels.len() + 1); + let mut offset = capacity - retained; + offsets.push(offset); + for level in &self.levels { + offset += level.len() as u32; + offsets.push(offset); + } + offsets + } + + fn update_min_max(&mut self, item: &T) { + match self.min_item.as_ref() { + None => { + self.min_item = Some(item.clone()); + self.max_item = Some(item.clone()); + } + Some(min) => { + if T::cmp(item, min) == Ordering::Less { + self.min_item = Some(item.clone()); + } + if let Some(max) = &self.max_item { + if T::cmp(max, item) == Ordering::Less { + self.max_item = Some(item.clone()); + } + } + } + } + } + + fn update_min_max_from_other(&mut self, other: &KllSketch) { + match (&self.min_item, &self.max_item) { + (None, None) => { + self.min_item = other.min_item.clone(); + self.max_item = other.max_item.clone(); + } + (Some(min), Some(max)) => { + if let Some(other_min) = &other.min_item { + if T::cmp(other_min, min) == Ordering::Less { + self.min_item = Some(other_min.clone()); + } + } + if let Some(other_max) = &other.max_item { + if T::cmp(max, other_max) == Ordering::Less { + self.max_item = Some(other_max.clone()); + } + } + } + _ => { + self.min_item = other.min_item.clone(); + self.max_item = other.max_item.clone(); + } + } + } + + fn internal_update(&mut self, item: T) { + if self.num_retained() >= self.capacity() { + self.compress_while_updating(); + } + self.n += 1; + self.is_level_zero_sorted = false; + self.levels[0].insert(0, item); + } + + fn compress_while_updating(&mut self) { + let level = self.find_level_to_compact(); + if level + 1 == self.levels.len() { + self.levels.push(Vec::new()); + } + + let mut current = std::mem::take(&mut self.levels[level]); + let mut above = std::mem::take(&mut self.levels[level + 1]); + + let odd = current.len() % 2 == 1; + let mut leftover = None; + if odd { + leftover = Some(current.remove(0)); + } + + if level == 0 && !self.is_level_zero_sorted { + current.sort_by(T::cmp); + } + + let use_up = above.is_empty(); + let promoted = downsample(current, random_bit(), use_up); + if above.is_empty() { + above = promoted; + } else { + above = merge_sorted_vec(promoted, above); + } + self.levels[level + 1] = above; + + let mut new_level = Vec::new(); + if let Some(item) = leftover { + new_level.push(item); + } + self.levels[level] = new_level; + } + + fn find_level_to_compact(&self) -> usize { + let num_levels = self.levels.len(); + for level in 0..num_levels { + let pop = self.levels[level].len() as u32; + let cap = level_capacity(self.k, num_levels, level, self.m); + if pop >= cap { + return level; + } + } + panic!("no level to compact"); + } + + fn merge_higher_levels(&mut self, other: &KllSketch) { + let provisional_levels = self.levels.len().max(other.levels.len()); + let mut self_levels = std::mem::take(&mut self.levels); + let mut work_levels = vec![Vec::new(); provisional_levels]; + work_levels[0] = std::mem::take(&mut self_levels[0]); + + for level in 1..provisional_levels { + let left = if level < self_levels.len() { + std::mem::take(&mut self_levels[level]) + } else { + Vec::new() + }; + let right = other.levels.get(level).cloned().unwrap_or_default(); + + work_levels[level] = if left.is_empty() { + right + } else if right.is_empty() { + left + } else { + merge_sorted_vec(left, right) + }; + } + + self.levels = general_compress(work_levels, self.k, self.m, self.is_level_zero_sorted); + } + + fn total_weight(&self) -> u64 { + let sizes: Vec = self.levels.iter().map(|level| level.len()).collect(); + sum_the_sample_weights(&sizes) + } +} + +fn normalized_rank_error(k: u16, pmf: bool) -> f64 { + let k = k as f64; + if pmf { + 2.446 / k.powf(0.9433) + } else { + 2.296 / k.powf(0.9723) + } +} + +fn downsample(items: Vec, offset: u32, use_up: bool) -> Vec { + let len = items.len(); + debug_assert!(len % 2 == 0, "length must be even"); + let offset = (offset & 1) as usize; + let parity = if use_up { + (len - 1 - offset) % 2 + } else { + offset + }; + + items + .into_iter() + .enumerate() + .filter_map(|(idx, item)| if idx % 2 == parity { Some(item) } else { None }) + .collect() +} + +fn merge_sorted_vec(left: Vec, right: Vec) -> Vec { + let mut merged = Vec::with_capacity(left.len() + right.len()); + let mut left_iter = left.into_iter().peekable(); + let mut right_iter = right.into_iter().peekable(); + + while let (Some(l), Some(r)) = (left_iter.peek(), right_iter.peek()) { + if T::cmp(l, r) == Ordering::Less { + merged.push(left_iter.next().unwrap()); + } else { + merged.push(right_iter.next().unwrap()); + } + } + merged.extend(left_iter); + merged.extend(right_iter); + merged +} + +fn general_compress( + mut levels_in: Vec>, + k: u16, + m: u8, + is_level_zero_sorted: bool, +) -> Vec> { + let mut current_num_levels = levels_in.len(); + let mut current_item_count: usize = levels_in.iter().map(|level| level.len()).sum(); + let mut target_item_count = compute_total_capacity(k, m, current_num_levels) as usize; + let mut levels_out = Vec::with_capacity(current_num_levels + 1); + + let mut current_level = 0usize; + while current_level < current_num_levels { + if current_level + 1 >= levels_in.len() { + levels_in.push(Vec::new()); + } + + let raw_pop = levels_in[current_level].len(); + let cap = level_capacity(k, current_num_levels, current_level, m) as usize; + + if current_item_count < target_item_count || raw_pop < cap { + levels_out.push(std::mem::take(&mut levels_in[current_level])); + } else { + let mut current = std::mem::take(&mut levels_in[current_level]); + let mut above = std::mem::take(&mut levels_in[current_level + 1]); + + let odd = current.len() % 2 == 1; + let mut leftover = None; + if odd { + leftover = Some(current.remove(0)); + } + + if current_level == 0 && !is_level_zero_sorted { + current.sort_by(T::cmp); + } + + let use_up = above.is_empty(); + let promoted = downsample(current, random_bit(), use_up); + let promoted_len = promoted.len(); + if above.is_empty() { + above = promoted; + } else { + above = merge_sorted_vec(promoted, above); + } + levels_in[current_level + 1] = above; + + let mut out_level = Vec::new(); + if let Some(item) = leftover { + out_level.push(item); + } + levels_out.push(out_level); + + current_item_count = current_item_count.saturating_sub(promoted_len); + + if current_level == current_num_levels - 1 { + current_num_levels += 1; + target_item_count += level_capacity(k, current_num_levels, 0, m) as usize; + if levels_in.len() < current_num_levels + 1 { + levels_in.resize_with(current_num_levels + 1, Vec::new); + } + } + } + current_level += 1; + } + + levels_out.truncate(current_num_levels); + levels_out +} + +impl KllItem for f32 { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.partial_cmp(b).unwrap_or(Ordering::Greater) + } + + fn is_nan(value: &Self) -> bool { + value.is_nan() + } +} + +impl KllSerde for f32 { + fn serialized_size(_value: &Self) -> usize { + 4 + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_f32_le(*value); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + input + .read_f32_le() + .map_err(|_| Error::insufficient_data("f32")) + } +} + +impl KllItem for f64 { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.partial_cmp(b).unwrap_or(Ordering::Greater) + } + + fn is_nan(value: &Self) -> bool { + value.is_nan() + } +} + +impl KllSerde for f64 { + fn serialized_size(_value: &Self) -> usize { + 8 + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_f64_le(*value); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + input + .read_f64_le() + .map_err(|_| Error::insufficient_data("f64")) + } +} + +impl KllItem for i64 { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.cmp(b) + } +} + +impl KllSerde for i64 { + fn serialized_size(_value: &Self) -> usize { + 8 + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_i64_le(*value); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + input + .read_i64_le() + .map_err(|_| Error::insufficient_data("i64")) + } +} + +impl KllItem for String { + fn cmp(a: &Self, b: &Self) -> Ordering { + a.cmp(b) + } +} + +impl KllSerde for String { + fn serialized_size(value: &Self) -> usize { + 4 + value.len() + } + + fn serialize(value: &Self, bytes: &mut SketchBytes) { + bytes.write_u32_le(value.len() as u32); + bytes.write(value.as_bytes()); + } + + fn deserialize(input: &mut SketchSlice<'_>) -> Result { + let len = input + .read_u32_le() + .map_err(|_| Error::insufficient_data("string_len"))? as usize; + let mut buf = vec![0u8; len]; + input + .read_exact(&mut buf) + .map_err(|_| Error::insufficient_data("string_bytes"))?; + String::from_utf8(buf).map_err(|_| Error::deserial("invalid utf-8 string")) + } +} diff --git a/datasketches/src/kll/sorted_view.rs b/datasketches/src/kll/sorted_view.rs new file mode 100644 index 0000000..655fd05 --- /dev/null +++ b/datasketches/src/kll/sorted_view.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; + +use super::sketch::KllItem; + +#[derive(Debug, Clone)] +pub(crate) struct SortedView { + entries: Vec>, + total_weight: u64, +} + +#[derive(Debug, Clone)] +struct Entry { + item: T, + weight: u64, +} + +impl SortedView { + fn new(mut entries: Vec>) -> Self { + entries.sort_by(|a, b| T::cmp(&a.item, &b.item)); + let mut total_weight = 0u64; + for entry in &mut entries { + total_weight += entry.weight; + entry.weight = total_weight; + } + Self { + entries, + total_weight, + } + } + + pub fn rank(&self, item: &T, inclusive: bool) -> f64 { + if self.entries.is_empty() { + return 0.0; + } + + let idx = if inclusive { + upper_bound(&self.entries, item) + } else { + lower_bound(&self.entries, item) + }; + + if idx == 0 { + return 0.0; + } + let weight = self.entries[idx - 1].weight; + weight as f64 / self.total_weight as f64 + } + + pub fn quantile(&self, rank: f64, inclusive: bool) -> T { + let weight = if inclusive { + (rank * self.total_weight as f64).ceil() as u64 + } else { + (rank * self.total_weight as f64) as u64 + }; + + let idx = if inclusive { + lower_bound_by_weight(&self.entries, weight) + } else { + upper_bound_by_weight(&self.entries, weight) + }; + + if idx >= self.entries.len() { + return self.entries[self.entries.len() - 1].item.clone(); + } + self.entries[idx].item.clone() + } + + pub fn cdf(&self, split_points: &[T], inclusive: bool) -> Vec { + check_split_points(split_points); + let mut ranks = Vec::with_capacity(split_points.len() + 1); + for item in split_points { + ranks.push(self.rank(item, inclusive)); + } + ranks.push(1.0); + ranks + } + + pub fn pmf(&self, split_points: &[T], inclusive: bool) -> Vec { + let mut buckets = self.cdf(split_points, inclusive); + for i in (1..buckets.len()).rev() { + buckets[i] -= buckets[i - 1]; + } + buckets + } +} + +pub(crate) fn build_sorted_view(levels: &[Vec]) -> SortedView { + let num_retained: usize = levels.iter().map(|level| level.len()).sum(); + let mut entries = Vec::with_capacity(num_retained); + + for (level_idx, level) in levels.iter().enumerate() { + let weight = 1u64 << level_idx; + for item in level { + entries.push(Entry { + item: item.clone(), + weight, + }); + } + } + + SortedView::new(entries) +} + +#[track_caller] +fn check_split_points(split_points: &[T]) { + let len = split_points.len(); + if len == 1 && T::is_nan(&split_points[0]) { + panic!("split_points must not contain NaN values"); + } + for i in 0..len.saturating_sub(1) { + if T::is_nan(&split_points[i]) { + panic!("split_points must not contain NaN values"); + } + if T::cmp(&split_points[i], &split_points[i + 1]) == Ordering::Less { + continue; + } + panic!("split_points must be unique and monotonically increasing"); + } +} + +fn lower_bound(entries: &[Entry], item: &T) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if T::cmp(&entries[mid].item, item) == Ordering::Less { + left = mid + 1; + } else { + right = mid; + } + } + left +} + +fn upper_bound(entries: &[Entry], item: &T) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if T::cmp(&entries[mid].item, item) == Ordering::Greater { + right = mid; + } else { + left = mid + 1; + } + } + left +} + +fn lower_bound_by_weight(entries: &[Entry], weight: u64) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if entries[mid].weight < weight { + left = mid + 1; + } else { + right = mid; + } + } + left +} + +fn upper_bound_by_weight(entries: &[Entry], weight: u64) -> usize { + let mut left = 0usize; + let mut right = entries.len(); + while left < right { + let mid = left + (right - left) / 2; + if entries[mid].weight > weight { + right = mid; + } else { + left = mid + 1; + } + } + left +} diff --git a/datasketches/src/lib.rs b/datasketches/src/lib.rs index 009fd9e..9034d51 100644 --- a/datasketches/src/lib.rs +++ b/datasketches/src/lib.rs @@ -36,6 +36,7 @@ pub mod countmin; pub mod error; pub mod frequencies; pub mod hll; +pub mod kll; pub mod tdigest; pub mod theta; diff --git a/datasketches/tests/kll_serialization_test.rs b/datasketches/tests/kll_serialization_test.rs new file mode 100644 index 0000000..3c1dd57 --- /dev/null +++ b/datasketches/tests/kll_serialization_test.rs @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! KLL Sketch Serialization Compatibility Tests +//! +//! These tests verify binary compatibility with Apache DataSketches implementations: +//! - Java (datasketches-java) +//! - C++ (datasketches-cpp) +//! +//! Test data is generated by the reference implementations and stored in: +//! `tests/serialization_test_data/` + +mod common; + +use std::fs; +use std::path::PathBuf; + +use common::serialization_test_data; +use datasketches::kll::DEFAULT_K; +use datasketches::kll::KllSketch; + +fn test_f32_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K as usize, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + assert_eq!( + sketch.min_item().cloned(), + Some(1.0), + "min item mismatch in {}", + path.display() + ); + assert_eq!( + sketch.max_item().cloned(), + Some(expected_n as f32), + "max item mismatch in {}", + path.display() + ); + } +} + +fn test_f64_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K as usize, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + assert_eq!( + sketch.min_item().cloned(), + Some(1.0), + "min item mismatch in {}", + path.display() + ); + assert_eq!( + sketch.max_item().cloned(), + Some(expected_n as f64), + "max item mismatch in {}", + path.display() + ); + } +} + +fn test_i64_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K as usize, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + assert_eq!( + sketch.min_item().cloned(), + Some(1), + "min item mismatch in {}", + path.display() + ); + assert_eq!( + sketch.max_item().cloned(), + Some(expected_n as i64), + "max item mismatch in {}", + path.display() + ); + } +} + +fn parse_string_value(value: &str) -> u64 { + value + .trim_start() + .parse::() + .expect("string value should be numeric") +} + +fn test_string_file(path: PathBuf, expected_n: usize) { + let bytes = fs::read(&path).unwrap(); + let sketch = KllSketch::::deserialize(&bytes).unwrap(); + + assert_eq!(sketch.k(), DEFAULT_K, "wrong k in {}", path.display()); + assert_eq!( + sketch.n() as usize, + expected_n, + "wrong n in {}", + path.display() + ); + assert_eq!( + sketch.is_estimation_mode(), + expected_n > DEFAULT_K as usize, + "wrong estimation mode in {}", + path.display() + ); + assert_eq!( + sketch.is_empty(), + expected_n == 0, + "wrong empty flag in {}", + path.display() + ); + + if expected_n == 0 { + assert!(sketch.min_item().is_none(), "min should be None"); + assert!(sketch.max_item().is_none(), "max should be None"); + } else { + let min_item = sketch.min_item().expect("missing min item"); + let max_item = sketch.max_item().expect("missing max item"); + assert_eq!( + parse_string_value(min_item), + 1, + "min item mismatch in {}", + path.display() + ); + assert_eq!( + parse_string_value(max_item), + expected_n as u64, + "max item mismatch in {}", + path.display() + ); + } +} + +#[test] +fn test_java_kll_float_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_float_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_f32_file(path, n); + } +} + +#[test] +fn test_java_kll_double_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_double_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_f64_file(path, n); + } +} + +#[test] +fn test_java_kll_long_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_long_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_i64_file(path, n); + } +} + +#[test] +fn test_java_kll_string_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_string_n{n}_java.sk"); + let path = serialization_test_data("java_generated_files", &filename); + test_string_file(path, n); + } +} + +#[test] +fn test_cpp_kll_float_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_float_n{n}_cpp.sk"); + let path = serialization_test_data("cpp_generated_files", &filename); + test_f32_file(path, n); + } +} + +#[test] +fn test_cpp_kll_double_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_double_n{n}_cpp.sk"); + let path = serialization_test_data("cpp_generated_files", &filename); + test_f64_file(path, n); + } +} + +#[test] +fn test_cpp_kll_string_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10000, 100000, 1000000]; + for n in test_cases { + let filename = format!("kll_string_n{n}_cpp.sk"); + let path = serialization_test_data("cpp_generated_files", &filename); + test_string_file(path, n); + } +} diff --git a/datasketches/tests/kll_test.rs b/datasketches/tests/kll_test.rs new file mode 100644 index 0000000..08ceaa7 --- /dev/null +++ b/datasketches/tests/kll_test.rs @@ -0,0 +1,322 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datasketches::kll::DEFAULT_K; +use datasketches::kll::KllSketch; +use datasketches::kll::MAX_K; +use datasketches::kll::MIN_K; + +const NUMERIC_NOISE_TOLERANCE: f64 = 1e-6; + +fn assert_approx_eq(actual: f64, expected: f64, tolerance: f64) { + let delta = (actual - expected).abs(); + assert!( + delta <= tolerance, + "expected {expected} +/- {tolerance}, got {actual}" + ); +} + +fn rank_eps(sketch: &KllSketch) -> f64 { + sketch.normalized_rank_error(false) +} + +#[test] +fn test_k_limits() { + let _min = KllSketch::::new(MIN_K); + let _max = KllSketch::::new(MAX_K); +} + +#[test] +#[should_panic(expected = "k must be in")] +fn test_k_too_small_panics() { + KllSketch::::new(MIN_K - 1); +} + +#[test] +fn test_empty() { + let sketch = KllSketch::::new(DEFAULT_K); + assert!(sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.n(), 0); + assert_eq!(sketch.num_retained(), 0); + assert!(sketch.min_item().is_none()); + assert!(sketch.max_item().is_none()); + assert!(sketch.rank(&0.0, true).is_none()); + assert!(sketch.quantile(0.5, true).is_none()); + assert!(sketch.pmf(&[0.0f32], true).is_none()); + assert!(sketch.cdf(&[0.0f32], true).is_none()); +} + +#[test] +#[should_panic(expected = "rank must be in [0.0, 1.0]")] +fn test_quantile_out_of_range_panics() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(0.0); + sketch.quantile(-1.0, true); +} + +#[test] +fn test_one_item() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(1.0); + assert!(!sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.n(), 1); + assert_eq!(sketch.num_retained(), 1); + assert_eq!(sketch.rank(&1.0, false), Some(0.0)); + assert_eq!(sketch.rank(&1.0, true), Some(1.0)); + assert_eq!(sketch.rank(&2.0, false), Some(1.0)); + assert_eq!(sketch.min_item().cloned(), Some(1.0)); + assert_eq!(sketch.max_item().cloned(), Some(1.0)); + assert_eq!(sketch.quantile(0.5, true), Some(1.0)); +} + +#[test] +fn test_nan_is_ignored() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(f32::NAN); + assert!(sketch.is_empty()); + sketch.update(0.0); + sketch.update(f32::NAN); + assert_eq!(sketch.n(), 1); +} + +#[test] +fn test_many_items_exact_mode() { + let mut sketch = KllSketch::::new(DEFAULT_K); + let n = DEFAULT_K as usize; + for i in 1..=n { + sketch.update(i as f32); + assert_eq!(sketch.n(), i as u64); + } + assert!(!sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.num_retained(), n); + assert_eq!(sketch.min_item().cloned(), Some(1.0)); + assert_eq!(sketch.quantile(0.0, true), Some(1.0)); + assert_eq!(sketch.max_item().cloned(), Some(n as f32)); + assert_eq!(sketch.quantile(1.0, true), Some(n as f32)); + + for i in 1..=n { + let inclusive_rank = i as f64 / n as f64; + assert_eq!(sketch.rank(&(i as f32), true), Some(inclusive_rank)); + let exclusive_rank = (i - 1) as f64 / n as f64; + assert_eq!(sketch.rank(&(i as f32), false), Some(exclusive_rank)); + } +} + +#[test] +fn test_ten_items_quantiles() { + let mut sketch = KllSketch::::new(DEFAULT_K); + for i in 1..=10 { + sketch.update(i as f32); + } + assert_eq!(sketch.quantile(0.0, true), Some(1.0)); + assert_eq!(sketch.quantile(0.5, true), Some(5.0)); + assert_eq!(sketch.quantile(0.99, true), Some(10.0)); + assert_eq!(sketch.quantile(1.0, true), Some(10.0)); +} + +#[test] +fn test_hundred_items_quantiles() { + let mut sketch = KllSketch::::new(DEFAULT_K); + for i in 0..100 { + sketch.update(i as f32); + } + assert_eq!(sketch.quantile(0.0, true), Some(0.0)); + assert_eq!(sketch.quantile(0.01, true), Some(0.0)); + assert_eq!(sketch.quantile(0.5, true), Some(49.0)); + assert_eq!(sketch.quantile(0.99, true), Some(98.0)); + assert_eq!(sketch.quantile(1.0, true), Some(99.0)); +} + +#[test] +fn test_many_items_estimation_mode_rank_error() { + let mut sketch = KllSketch::::new(DEFAULT_K); + let n = 10_000; + for i in 0..n { + sketch.update(i as f32); + } + assert!(!sketch.is_empty()); + assert!(sketch.is_estimation_mode()); + assert_eq!(sketch.min_item().cloned(), Some(0.0)); + assert_eq!(sketch.max_item().cloned(), Some((n - 1) as f32)); + + let rank_eps = rank_eps(&sketch); + for i in (0..n).step_by(10) { + let true_rank = i as f64 / n as f64; + let rank = sketch.rank(&(i as f32), false).unwrap(); + assert_approx_eq(rank, true_rank, rank_eps); + } + + assert!(sketch.num_retained() > 0); +} + +#[test] +fn test_rank_cdf_pmf_consistency() { + let mut sketch = KllSketch::::new(DEFAULT_K); + let n = 200; + let mut values = Vec::with_capacity(n); + for i in 0..n { + sketch.update(i as f32); + values.push(i as f32); + } + + let ranks = sketch.cdf(&values, false).unwrap(); + let pmf = sketch.pmf(&values, false).unwrap(); + + let mut subtotal = 0.0; + for i in 0..n { + let rank = sketch.rank(&values[i], false).unwrap(); + assert_eq!(rank, ranks[i]); + subtotal += pmf[i]; + assert!( + (ranks[i] - subtotal).abs() <= NUMERIC_NOISE_TOLERANCE, + "cdf vs pmf mismatch at index {i}" + ); + } + + let ranks = sketch.cdf(&values, true).unwrap(); + let pmf = sketch.pmf(&values, true).unwrap(); + + let mut subtotal = 0.0; + for i in 0..n { + let rank = sketch.rank(&values[i], true).unwrap(); + assert_eq!(rank, ranks[i]); + subtotal += pmf[i]; + assert!( + (ranks[i] - subtotal).abs() <= NUMERIC_NOISE_TOLERANCE, + "cdf vs pmf mismatch at index {i}" + ); + } +} + +#[test] +#[should_panic(expected = "split_points must be unique and monotonically increasing")] +fn test_out_of_order_split_points_panics() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(0.0); + let split_points = [1.0, 0.0]; + let _ = sketch.cdf(&split_points, true); +} + +#[test] +#[should_panic(expected = "split_points must not contain NaN values")] +fn test_nan_split_point_panics() { + let mut sketch = KllSketch::::new(DEFAULT_K); + sketch.update(0.0); + let split_points = [f32::NAN]; + let _ = sketch.cdf(&split_points, true); +} + +#[test] +fn test_merge() { + let mut sketch1 = KllSketch::::new(DEFAULT_K); + let mut sketch2 = KllSketch::::new(DEFAULT_K); + let n = 10_000; + for i in 0..n { + sketch1.update(i as f32); + sketch2.update((2 * n - i - 1) as f32); + } + + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((n - 1) as f32)); + assert_eq!(sketch2.min_item().cloned(), Some(n as f32)); + assert_eq!(sketch2.max_item().cloned(), Some((2 * n - 1) as f32)); + + sketch1.merge(&sketch2); + + assert!(!sketch1.is_empty()); + assert_eq!(sketch1.n(), (2 * n) as u64); + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((2 * n - 1) as f32)); + let median = sketch1.quantile(0.5, true).unwrap(); + let rank_eps = rank_eps(&sketch1); + assert_approx_eq(median as f64, n as f64, n as f64 * rank_eps); +} + +#[test] +fn test_merge_lower_k() { + let mut sketch1 = KllSketch::::new(256); + let mut sketch2 = KllSketch::::new(128); + let n = 10_000; + for i in 0..n { + sketch1.update(i as f32); + sketch2.update((2 * n - i - 1) as f32); + } + + sketch1.merge(&sketch2); + + assert_eq!(sketch1.n(), (2 * n) as u64); + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((2 * n - 1) as f32)); + assert_eq!( + sketch1.normalized_rank_error(false), + sketch2.normalized_rank_error(false) + ); + assert_eq!( + sketch1.normalized_rank_error(true), + sketch2.normalized_rank_error(true) + ); + let median = sketch1.quantile(0.5, true).unwrap(); + let rank_eps = rank_eps(&sketch1); + assert_approx_eq(median as f64, n as f64, n as f64 * rank_eps); +} + +#[test] +fn test_merge_exact_mode_lower_k() { + let mut sketch1 = KllSketch::::new(256); + let sketch2 = KllSketch::::new(128); + let n = 10_000; + for i in 0..n { + sketch1.update(i as f32); + } + + let err_before = sketch1.normalized_rank_error(true); + sketch1.merge(&sketch2); + assert_eq!(sketch1.normalized_rank_error(true), err_before); + + assert_eq!(sketch1.n(), n as u64); + assert_eq!(sketch1.min_item().cloned(), Some(0.0)); + assert_eq!(sketch1.max_item().cloned(), Some((n - 1) as f32)); + let median = sketch1.quantile(0.5, true).unwrap(); + let rank_eps = rank_eps(&sketch1); + assert_approx_eq(median as f64, (n / 2) as f64, (n as f64 / 2.0) * rank_eps); +} + +#[test] +fn test_merge_min_max_from_other() { + let mut sketch1 = KllSketch::::new(DEFAULT_K); + let mut sketch2 = KllSketch::::new(DEFAULT_K); + sketch1.update(1.0); + sketch2.update(2.0); + sketch2.merge(&sketch1); + assert_eq!(sketch2.min_item().cloned(), Some(1.0)); + assert_eq!(sketch2.max_item().cloned(), Some(2.0)); +} + +#[test] +fn test_merge_min_max_large_other() { + let mut sketch1 = KllSketch::::new(DEFAULT_K); + for i in 0..1_000_000 { + sketch1.update(i as f32); + } + let mut sketch2 = KllSketch::::new(DEFAULT_K); + sketch2.merge(&sketch1); + assert_eq!(sketch2.min_item().cloned(), Some(0.0)); + assert_eq!(sketch2.max_item().cloned(), Some(999_999.0)); +}