// Requires clap and statrs use statrs::function::gamma::ln_gamma; use std::f64::consts::E; use clap::Parser; struct LnGammaCache { n: usize, k: usize, n_log: f64, // contains [0, k] which represents [1, k+1] lk_tab: Vec, // contains [n-k, n] hk_tab: Vec, } impl LnGammaCache { fn precompute_ktab(&mut self, n: usize) { for k in 0..=self.k { self.lk_tab[k] = ln_gamma((k + 1) as f64); self.hk_tab[k] = ln_gamma((n - k + 1) as f64); } } fn new(n: usize, k: usize) -> Self { let mut cache = LnGammaCache { n, k, n_log: (n as f64).ln(), lk_tab: vec![0f64; k+1], hk_tab: vec![0f64; k+1], }; cache.precompute_ktab(n); cache } unsafe fn ln_gamma_unchecked(&self, i: usize) -> f64 { unsafe { if i <= self.k + 1 { debug_assert!(i > 0, "Uncached call to ln_gamma"); return *self.lk_tab.get_unchecked(i - 1) } else { debug_assert!(self.n - self.k <= i && i <= self.n + 1, "Uncached call to ln_gamma"); return *self.hk_tab.get_unchecked(self.n - (i - 1)) } } } } // The performance is dependent heavily on the execution time of this function // After many optimizations, powf is the most computationally expensive unit in this routine fn evaluate_probability_log_gamma( partition: &[usize], k: usize, n: usize, maxpart: usize, cache: &LnGammaCache, ) -> f64 { let mut c0 = 0; for i in 0..maxpart { c0 += partition[i]; } c0 = n - c0; let log_probability = unsafe { let mut log_denominator = (k as f64) * cache.n_log + cache.ln_gamma_unchecked(c0 + 1); let log_numerator = cache.ln_gamma_unchecked(n + 1) + cache.ln_gamma_unchecked(k + 1); for i in 0..maxpart { log_denominator += cache.ln_gamma_unchecked(partition[i] + 1); log_denominator += (partition[i] as f64) * cache.ln_gamma_unchecked(i + 2); } log_numerator - log_denominator }; E.powf(log_probability) } fn evaluate_probability( partition: &[usize], probdist: &mut [f64], k: usize, n: usize, maxpart: usize, cache: &LnGammaCache ) { let mut collisions = 0usize; for i in 1..maxpart { collisions += i * partition[i]; } probdist[collisions] += evaluate_probability_log_gamma(partition, k, n, maxpart, cache); } fn generate_partitions_dfs( partition: &mut [usize], probdist: &mut [f64], k: usize, n: usize, n_collisions: usize, collision_ub: usize, level_ub: &[usize], cache: &LnGammaCache, ) { if n_collisions > collision_ub { return; } generate_partitions_dfs(partition, probdist, k, n, n_collisions + 1, collision_ub, level_ub, cache); let mut n_subcomb = partition[0] / n_collisions; n_subcomb = std::cmp::min(n_subcomb, level_ub[n_collisions - 1]); for i in 1..=n_subcomb { partition[0] -= i * n_collisions; partition[n_collisions- 1] += i; evaluate_probability(partition, probdist, k, n, n_collisions, cache); generate_partitions_dfs(partition, probdist, k, n, n_collisions + 1, collision_ub, level_ub, cache); partition[0] += i * n_collisions; partition[n_collisions - 1] -= i; } } fn generate_partitions( k: usize, n: usize, collision_ub: usize, level_ub: &[usize], ) -> Vec { let mut partitions = vec![0usize; k]; let mut probdist = vec![0f64; k]; partitions[0] = k; let cache = LnGammaCache::new(n, k); if collision_ub > 0 { evaluate_probability(&partitions, &mut probdist, k, n, 1, &cache); } if collision_ub > 1 { generate_partitions_dfs(&mut partitions, &mut probdist, k, n, 2, collision_ub, level_ub, &cache); } return probdist; } fn parse_max_level_collisions(s: &str) -> Vec { let mut max_sizes = Vec::new(); let mut add_token = |token: &str| { let sz = match token.parse::() { Ok(res) => res, Err(err) => { eprintln!("Failed to parse max sizes: {}", err); std::process::exit(1); } }; max_sizes.push(sz); }; let mut i = 0usize; while let Some(j) = s[i..].find(',') { add_token(&s[i..i+j]); i += j + 1; } add_token(&s[i..]); max_sizes } #[derive(Parser,Debug)] #[command(author, version, about)] struct Args { n: usize, k: usize, // Only consider collisions up until this bound (inclusive) // For collision_ub=4, combinations of zero, one, two, three, and four // collisions are calculated collision_ub: Option, // Only consider collisions up to level_ub_l collisions for level l. // For example, if level_ub_3 = 10 this means we only consider scenarios // in which three hashes co-collide up to 10 times level_ub: Option, } fn main() { let args = Args::parse(); if args.n == 0 || args.k == 0 { return; } if args.k >= args.n { eprintln!("This computation is optimized for k < n."); std::process::exit(1); } let collision_ub = match args.collision_ub { Some(ub) => { if ub > args.k { eprintln!("collision_ub={} hashes cannot collide if only k={} values are hashed", ub, args.k); } ub }, None => args.k }; let level_ub: Vec; if let Some(ub_spec) = args.level_ub { level_ub = parse_max_level_collisions(&ub_spec[..]); if level_ub.len() != collision_ub { eprintln!("Please specify {} bounds for each level", collision_ub); std::process::exit(1); } } else { level_ub = vec![args.k; collision_ub + 1]; } let probdist = generate_partitions(args.k, args.n, collision_ub, &level_ub[..]); for i in 0..probdist.len() { println!("{},{}", i, probdist[i]); } }