diff --git a/pallets/subtensor/src/math.rs b/pallets/subtensor/src/math.rs index 1abc9ed9b..9aab2c4b5 100644 --- a/pallets/subtensor/src/math.rs +++ b/pallets/subtensor/src/math.rs @@ -889,59 +889,48 @@ pub fn weighted_median( score: &[I32F32], partition_idx: &[usize], minority: I32F32, - partition_lo: I32F32, - partition_hi: I32F32, + mut partition_lo: I32F32, + mut partition_hi: I32F32, ) -> I32F32 { - let n = partition_idx.len(); - if n == 0 { - return I32F32::from_num(0); - } - if n == 1 { - return score[partition_idx[0]]; - } - assert!(stake.len() == score.len()); - let mid_idx: usize = n.saturating_div(2); - let pivot: I32F32 = score[partition_idx[mid_idx]]; - let mut lo_stake: I32F32 = I32F32::from_num(0); - let mut hi_stake: I32F32 = I32F32::from_num(0); - let mut lower: Vec = vec![]; - let mut upper: Vec = vec![]; - for &idx in partition_idx { - if score[idx] == pivot { - continue; + let mut current_partition_idx = partition_idx.to_vec(); + while !current_partition_idx.is_empty() { + let n = current_partition_idx.len(); + if n == 1 { + return score[current_partition_idx[0]]; + } + let mid_idx: usize = n.saturating_div(2); + let pivot: I32F32 = score[current_partition_idx[mid_idx]]; + let mut lo_stake: I32F32 = I32F32::from_num(0); + let mut hi_stake: I32F32 = I32F32::from_num(0); + let mut lower: Vec = vec![]; + let mut upper: Vec = vec![]; + for &idx in ¤t_partition_idx { + if score[idx] == pivot { + continue; + } + if score[idx] < pivot { + lo_stake = lo_stake.saturating_add(stake[idx]); + lower.push(idx); + } else { + hi_stake = hi_stake.saturating_add(stake[idx]); + upper.push(idx); + } } - if score[idx] < pivot { - lo_stake = lo_stake.saturating_add(stake[idx]); - lower.push(idx); + if partition_lo.saturating_add(lo_stake) <= minority + && minority < partition_hi.saturating_sub(hi_stake) + { + return pivot; + } else if (minority < partition_lo.saturating_add(lo_stake)) && (!lower.is_empty()) { + current_partition_idx = lower; + partition_hi = partition_lo.saturating_add(lo_stake); + } else if (partition_hi.saturating_sub(hi_stake) <= minority) && (!upper.is_empty()) { + current_partition_idx = upper; + partition_lo = partition_hi.saturating_sub(hi_stake); } else { - hi_stake = hi_stake.saturating_add(stake[idx]); - upper.push(idx); + return pivot; } } - if (partition_lo.saturating_add(lo_stake) <= minority) - && (minority < partition_hi.saturating_sub(hi_stake)) - { - return pivot; - } else if (minority < partition_lo.saturating_add(lo_stake)) && (!lower.is_empty()) { - return weighted_median( - stake, - score, - &lower, - minority, - partition_lo, - partition_lo.saturating_add(lo_stake), - ); - } else if (partition_hi.saturating_sub(hi_stake) <= minority) && (!upper.is_empty()) { - return weighted_median( - stake, - score, - &upper, - minority, - partition_hi.saturating_sub(hi_stake), - partition_hi, - ); - } - pivot + I32F32::from_num(0) } /// Column-wise weighted median, e.g. stake-weighted median scores per server (column) over all validators (rows).