diff --git a/xayn-ai/src/analytics.rs b/xayn-ai/src/analytics.rs index c6d201e43..d9fafb092 100644 --- a/xayn-ai/src/analytics.rs +++ b/xayn-ai/src/analytics.rs @@ -1,20 +1,445 @@ +use std::{cmp::Ordering, collections::HashMap}; + +use anyhow::bail; +use displaydoc::Display; +use thiserror::Error; + use crate::{ - data::{document::DocumentHistory, document_data::DocumentDataWithMab}, + data::{ + document::{DocumentHistory, Relevance}, + document_data::DocumentDataWithMab, + }, error::Error, reranker::systems, + utils::nan_safe_f32_cmp_desc, }; +/// Which k to use for nDCG@k +const DEFAULT_NDCG_K: usize = 2; +/// Calculated analytics data. #[derive(Clone)] -pub struct Analytics; +pub struct Analytics { + /// The nDCG@k score between the LTR ranking and the relevance based ranking + pub ndcg_ltr: f32, + /// The nDCG@k score between the Context ranking and the relevance based ranking + pub ndcg_context: f32, + /// The nDCG@k score between the initial ranking and the relevance based ranking + pub ndcg_initial_ranking: f32, + /// The nDCG@k score between the final ranking and the relevance based ranking + pub ndcg_final_ranking: f32, +} + +#[derive(Error, Debug, Display)] +/// Can not calculate Analytics as no relevant history is available. +pub(crate) struct NoRelevantHistoricInfo; pub(crate) struct AnalyticsSystem; impl systems::AnalyticsSystem for AnalyticsSystem { fn compute_analytics( &self, - _history: &[DocumentHistory], - _documents: &[DocumentDataWithMab], + history: &[DocumentHistory], + documents: &[DocumentDataWithMab], ) -> Result { - Ok(Analytics) + // We need to be able to lookup relevances by document id. + // and linear search is most likely a bad idea. So we create + // a hashmap for the lookups. + let relevance_lookups: HashMap<_, _> = { + history + .iter() + .map(|h_doc| (&h_doc.id, score_for_relevance(h_doc.relevance))) + .collect() + }; + + let mut paired_ltr_scores = Vec::new(); + let mut paired_context_scores = Vec::new(); + let mut paired_final_ranking_score = Vec::new(); + + for document in documents { + if let Some(relevance) = relevance_lookups.get(&document.document_id.id).copied() { + paired_ltr_scores.push((relevance, document.ltr.ltr_score)); + paired_context_scores.push((relevance, document.context.context_value)); + + // nDCG expects higher scores to be better but for the ranking + // it's the oposite, the solution carried over from the dart impl + // is to multiply by -1. + let final_ranking_desc = -(document.mab.rank as f32); + paired_final_ranking_score.push((relevance, final_ranking_desc)); + } + } + + if paired_ltr_scores.is_empty() { + bail!(NoRelevantHistoricInfo); + } + + let ndcg_ltr = calcuate_reordered_ndcg_at_k_score(&mut paired_ltr_scores, DEFAULT_NDCG_K); + + let ndcg_context = + calcuate_reordered_ndcg_at_k_score(&mut paired_context_scores, DEFAULT_NDCG_K); + + let ndcg_final_ranking = + calcuate_reordered_ndcg_at_k_score(&mut paired_final_ranking_score, DEFAULT_NDCG_K); + + Ok(Analytics { + //FIXME: We currently have no access to the initial score as thiss will require + // some changes to the main applications type state/component system this + // will be done in a followup PR. + ndcg_initial_ranking: f32::NAN, + ndcg_ltr, + ndcg_context, + ndcg_final_ranking, + }) + } +} + +/// Returns a score for the given `Relevance`. +fn score_for_relevance(relevance: Relevance) -> f32 { + match relevance { + Relevance::Low => 0., + Relevance::Medium => 1., + Relevance::High => 2., + } +} + +/// Calculates the nDCG@k for given paired relevances. +/// +/// The input is a slice over `(relevance, ordering_score)` pairs, +/// where the `ordering_score` is used to reorder the relevances +/// based on sorting them in descending order. +/// +/// **Note that the `paired_relevances` are sorted in place.** +/// +/// After the reordering of the pairs the `relevance` values +/// are used to calculate the nDCG@k. +/// +/// ## NaN Handling. +/// +/// NaN values are treated as the lowest possible socres wrt. the sorting. +/// +/// If a `NaN` is in the k-first relevances the resulting nDCG@k score will be `NaN`. +fn calcuate_reordered_ndcg_at_k_score(paired_relevances: &mut [(f32, f32)], k: usize) -> f32 { + paired_relevances + .sort_by(|(_, ord_sc_1), (_, ord_sc_2)| nan_safe_f32_cmp_desc(ord_sc_1, ord_sc_2)); + ndcg_at_k(paired_relevances.iter().map(|(rel, _ord)| *rel), k) +} + +/// Calculates the nDCG@k. +/// +/// This taks the first k values for the DCG score and the "best" k values +/// for the IDCG score and then calculates the nDCG score with that. +fn ndcg_at_k(relevances: impl Iterator + Clone + ExactSizeIterator, k: usize) -> f32 { + let dcg_at_k = dcg(relevances.clone().take(k)); + + let ideal_relevances = pick_k_highest_sorted_desc(relevances, k); + let idcg_at_k = dcg(ideal_relevances.into_iter()); + + // if there is no ideal score, pretent the ideal score is 1 + if idcg_at_k == 0.0 { + dcg_at_k + } else { + dcg_at_k / idcg_at_k + } +} + +/// Pick the k-highest values in given iterator. +/// +/// (As if a vector is sorted and then &sorted_score[..k]). +/// +/// If `NaN`'s is treated as the smallest possible value, i.e. +/// preferably not picked at all if possible. +fn pick_k_highest_sorted_desc( + scores: impl Iterator + ExactSizeIterator, + k: usize, +) -> Vec { + // Due to specialization this has no overhead if scores is already fused. + let mut scores = scores.fuse(); + let mut k_highest: Vec<_> = (&mut scores).take(k).collect(); + + k_highest.sort_by(nan_safe_f32_cmp_desc); + + for score in scores { + //Supposed to act as NaN safe version of: if k_highest[k-1] < score { + if nan_safe_f32_cmp_desc(&k_highest[k - 1], &score) == Ordering::Greater { + let _ = k_highest.pop(); + + let idx = k_highest + .binary_search_by(|other| nan_safe_f32_cmp_desc(other, &score)) + .unwrap_or_else(|not_found_insert_idx| not_found_insert_idx); + + k_highest.insert(idx, score); + } + } + + k_highest +} + +/// Calculates the DCG of given input sequence. +fn dcg(scores: impl Iterator) -> f32 { + // - As this is only used for analytics and bound by `k`(==2) and `&[Document].len()` (~ 10 to 40) + // no further optimizations make sense. Especially not if they require memory allocations. + // - A "simple commulative" sum is ok as we only use small number of scores (default k=2) + scores.enumerate().fold(0.0, |sum, (idx, score)| { + //it's i+2 as our i starts with 0, while the formular starts with 1 and uses i+1 + sum + (2f32.powf(score) - 1.) / (idx as f32 + 2.).log2() + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reranker::systems::AnalyticsSystem, tests, UserFeedback}; + + #[test] + fn test_full_analytics_system() { + let history = tests::document_history(vec![ + (2, Relevance::Low, UserFeedback::None), + (3, Relevance::Medium, UserFeedback::None), + (1, Relevance::High, UserFeedback::None), + (0, Relevance::Medium, UserFeedback::None), + (10, Relevance::Low, UserFeedback::None), + ]); + + let mut documents = tests::data_with_mab(tests::from_ids(0..3)); + documents[0].ltr.ltr_score = 3.; + documents[0].context.context_value = 3.5; + documents[0].mab.rank = 1; + + documents[1].ltr.ltr_score = 2.; + documents[1].context.context_value = 7.; + documents[1].mab.rank = 0; + + documents[2].ltr.ltr_score = 7.; + documents[2].context.context_value = 6.; + documents[2].mab.rank = 2; + + let Analytics { + ndcg_ltr, + ndcg_context, + ndcg_initial_ranking: _, + ndcg_final_ranking, + } = AnalyticsSystem + .compute_analytics(&history, &documents) + .unwrap(); + + assert_f32_eq!(ndcg_ltr, 0.173_765_35); + assert_f32_eq!(ndcg_context, 0.826_234_64); + //FIXME: Currently not possible as `ndcg_initial_ranking` is not yet computed + // assert!(approx_eq!(f32, ndcg_initial_ranking, 0.7967075809905066, ulps = 2)); + assert_f32_eq!(ndcg_final_ranking, 1.0); + } + + #[test] + fn test_calcuate_reordered_ndcg_at_k_score_tests_from_dart() { + let relevances = &mut [(0., -50.), (0., 0.001), (1., 4.14), (2., 1000.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + assert_eq!(format!("{:.4}", res), "1.0000"); + + let relevances = &mut [(0., -10.), (0., 1.), (1., 0.), (2., 6.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + let res2 = ndcg_at_k([2., 0., 1., 0.].iter().copied(), 2); + assert_f32_eq!(res, res2); + + let relevances = &mut [(0., 1.), (0., -10.), (1., -11.), (2., -11.6)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + assert_f32_eq!(res, 0.0); + + let relevances = &mut [(0., 1.), (0., -10.), (1., 100.), (2., 99.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + let res2 = ndcg_at_k([1., 2., 1., 0.].iter().copied(), 2); + assert_f32_eq!(res, res2); + } + + #[test] + fn test_calcuate_reordered_ndcg_at_k_score_without_reordering() { + let relevances = &mut [(1., 12.), (4., 9.), (10., 7.), (3., 5.), (0., 4.), (6., 1.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + assert_f32_eq!(res, 0.009_846_116); + + let relevances = &mut [(1., 12.), (4., 9.), (10., 7.), (3., 5.), (0., 4.), (6., 1.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 4); + assert_f32_eq!(res, 0.489_142_48); + + let relevances = &mut [(1., 12.), (4., 9.), (10., 7.), (3., 5.), (0., 4.), (6., 1.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 100); + assert_f32_eq!(res, 0.509_867_9); + + let relevances = &mut [ + (-1., 12.), + (7., 9.), + (-10., 7.), + (3., 5.), + (0., 4.), + (-6., 1.), + ]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + assert_f32_eq!(res, 0.605_921_45); + + let relevances = &mut [ + (-1., 12.), + (7., 9.), + (-10., 7.), + (3., 5.), + (0., 4.), + (-6., 1.), + ]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 4); + assert_f32_eq!(res, 0.626_086_65); + + let relevances = &mut [ + (-1., 12.), + (7., 9.), + (-10., 7.), + (3., 5.), + (0., 4.), + (-6., 1.), + ]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 100); + assert_f32_eq!(res, 0.626_934_23); + } + + #[test] + fn test_calcuate_reordered_ndcg_at_k_score_with_reordering() { + let relevances = &mut [(4., 9.), (10., 7.), (6., 1.), (0., 4.), (3., 5.), (1., 12.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + assert_f32_eq!(res, 0.009_846_116); + + let relevances = &mut [(4., 9.), (10., 7.), (6., 1.), (0., 4.), (3., 5.), (1., 12.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 4); + assert_f32_eq!(res, 0.489_142_48); + + let relevances = &mut [(4., 9.), (10., 7.), (6., 1.), (0., 4.), (3., 5.), (1., 12.)]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 100); + assert_f32_eq!(res, 0.509_867_9); + + let relevances = &mut [ + (3., 5.), + (-10., 7.), + (0., 4.), + (-1., 12.), + (7., 9.), + (-6., 1.), + ]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 2); + assert_f32_eq!(res, 0.605_921_45); + + let relevances = &mut [ + (3., 5.), + (-10., 7.), + (0., 4.), + (-1., 12.), + (7., 9.), + (-6., 1.), + ]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 4); + assert_f32_eq!(res, 0.626_086_65); + + let relevances = &mut [ + (3., 5.), + (-10., 7.), + (0., 4.), + (-1., 12.), + (7., 9.), + (-6., 1.), + ]; + let res = calcuate_reordered_ndcg_at_k_score(relevances, 100); + assert_f32_eq!(res, 0.626_934_23); + } + + #[test] + fn test_ndcg_at_k_tests_from_dart() { + let res = ndcg_at_k([0., 0., 1., 2.].iter().copied(), 4); + assert_eq!(format!("{:.4}", res), "0.4935"); + } + + #[test] + fn ndcg_at_k_produces_expected_values_for_k_larger_then_input() { + let res = ndcg_at_k([1., 4., 10., 3., 0., 6.].iter().copied(), 100); + assert_f32_eq!(res, 0.509_867_9); + + let res = ndcg_at_k([-1., 7., -10., 3., 0., -6.].iter().copied(), 100); + assert_f32_eq!(res, 0.626_934_23); + } + + #[test] + fn ndcg_at_k_produces_expected_values_for_k_smaller_then_input() { + let res = ndcg_at_k([1., 4., 10., 3., 0., 6.].iter().copied(), 2); + assert_f32_eq!(res, 0.009_846_116); + let res = ndcg_at_k([1., 4., 10., 3., 0., 6.].iter().copied(), 4); + assert_f32_eq!(res, 0.489_142_48); + + let res = ndcg_at_k([-1., 7., -10., 3., 0., -6.].iter().copied(), 2); + assert_f32_eq!(res, 0.605_921_45); + let res = ndcg_at_k([-1., 7., -10., 3., 0., -6.].iter().copied(), 4); + assert_f32_eq!(res, 0.626_086_65); + } + + #[test] + fn test_dcg_tests_from_dart() { + // there is no dcg@k function in my code. It's dcg(input_iter.take(k)). + let res = dcg([0., 0., 1., 1.].iter().copied().take(2)); + assert_f32_eq!(res, 0.0); + + let res = dcg([0., 0., 1., 1.].iter().copied().take(4)); + // Dart used ln instead of log2 so the values diverge. + // assert_eq!(format!("{:.4}", res), "1.3426"); + assert_eq!(format!("{:.4}", res), "0.9307"); + + let res = dcg([0., 0., 1., 2.].iter().copied().take(4)); + // Dart used ln instead of log2 so the values diverge. + // assert_eq!(format!("{:.4}", res), "2.5853"); + assert_eq!(format!("{:.4}", res), "1.7920"); + } + + #[test] + fn dcg_produces_expected_results() { + assert_f32_eq!(dcg([3f32, 2., 3., 0., 1., 2.].iter().copied()), 13.848_264); + assert_f32_eq!( + dcg([-3.2, -2., -4., 0., -1., -2.].iter().copied()), + -2.293_710_2 + ); + } + + #[test] + fn test_pick_k_highest_picks_the_highest_values_and_only_them() { + let cases: &[(&[f32], &[f32])] = &[ + (&[3., 2., 1., 0.], &[3., 2.]), + (&[0., 1., 2., 3.], &[3., 2.]), + (&[-2., -2.], &[-2., -2.]), + (&[-30., 3., 2., 10., -3., 0.], &[10., 3.]), + (&[-3., 0., -1., -2.], &[0., -1.]), + ]; + + for (input, pick) in cases { + let res = pick_k_highest_sorted_desc(input.iter().copied(), 2); + assert_eq!( + &*res, &**pick, + "res={:?}, expected={:?}, input={:?}", + res, pick, input + ); + } + } + + #[test] + fn test_pick_k_highest_does_not_pick_nans_if_possible() { + #![allow(clippy::float_cmp)] + + let res = pick_k_highest_sorted_desc([3., 2., f32::NAN].iter().copied(), 2); + assert_eq!(&*res, &[3., 2.]); + + let res = pick_k_highest_sorted_desc( + [f32::NAN, 3., f32::NAN, f32::NAN, 2., 4., f32::NAN] + .iter() + .copied(), + 2, + ); + assert_eq!(&*res, &[4., 3.]); + + let res = pick_k_highest_sorted_desc([f32::NAN, 3., 2., f32::NAN].iter().copied(), 3); + assert_eq!(&res[..2], &[3., 2.]); + assert!(res[2].is_nan()); + + let res = pick_k_highest_sorted_desc([f32::NAN].iter().copied(), 1); + assert_eq!(res.len(), 1); + assert!(res[0].is_nan()); } } diff --git a/xayn-ai/src/lib.rs b/xayn-ai/src/lib.rs index 8a8fd8e4c..88e8ab8b6 100644 --- a/xayn-ai/src/lib.rs +++ b/xayn-ai/src/lib.rs @@ -1,3 +1,6 @@ +#[macro_use] +mod utils; + mod analytics; mod bert; mod coi; @@ -7,7 +10,6 @@ mod error; mod ltr; mod mab; mod reranker; -mod utils; pub use crate::{ analytics::Analytics, diff --git a/xayn-ai/src/mab.rs b/xayn-ai/src/mab.rs index babdd6f79..da993e7bc 100644 --- a/xayn-ai/src/mab.rs +++ b/xayn-ai/src/mab.rs @@ -6,6 +6,7 @@ use crate::{ UserInterests, }, reranker::systems::MabSystem, + utils::nan_safe_f32_cmp, Error, }; @@ -49,21 +50,6 @@ impl BetaSample for BetaSampler { } } -/// Pretend that comparing two f32 is total. The function will rank `NaN` -/// as the lowest value, similar to what [`f32::max`] does. -fn f32_total_cmp(a: &f32, b: &f32) -> Ordering { - a.partial_cmp(&b).unwrap_or_else(|| { - // if `partial_cmp` returns None we have at least one `NaN`, - // we treat it as the lowest value - match (a.is_nan(), b.is_nan()) { - (true, true) => Ordering::Equal, - (true, _) => Ordering::Less, - (_, true) => Ordering::Greater, - _ => unreachable!("partial_cmp returned None but both numbers are not NaN"), - } - }) -} - /// Wrapper to order documents by `context_value`. /// We need to implement `Ord` to use it in the `BinaryHeap`. #[cfg_attr(test, derive(Debug, Clone))] @@ -84,7 +70,7 @@ impl PartialOrd for DocumentByContext { impl Ord for DocumentByContext { fn cmp(&self, other: &Self) -> Ordering { - f32_total_cmp( + nan_safe_f32_cmp( &self.0.context.context_value, &other.0.context.context_value, ) @@ -169,7 +155,7 @@ fn pull_arms( |max, coi_id| -> Result<_, MabError> { let sample = sample_from_coi(coi_id)?; - if let Ordering::Greater = f32_total_cmp(&sample, &max.0) { + if let Ordering::Greater = nan_safe_f32_cmp(&sample, &max.0) { Ok((sample, coi_id)) } else { Ok(max) diff --git a/xayn-ai/src/reranker/mod.rs b/xayn-ai/src/reranker/mod.rs index 97f9c34ef..d70b65e90 100644 --- a/xayn-ai/src/reranker/mod.rs +++ b/xayn-ai/src/reranker/mod.rs @@ -597,7 +597,12 @@ mod tests { let cs = common_systems_with_fail!(analytics, MockAnalyticsSystem, compute_analytics, |_,_|); let mut reranker = Reranker::new(cs).unwrap(); - reranker.analytics = Some(Analytics); + reranker.analytics = Some(Analytics { + ndcg_initial_ranking: 0., + ndcg_ltr: 0., + ndcg_context: 0., + ndcg_final_ranking: 0., + }); let documents = car_interest_example::documents(); let history = history_for_prev_docs( &reranker.data.prev_documents.to_coi_system_data(), diff --git a/xayn-ai/src/utils.rs b/xayn-ai/src/utils.rs index b9769c74b..ed3e1d6bd 100644 --- a/xayn-ai/src/utils.rs +++ b/xayn-ai/src/utils.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + #[macro_export] macro_rules! to_vec_of_ref_of { ($data: expr, $type:ty) => { @@ -7,3 +9,96 @@ macro_rules! to_vec_of_ref_of { .collect::>() }; } + +/// Allows comparing and sorting f32 even if `NaN` is involved. +/// +/// Pretend that f32 has a total ordering. +/// +/// `NaN` is treated as the lowest possible value, similar to what [`f32::max`] does. +/// +/// If this is used for sorting this will lead to an ascending order, like +/// for example `[NaN, 0.5, 1.5, 2.0]`. +/// +/// By switching the input parameters around this can be used to create a +/// descending sorted order, like e.g.: `[2.0, 1.5, 0.5, NaN]`. +pub(crate) fn nan_safe_f32_cmp(a: &f32, b: &f32) -> Ordering { + a.partial_cmp(&b).unwrap_or_else(|| { + // if `partial_cmp` returns None we have at least one `NaN`, + // we treat it as the lowest value + match (a.is_nan(), b.is_nan()) { + (true, true) => Ordering::Equal, + (true, _) => Ordering::Less, + (_, true) => Ordering::Greater, + _ => unreachable!("partial_cmp returned None but both numbers are not NaN"), + } + }) +} + +/// `nan_safe_f32_cmp_desc(a,b)` is syntax suggar for `nan_safe_f32_cmp(b, a)` +#[inline] +pub(crate) fn nan_safe_f32_cmp_desc(a: &f32, b: &f32) -> Ordering { + nan_safe_f32_cmp(b, a) +} + +#[cfg(test)] +macro_rules! assert_f32_eq { + ($left:expr, $right:expr) => { + assert_f32_eq! { $left, $right, ulps = 2 } + }; + ($left:expr, $right:expr, ulps = $ulps:expr) => {{ + let left = $left; + let right = $right; + let ulps = $ulps; + assert!( + ::float_cmp::approx_eq!(f32, $left, $right, ulps = ulps), + "approximated equal assertion failed (ulps={}): {} == {}", + ulps, + left, + right + ); + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_nan_safe_f32_cmp_sorts_in_the_right_order() { + #![allow(clippy::float_cmp)] + + let data = &mut [f32::NAN, 1., 5., f32::NAN, 4.]; + data.sort_by(nan_safe_f32_cmp); + + assert_eq!(&data[2..], &[1., 4., 5.]); + assert!(data[0].is_nan()); + assert!(data[1].is_nan()); + + data.sort_by(nan_safe_f32_cmp_desc); + + assert_eq!(&data[..3], &[5., 4., 1.]); + assert!(data[3].is_nan()); + assert!(data[4].is_nan()); + + let data = &mut [1., 5., 3., 4.]; + + data.sort_by(nan_safe_f32_cmp); + assert_eq!(&data[..], &[1., 3., 4., 5.]); + + data.sort_by(nan_safe_f32_cmp_desc); + assert_eq!(&data[..], &[5., 4., 3., 1.]); + } + + #[test] + fn test_nan_safe_f32_cmp_nans_compare_as_expected() { + assert_eq!(nan_safe_f32_cmp(&f32::NAN, &f32::NAN), Ordering::Equal); + assert_eq!(nan_safe_f32_cmp(&-12., &f32::NAN), Ordering::Greater); + assert_eq!(nan_safe_f32_cmp_desc(&-12., &f32::NAN), Ordering::Less); + assert_eq!(nan_safe_f32_cmp(&f32::NAN, &-12.), Ordering::Less); + assert_eq!(nan_safe_f32_cmp_desc(&f32::NAN, &-12.), Ordering::Greater); + assert_eq!(nan_safe_f32_cmp(&12., &f32::NAN), Ordering::Greater); + assert_eq!(nan_safe_f32_cmp_desc(&12., &f32::NAN), Ordering::Less); + assert_eq!(nan_safe_f32_cmp(&f32::NAN, &12.), Ordering::Less); + assert_eq!(nan_safe_f32_cmp_desc(&f32::NAN, &12.), Ordering::Greater); + } +}