Skip to content

Commit

Permalink
chore: threshold should be algebraic data type
Browse files Browse the repository at this point in the history
  • Loading branch information
fiksn committed Sep 12, 2023
1 parent 46def56 commit d6bf914
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 17 deletions.
63 changes: 61 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ use std::collections::HashMap;
use std::fmt;
use std::ops::Deref;
use std::str::FromStr;
use std::string::ParseError;
use std::sync::{Arc, Mutex};
use std::{error::Error, net::SocketAddr};
use std::{
thread,
time::{Duration, SystemTime},
};
use thiserror::Error;
use tokio::main;
use tokio::net::TcpStream;
use tokio::net::ToSocketAddrs;
Expand All @@ -66,6 +68,42 @@ struct ChannelInfo {
node2: NodeId,
}

#[derive(Error, Debug, Eq, PartialEq)]
pub enum ParseThresholdError {
#[error("Parse error")]
ParseError,
#[error("Parse int error {0}")]
ParseIntError(#[from] std::num::ParseIntError),
#[error("Parse foat error {0}")]
ParseFloatError(#[from] std::num::ParseFloatError),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Threshold {
Number(u64),
Percentage(f64),
}

impl FromStr for Threshold {
type Err = ParseThresholdError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.ends_with('%') {
let x = &s[..s.len() - 1];
let num = x.parse::<f64>()?;

if num < 0f64 || num > 100f64 {
return Err(ParseThresholdError::ParseError);
}

Ok(Self::Percentage(num))
} else {
let num = s.parse::<u64>()?;

Ok(Self::Number(num))
}
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
Expand All @@ -76,8 +114,8 @@ struct Args {
nodes: Vec<LightningNodeAddr>,

/// Threshold
#[arg(short, long, num_args = 1, default_value_t = 5)]
threshold: u8,
#[arg(short, long, num_args = 1, default_value = "10%")]
threshold: Threshold,
}

const DEBUG: bool = false;
Expand Down Expand Up @@ -193,3 +231,24 @@ async fn connect(
return None;
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_threshold_parsing() {
assert!(Threshold::from_str("100%") == Ok(Threshold::Percentage(100f64)));
assert!(Threshold::from_str("99.9%") == Ok(Threshold::Percentage(99.9f64)));
assert!(Threshold::from_str("101%").is_err());
assert!(Threshold::from_str("%").is_err());
assert!(Threshold::from_str("-1%").is_err());
assert!(Threshold::from_str("aa").is_err());
assert!(Threshold::from_str("%2").is_err());
assert!(Threshold::from_str("21%") == Ok(Threshold::Percentage(21f64)));

assert!(Threshold::from_str("-11").is_err());
assert!(Threshold::from_str("21") == Ok(Threshold::Number(21)));
assert!(Threshold::from_str("101") == Ok(Threshold::Number(101)));
}
}
73 changes: 58 additions & 15 deletions src/voter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use std::ops::Deref;
use std::str::FromStr;
use std::sync::{Arc, Mutex};

use super::Threshold;

use super::dummy::*;
use super::resolve::*;

Expand All @@ -23,15 +25,15 @@ where
L::Target: Logger,
{
logger: L,
threshold: u8,
threshold: Threshold,
resolver: Mutex<Option<Arc<CachingChannelResolving<Arc<DummyLogger>>>>>,
votes: Mutex<HashMap<NodeId, HashSet<u64>>>,
}
impl<L: Deref + Send + std::marker::Sync + 'static> Voter<L>
where
L::Target: Logger,
{
pub fn new(threshold: u8, logger: L) -> Voter<L> {
pub fn new(threshold: Threshold, logger: L) -> Voter<L> {
Voter {
resolver: Mutex::new(None),
logger: logger,
Expand Down Expand Up @@ -102,13 +104,13 @@ where
}
}

if num >= self.threshold {
let info = Self::get_nodeinfo(node.node_id).await;
let (b, channels) = self.threshold_breached(node.node_id, num as u64).await;
if b {
log_info!(
self.logger,
"THRESHOLD BREACHED num: {}/{} node: {} alias: {}",
num,
info.map_or(0, |info| info.channelcount),
channels,
node.node_id,
node.alias
);
Expand Down Expand Up @@ -136,18 +138,59 @@ where
node.alias
);

let mut guard = self.votes.lock().unwrap();
if let Some(one) = guard.get(&node.node_id) {
if one.len() as u8 >= self.threshold {
log_info!(
self.logger,
"THRESHOLD NOT BREACHED anymore, node: {} alias: {}",
node.node_id,
node.alias
);
let num: u8;

{
let guard = self.votes.lock().unwrap();

if let Some(one) = guard.get(&node.node_id) {
num = one.len() as u8;
} else {
num = 0;
}
}
guard.remove(&node.node_id);

let (b, channels) = self.threshold_breached(node.node_id, num as u64).await;
if b {
log_info!(
self.logger,
"THRESHOLD NOT BREACHED anymore, node: {} alias: {}",
node.node_id,
node.alias
);
}

{
// Delete all
let mut guard = self.votes.lock().unwrap();
guard.remove(&node.node_id);
}
}

async fn threshold_breached(&self, node_id: NodeId, num: u64) -> (bool, u64) {
let mut limit = 3;
let mut percent: f64 = 0f64;

match self.threshold {
Threshold::Percentage(value) => {
percent = value as f64;
}
Threshold::Number(value) => {
limit = value;
}
};

if num >= limit {
let info = Self::get_nodeinfo(node_id).await;
let channels = info.map_or(1, |info| info.channelcount);
if percent > 0f64 && ((num / channels * 100) as f64) < percent {
return (false, 0);
}

return (true, channels);
}

return (false, 0);
}

async fn get_nodeinfo(node_id: NodeId) -> Option<NodeInfo> {
Expand Down

0 comments on commit d6bf914

Please sign in to comment.